1
0
mirror of https://github.com/spl0k/supysonic.git synced 2024-11-09 19:52:16 +00:00

Merge branch 'pony'

This commit is contained in:
spl0k 2018-01-04 21:36:56 +01:00
commit 4b446f7121
48 changed files with 1865 additions and 1751 deletions

2
.gitignore vendored
View File

@ -66,3 +66,5 @@ Session.vim
.netrwhist .netrwhist
*~ *~
*.orig

View File

@ -51,32 +51,27 @@ You'll need these to run Supysonic:
* Python 2.7 * Python 2.7
* [Flask](http://flask.pocoo.org/) >= 0.9 * [Flask](http://flask.pocoo.org/) >= 0.9
* [Storm](https://storm.canonical.com/) * [PonyORM](https://ponyorm.com/)
* [Python Imaging Library](https://github.com/python-pillow/Pillow) * [Python Imaging Library](https://github.com/python-pillow/Pillow)
* [simplejson](https://simplejson.readthedocs.io/en/latest/) * [simplejson](https://simplejson.readthedocs.io/en/latest/)
* [requests](http://docs.python-requests.org/) * [requests](http://docs.python-requests.org/)
* [mutagen](https://mutagen.readthedocs.io/en/latest/) * [mutagen](https://mutagen.readthedocs.io/en/latest/)
* [watchdog](https://github.com/gorakhargosh/watchdog) * [watchdog](https://github.com/gorakhargosh/watchdog)
On a Debian-like OS (Debian, Ubuntu, Linux Mint, etc.), you can install them You can install all of them using `pip`:
this way:
$ apt-get install python-flask python-storm python-imaging python-simplesjon python-requests python-mutagen python-watchdog $ pip install -r requirements.txt
You may also need a database specific package: You may also need a database specific package:
* MySQL: `apt install python-mysqldb` * MySQL: `pip install pymysql` or `pip install mysqlclient`
* PostgreSQL: `apt-install python-psycopg2` * PostgreSQL: `pip install psycopg2`
Due to a bug in `storm`, `psycopg2` version 2.5 and later does not work
properly. You can either use version 2.4 or [patch storm][storm] yourself.
[storm]: https://bugs.launchpad.net/storm/+bug/1170063
### Configuration ### Configuration
Supysonic looks for two files for its configuration: `/etc/supysonic` and Supysonic looks for four files for its configuration: `/etc/supysonic`,
`~/.supysonic`, merging values from the two files. `~/.supysonic`, `~/.config/supysonic/supysonic.conf` and `supysonic.conf` in
the current folder, merging values from all files.
Configuration files must respect a structure similar to Windows INI file, with Configuration files must respect a structure similar to Windows INI file, with
`[section]` headers and using a `KEY = VALUE` or `KEY: VALUE` syntax. `[section]` headers and using a `KEY = VALUE` or `KEY: VALUE` syntax.
@ -85,7 +80,7 @@ The sample configuration (`config.sample`) looks like this:
```ini ```ini
[base] [base]
; A Storm database URI. See the 'schema' folder for schema creation scripts ; A database URI. See the 'schema' folder for schema creation scripts
; Default: sqlite:///tmp/supysonic/supysonic.db ; Default: sqlite:///tmp/supysonic/supysonic.db
;database_uri = sqlite:////var/supysonic/supysonic.db ;database_uri = sqlite:////var/supysonic/supysonic.db
;database_uri = mysql://supysonic:supysonic@localhost/supysonic ;database_uri = mysql://supysonic:supysonic@localhost/supysonic
@ -389,3 +384,7 @@ the case migration scripts will be provided in the `schema/migration`
folder, prefixed by the date of commit that introduced the changes. Those folder, prefixed by the date of commit that introduced the changes. Those
scripts shouldn't be used when initializing a new database, only when scripts shouldn't be used when initializing a new database, only when
upgrading from a previous schema. upgrading from a previous schema.
There could be both SQL scripts or Python scripts. The Python scripts require
arguments that are explained when the script is invoked with the `-h` flag.
If a migration script isn't provided for a specific database engine, it simply
means that no migration is needed for this engine.

View File

@ -12,9 +12,11 @@ import sys
from supysonic.cli import SupysonicCLI from supysonic.cli import SupysonicCLI
from supysonic.config import IniConfig from supysonic.config import IniConfig
from supysonic.db import init_database, release_database
if __name__ == "__main__": if __name__ == "__main__":
config = IniConfig.from_common_locations() config = IniConfig.from_common_locations()
init_database(config.BASE['database_uri'])
cli = SupysonicCLI(config) cli = SupysonicCLI(config)
if len(sys.argv) > 1: if len(sys.argv) > 1:
@ -22,3 +24,5 @@ if __name__ == "__main__":
else: else:
cli.cmdloop() cli.cmdloop()
release_database()

View File

@ -1,5 +1,5 @@
[base] [base]
; A Storm database URI. See the 'schema' folder for schema creation scripts ; A database URI. See the 'schema' folder for schema creation scripts
; Default: sqlite:///tmp/supysonic/supysonic.db ; Default: sqlite:///tmp/supysonic/supysonic.db
;database_uri = sqlite:////var/supysonic/supysonic.db ;database_uri = sqlite:////var/supysonic/supysonic.db
;database_uri = mysql://supysonic:supysonic@localhost/supysonic ;database_uri = mysql://supysonic:supysonic@localhost/supysonic

View File

@ -1,5 +1,5 @@
flask>=0.9 flask>=0.9
storm pony
Pillow Pillow
simplejson simplejson
requests>=1.0.0 requests>=1.0.0

View File

@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# This file is part of Supysonic.
# Supysonic is a Python implementation of the Subsonic server API.
#
# Copyright (C) 2017 Alban 'spl0k' Féron
#
# Distributed under terms of the GNU AGPLv3 license.
# Converts ids from hex-encoded strings to binary data
import argparse
try:
import MySQLdb as provider
except ImportError:
import pymysql as provider
from uuid import UUID
from warnings import filterwarnings
parser = argparse.ArgumentParser()
parser.add_argument('username')
parser.add_argument('password')
parser.add_argument('database')
parser.add_argument('-H', '--host', default = 'localhost', help = 'default: localhost')
args = parser.parse_args()
def process_table(connection, table, fields, nullable_fields = ()):
to_update = { field: set() for field in fields + nullable_fields }
c = connection.cursor()
c.execute('SELECT {1} FROM {0}'.format(table, ','.join(fields + nullable_fields)))
for row in c:
for field, value in zip(fields + nullable_fields, row):
if value is None or not isinstance(value, basestring):
continue
to_update[field].add(value)
for field, values in to_update.iteritems():
if not values:
continue
sql = 'UPDATE {0} SET {1}=%s WHERE {1}=%s'.format(table, field)
c.executemany(sql, map(lambda v: (UUID(v).bytes, v), values))
for field in fields:
sql = 'ALTER TABLE {0} MODIFY {1} BINARY(16) NOT NULL'.format(table, field)
c.execute(sql)
for field in nullable_fields:
sql = 'ALTER TABLE {0} MODIFY {1} BINARY(16)'.format(table, field)
c.execute(sql)
connection.commit()
filterwarnings('ignore', category = provider.Warning)
conn = provider.connect(host = args.host, user = args.username, passwd = args.password, db = args.database)
conn.cursor().execute('SET FOREIGN_KEY_CHECKS = 0')
process_table(conn, 'folder', ('id',), ('parent_id',))
process_table(conn, 'artist', ('id',))
process_table(conn, 'album', ('id', 'artist_id'))
process_table(conn, 'track', ('id', 'album_id', 'artist_id', 'root_folder_id', 'folder_id'))
process_table(conn, 'user', ('id',), ('last_play_id',))
process_table(conn, 'client_prefs', ('user_id',))
process_table(conn, 'starred_folder', ('user_id', 'starred_id'))
process_table(conn, 'starred_artist', ('user_id', 'starred_id'))
process_table(conn, 'starred_album', ('user_id', 'starred_id'))
process_table(conn, 'starred_track', ('user_id', 'starred_id'))
process_table(conn, 'rating_folder', ('user_id', 'rated_id'))
process_table(conn, 'rating_track', ('user_id', 'rated_id'))
process_table(conn, 'chat_message', ('id', 'user_id'))
process_table(conn, 'playlist', ('id', 'user_id'))
conn.cursor().execute('SET FOREIGN_KEY_CHECKS = 1')
conn.close()

View File

@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# This file is part of Supysonic.
# Supysonic is a Python implementation of the Subsonic server API.
#
# Copyright (C) 2017 Alban 'spl0k' Féron
#
# Distributed under terms of the GNU AGPLv3 license.
# Converts ids from hex-encoded strings to binary data
import argparse
import sqlite3
from uuid import UUID
parser = argparse.ArgumentParser()
parser.add_argument('dbfile', help = 'Path to the SQLite database file')
args = parser.parse_args()
def process_table(connection, table, fields):
to_update = { field: set() for field in fields }
c = connection.cursor()
for row in c.execute('SELECT {1} FROM {0}'.format(table, ','.join(fields))):
for field, value in zip(fields, row):
if value is None or not isinstance(value, basestring):
continue
to_update[field].add(value)
for field, values in to_update.iteritems():
sql = 'UPDATE {0} SET {1}=? WHERE {1}=?'.format(table, field)
c.executemany(sql, map(lambda v: (buffer(UUID(v).bytes), v), values))
connection.commit()
with sqlite3.connect(args.dbfile) as conn:
conn.cursor().execute('PRAGMA foreign_keys = OFF')
process_table(conn, 'folder', ('id', 'parent_id'))
process_table(conn, 'artist', ('id',))
process_table(conn, 'album', ('id', 'artist_id'))
process_table(conn, 'track', ('id', 'album_id', 'artist_id', 'root_folder_id', 'folder_id'))
process_table(conn, 'user', ('id', 'last_play_id'))
process_table(conn, 'client_prefs', ('user_id',))
process_table(conn, 'starred_folder', ('user_id', 'starred_id'))
process_table(conn, 'starred_artist', ('user_id', 'starred_id'))
process_table(conn, 'starred_album', ('user_id', 'starred_id'))
process_table(conn, 'starred_track', ('user_id', 'starred_id'))
process_table(conn, 'rating_folder', ('user_id', 'rated_id'))
process_table(conn, 'rating_track', ('user_id', 'rated_id'))
process_table(conn, 'chat_message', ('id', 'user_id'))
process_table(conn, 'playlist', ('id', 'user_id'))

View File

@ -1,35 +1,35 @@
CREATE TABLE folder ( CREATE TABLE folder (
id CHAR(36) PRIMARY KEY, id BINARY(16) PRIMARY KEY,
root BOOLEAN NOT NULL, root BOOLEAN NOT NULL,
name VARCHAR(256) NOT NULL, name VARCHAR(256) NOT NULL,
path VARCHAR(4096) NOT NULL, path VARCHAR(4096) NOT NULL,
created DATETIME NOT NULL, created DATETIME NOT NULL,
has_cover_art BOOLEAN NOT NULL, has_cover_art BOOLEAN NOT NULL,
last_scan INTEGER NOT NULL, last_scan INTEGER NOT NULL,
parent_id CHAR(36) REFERENCES folder parent_id BINARY(16) REFERENCES folder
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE artist ( CREATE TABLE artist (
id CHAR(36) PRIMARY KEY, id BINARY(16) PRIMARY KEY,
name VARCHAR(256) NOT NULL name VARCHAR(256) NOT NULL
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE album ( CREATE TABLE album (
id CHAR(36) PRIMARY KEY, id BINARY(16) PRIMARY KEY,
name VARCHAR(256) NOT NULL, name VARCHAR(256) NOT NULL,
artist_id CHAR(36) NOT NULL REFERENCES artist artist_id BINARY(16) NOT NULL REFERENCES artist
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE track ( CREATE TABLE track (
id CHAR(36) PRIMARY KEY, id BINARY(16) PRIMARY KEY,
disc INTEGER NOT NULL, disc INTEGER NOT NULL,
number INTEGER NOT NULL, number INTEGER NOT NULL,
title VARCHAR(256) NOT NULL, title VARCHAR(256) NOT NULL,
year INTEGER, year INTEGER,
genre VARCHAR(256), genre VARCHAR(256),
duration INTEGER NOT NULL, duration INTEGER NOT NULL,
album_id CHAR(36) NOT NULL REFERENCES album, album_id BINARY(16) NOT NULL REFERENCES album,
artist_id CHAR(36) NOT NULL REFERENCES artist, artist_id BINARY(16) NOT NULL REFERENCES artist,
bitrate INTEGER NOT NULL, bitrate INTEGER NOT NULL,
path VARCHAR(4096) NOT NULL, path VARCHAR(4096) NOT NULL,
content_type VARCHAR(32) NOT NULL, content_type VARCHAR(32) NOT NULL,
@ -37,12 +37,12 @@ CREATE TABLE track (
last_modification INTEGER NOT NULL, last_modification INTEGER NOT NULL,
play_count INTEGER NOT NULL, play_count INTEGER NOT NULL,
last_play DATETIME, last_play DATETIME,
root_folder_id CHAR(36) NOT NULL REFERENCES folder, root_folder_id BINARY(16) NOT NULL REFERENCES folder,
folder_id CHAR(36) NOT NULL REFERENCES folder folder_id BINARY(16) NOT NULL REFERENCES folder
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE user ( CREATE TABLE user (
id CHAR(36) PRIMARY KEY, id BINARY(16) PRIMARY KEY,
name VARCHAR(64) NOT NULL, name VARCHAR(64) NOT NULL,
mail VARCHAR(256), mail VARCHAR(256),
password CHAR(40) NOT NULL, password CHAR(40) NOT NULL,
@ -50,12 +50,12 @@ CREATE TABLE user (
admin BOOLEAN NOT NULL, admin BOOLEAN NOT NULL,
lastfm_session CHAR(32), lastfm_session CHAR(32),
lastfm_status BOOLEAN NOT NULL, lastfm_status BOOLEAN NOT NULL,
last_play_id CHAR(36) REFERENCES track, last_play_id BINARY(16) REFERENCES track,
last_play_date DATETIME last_play_date DATETIME
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE client_prefs ( CREATE TABLE client_prefs (
user_id CHAR(36) NOT NULL, user_id BINARY(16) NOT NULL,
client_name VARCHAR(32) NOT NULL, client_name VARCHAR(32) NOT NULL,
format VARCHAR(8), format VARCHAR(8),
bitrate INTEGER, bitrate INTEGER,
@ -63,57 +63,57 @@ CREATE TABLE client_prefs (
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE starred_folder ( CREATE TABLE starred_folder (
user_id CHAR(36) NOT NULL REFERENCES user, user_id BINARY(16) NOT NULL REFERENCES user,
starred_id CHAR(36) NOT NULL REFERENCES folder, starred_id BINARY(16) NOT NULL REFERENCES folder,
date DATETIME NOT NULL, date DATETIME NOT NULL,
PRIMARY KEY (user_id, starred_id) PRIMARY KEY (user_id, starred_id)
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE starred_artist ( CREATE TABLE starred_artist (
user_id CHAR(36) NOT NULL REFERENCES user, user_id BINARY(16) NOT NULL REFERENCES user,
starred_id CHAR(36) NOT NULL REFERENCES artist, starred_id BINARY(16) NOT NULL REFERENCES artist,
date DATETIME NOT NULL, date DATETIME NOT NULL,
PRIMARY KEY (user_id, starred_id) PRIMARY KEY (user_id, starred_id)
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE starred_album ( CREATE TABLE starred_album (
user_id CHAR(36) NOT NULL REFERENCES user, user_id BINARY(16) NOT NULL REFERENCES user,
starred_id CHAR(36) NOT NULL REFERENCES album, starred_id BINARY(16) NOT NULL REFERENCES album,
date DATETIME NOT NULL, date DATETIME NOT NULL,
PRIMARY KEY (user_id, starred_id) PRIMARY KEY (user_id, starred_id)
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE starred_track ( CREATE TABLE starred_track (
user_id CHAR(36) NOT NULL REFERENCES user, user_id BINARY(16) NOT NULL REFERENCES user,
starred_id CHAR(36) NOT NULL REFERENCES track, starred_id BINARY(16) NOT NULL REFERENCES track,
date DATETIME NOT NULL, date DATETIME NOT NULL,
PRIMARY KEY (user_id, starred_id) PRIMARY KEY (user_id, starred_id)
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE rating_folder ( CREATE TABLE rating_folder (
user_id CHAR(36) NOT NULL REFERENCES user, user_id BINARY(16) NOT NULL REFERENCES user,
rated_id CHAR(36) NOT NULL REFERENCES folder, rated_id BINARY(16) NOT NULL REFERENCES folder,
rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5), rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
PRIMARY KEY (user_id, rated_id) PRIMARY KEY (user_id, rated_id)
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE rating_track ( CREATE TABLE rating_track (
user_id CHAR(36) NOT NULL REFERENCES user, user_id BINARY(16) NOT NULL REFERENCES user,
rated_id CHAR(36) NOT NULL REFERENCES track, rated_id BINARY(16) NOT NULL REFERENCES track,
rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5), rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
PRIMARY KEY (user_id, rated_id) PRIMARY KEY (user_id, rated_id)
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE chat_message ( CREATE TABLE chat_message (
id CHAR(36) PRIMARY KEY, id BINARY(16) PRIMARY KEY,
user_id CHAR(36) NOT NULL REFERENCES user, user_id BINARY(16) NOT NULL REFERENCES user,
time INTEGER NOT NULL, time INTEGER NOT NULL,
message VARCHAR(512) NOT NULL message VARCHAR(512) NOT NULL
) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci; ) DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE TABLE playlist ( CREATE TABLE playlist (
id CHAR(36) PRIMARY KEY, id BINARY(16) PRIMARY KEY,
user_id CHAR(36) NOT NULL REFERENCES user, user_id BINARY(16) NOT NULL REFERENCES user,
name VARCHAR(256) NOT NULL, name VARCHAR(256) NOT NULL,
comment VARCHAR(256), comment VARCHAR(256),
public BOOLEAN NOT NULL, public BOOLEAN NOT NULL,

View File

@ -18,15 +18,15 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import binascii
import simplejson import simplejson
import uuid import uuid
import binascii
from flask import request, current_app as app from flask import request, current_app as app
from pony.orm import db_session, ObjectNotFound
from xml.dom import minidom from xml.dom import minidom
from xml.etree import ElementTree from xml.etree import ElementTree
from ..web import store
from ..managers.user import UserManager from ..managers.user import UserManager
@app.before_request @app.before_request
@ -70,7 +70,7 @@ def authorize():
error = request.error_formatter(40, 'Unauthorized'), 401 error = request.error_formatter(40, 'Unauthorized'), 401
if request.authorization: if request.authorization:
status, user = UserManager.try_auth(store, request.authorization.username, request.authorization.password) status, user = UserManager.try_auth(request.authorization.username, request.authorization.password)
if status == UserManager.SUCCESS: if status == UserManager.SUCCESS:
request.username = request.authorization.username request.username = request.authorization.username
request.user = user request.user = user
@ -81,7 +81,7 @@ def authorize():
return error return error
password = decode_password(password) password = decode_password(password)
status, user = UserManager.try_auth(store, username, password) status, user = UserManager.try_auth(username, password)
if status != UserManager.SUCCESS: if status != UserManager.SUCCESS:
return error return error
@ -97,15 +97,13 @@ def get_client_prefs():
return request.error_formatter(10, 'Missing required parameter') return request.error_formatter(10, 'Missing required parameter')
client = request.values.get('c') client = request.values.get('c')
prefs = store.get(ClientPrefs, (request.user.id, client)) with db_session:
if not prefs: try:
prefs = ClientPrefs() ClientPrefs[request.user.id, client]
prefs.user_id = request.user.id except ObjectNotFound:
prefs.client_name = client ClientPrefs(user = User[request.user.id], client_name = client)
store.add(prefs)
store.commit()
request.prefs = prefs request.client = client
@app.after_request @app.after_request
def set_headers(response): def set_headers(response):
@ -218,19 +216,20 @@ class ResponseHelper:
return str(value).lower() return str(value).lower()
return str(value) return str(value)
def get_entity(req, ent, param = 'id'): def get_entity(req, cls, param = 'id'):
eid = req.values.get(param) eid = req.values.get(param)
if not eid: if not eid:
return False, req.error_formatter(10, 'Missing %s id' % ent.__name__) return False, req.error_formatter(10, 'Missing %s id' % cls.__name__)
try: try:
eid = uuid.UUID(eid) eid = uuid.UUID(eid)
except: except:
return False, req.error_formatter(0, 'Invalid %s id' % ent.__name__) return False, req.error_formatter(0, 'Invalid %s id' % cls.__name__)
entity = store.get(ent, eid) try:
if not entity: entity = cls[eid]
return False, (req.error_formatter(70, '%s not found' % ent.__name__), 404) except ObjectNotFound:
return False, (req.error_formatter(70, '%s not found' % cls.__name__), 404)
return True, entity return True, entity

View File

@ -23,12 +23,10 @@ import uuid
from datetime import timedelta from datetime import timedelta
from flask import request, current_app as app from flask import request, current_app as app
from storm.expr import Desc, Avg, Min, Max from pony.orm import db_session, select, desc, avg, max, min, count
from storm.info import ClassAlias
from ..db import Folder, Artist, Album, Track, RatingFolder, StarredFolder, StarredArtist, StarredAlbum, StarredTrack, User from ..db import Folder, Artist, Album, Track, RatingFolder, StarredFolder, StarredArtist, StarredAlbum, StarredTrack, User
from ..db import now from ..db import now
from ..web import store
@app.route('/rest/getRandomSongs.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getRandomSongs.view', methods = [ 'GET', 'POST' ])
def rand_songs(): def rand_songs():
@ -43,31 +41,24 @@ def rand_songs():
except: except:
return request.error_formatter(0, 'Invalid parameter format') return request.error_formatter(0, 'Invalid parameter format')
query = store.find(Track) query = Track.select()
if fromYear: if fromYear:
query = query.find(Track.year >= fromYear) query = query.filter(lambda t: t.year >= fromYear)
if toYear: if toYear:
query = query.find(Track.year <= toYear) query = query.filter(lambda t: t.year <= toYear)
if genre: if genre:
query = query.find(Track.genre == genre) query = query.filter(lambda t: t.genre == genre)
if fid: if fid:
if not store.find(Folder, Folder.id == fid, Folder.root == True).one(): with db_session:
if not Folder.exists(id = fid, root = True):
return request.error_formatter(70, 'Unknown folder') return request.error_formatter(70, 'Unknown folder')
query = query.find(Track.root_folder_id == fid) query = query.filter(lambda t: t.root_folder.id == fid)
count = query.count()
if not count:
return request.formatter({ 'randomSongs': {} })
tracks = []
for _ in xrange(size):
x = random.choice(xrange(count))
tracks.append(query[x])
with db_session:
return request.formatter({ return request.formatter({
'randomSongs': { 'randomSongs': {
'song': [ t.as_subsonic_child(request.user, request.prefs) for t in tracks ] 'song': [ t.as_subsonic_child(request.user, request.client) for t in query.random(size) ]
} }
}) })
@ -82,44 +73,35 @@ def album_list():
except: except:
return request.error_formatter(0, 'Invalid parameter format') return request.error_formatter(0, 'Invalid parameter format')
query = store.find(Folder, Track.folder_id == Folder.id) query = select(t.folder for t in Track)
if ltype == 'random': if ltype == 'random':
albums = [] with db_session:
count = query.count()
if not count:
return request.formatter({ 'albumList': {} })
for _ in xrange(size):
x = random.choice(xrange(count))
albums.append(query[x])
return request.formatter({ return request.formatter({
'albumList': { 'albumList': {
'album': [ a.as_subsonic_child(request.user) for a in albums ] 'album': [ a.as_subsonic_child(request.user) for a in query.random(size) ]
} }
}) })
elif ltype == 'newest': elif ltype == 'newest':
query = query.order_by(Desc(Folder.created)).config(distinct = True) query = query.order_by(desc(Folder.created))
elif ltype == 'highest': elif ltype == 'highest':
query = query.find(RatingFolder.rated_id == Folder.id).group_by(Folder.id).order_by(Desc(Avg(RatingFolder.rating))) query = query.order_by(lambda f: desc(avg(f.ratings.rating)))
elif ltype == 'frequent': elif ltype == 'frequent':
query = query.group_by(Folder.id).order_by(Desc(Avg(Track.play_count))) query = query.order_by(lambda f: desc(avg(f.tracks.play_count)))
elif ltype == 'recent': elif ltype == 'recent':
query = query.group_by(Folder.id).order_by(Desc(Max(Track.last_play))) query = query.order_by(lambda f: desc(max(f.tracks.last_play)))
elif ltype == 'starred': elif ltype == 'starred':
query = query.find(StarredFolder.starred_id == Folder.id, User.id == StarredFolder.user_id, User.name == request.username) query = select(s.starred for s in StarredFolder if s.user.id == request.user.id and count(s.starred.tracks) > 0)
elif ltype == 'alphabeticalByName': elif ltype == 'alphabeticalByName':
query = query.order_by(Folder.name).config(distinct = True) query = query.order_by(Folder.name)
elif ltype == 'alphabeticalByArtist': elif ltype == 'alphabeticalByArtist':
parent = ClassAlias(Folder) query = query.order_by(lambda f: f.parent.name + f.name)
query = query.find(Folder.parent_id == parent.id).order_by(parent.name, Folder.name).config(distinct = True)
else: else:
return request.error_formatter(0, 'Unknown search type') return request.error_formatter(0, 'Unknown search type')
with db_session:
return request.formatter({ return request.formatter({
'albumList': { 'albumList': {
'album': [ f.as_subsonic_child(request.user) for f in query[offset:offset+size] ] 'album': [ f.as_subsonic_child(request.user) for f in query.limit(size, offset) ]
} }
}) })
@ -134,76 +116,71 @@ def album_list_id3():
except: except:
return request.error_formatter(0, 'Invalid parameter format') return request.error_formatter(0, 'Invalid parameter format')
query = store.find(Album) query = Album.select()
if ltype == 'random': if ltype == 'random':
albums = [] with db_session:
count = query.count()
if not count:
return request.formatter({ 'albumList2': {} })
for _ in xrange(size):
x = random.choice(xrange(count))
albums.append(query[x])
return request.formatter({ return request.formatter({
'albumList2': { 'albumList2': {
'album': [ a.as_subsonic_album(request.user) for a in albums ] 'album': [ a.as_subsonic_album(request.user) for a in query.random(size) ]
} }
}) })
elif ltype == 'newest': elif ltype == 'newest':
query = query.find(Track.album_id == Album.id).group_by(Album.id).order_by(Desc(Min(Track.created))) query = query.order_by(lambda a: desc(min(a.tracks.created)))
elif ltype == 'frequent': elif ltype == 'frequent':
query = query.find(Track.album_id == Album.id).group_by(Album.id).order_by(Desc(Avg(Track.play_count))) query = query.order_by(lambda a: desc(avg(a.tracks.play_count)))
elif ltype == 'recent': elif ltype == 'recent':
query = query.find(Track.album_id == Album.id).group_by(Album.id).order_by(Desc(Max(Track.last_play))) query = query.order_by(lambda a: desc(max(a.tracks.last_play)))
elif ltype == 'starred': elif ltype == 'starred':
query = query.find(StarredAlbum.starred_id == Album.id, User.id == StarredAlbum.user_id, User.name == request.username) query = select(s.starred for s in StarredAlbum if s.user.id == request.user.id)
elif ltype == 'alphabeticalByName': elif ltype == 'alphabeticalByName':
query = query.order_by(Album.name) query = query.order_by(Album.name)
elif ltype == 'alphabeticalByArtist': elif ltype == 'alphabeticalByArtist':
query = query.find(Artist.id == Album.artist_id).order_by(Artist.name, Album.name) query = query.order_by(lambda a: a.artist.name + a.name)
else: else:
return request.error_formatter(0, 'Unknown search type') return request.error_formatter(0, 'Unknown search type')
with db_session:
return request.formatter({ return request.formatter({
'albumList2': { 'albumList2': {
'album': [ f.as_subsonic_album(request.user) for f in query[offset:offset+size] ] 'album': [ f.as_subsonic_album(request.user) for f in query.limit(size, offset) ]
} }
}) })
@app.route('/rest/getNowPlaying.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getNowPlaying.view', methods = [ 'GET', 'POST' ])
@db_session
def now_playing(): def now_playing():
query = store.find(User, Track.id == User.last_play_id) query = User.select(lambda u: u.last_play is not None and u.last_play_date + timedelta(minutes = 3) > now())
return request.formatter({ return request.formatter({
'nowPlaying': { 'nowPlaying': {
'entry': [ dict( 'entry': [ dict(
u.last_play.as_subsonic_child(request.user, request.prefs).items() + u.last_play.as_subsonic_child(request.user, request.client).items() +
{ 'username': u.name, 'minutesAgo': (now() - u.last_play_date).seconds / 60, 'playerId': 0 }.items() { 'username': u.name, 'minutesAgo': (now() - u.last_play_date).seconds / 60, 'playerId': 0 }.items()
) for u in query if u.last_play_date + timedelta(seconds = u.last_play.duration * 2) > now() ] ) for u in query ]
} }
}) })
@app.route('/rest/getStarred.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getStarred.view', methods = [ 'GET', 'POST' ])
@db_session
def get_starred(): def get_starred():
folders = store.find(StarredFolder, StarredFolder.user_id == User.id, User.name == request.username) folders = select(s.starred for s in StarredFolder if s.user.id == request.user.id)
return request.formatter({ return request.formatter({
'starred': { 'starred': {
'artist': [ { 'id': str(sf.starred_id), 'name': sf.starred.name } for sf in folders.find(Folder.parent_id == StarredFolder.starred_id, Track.folder_id == Folder.id).config(distinct = True) ], 'artist': [ { 'id': str(sf.id), 'name': sf.name } for sf in folders.filter(lambda f: count(f.tracks) == 0) ],
'album': [ sf.starred.as_subsonic_child(request.user) for sf in folders.find(Track.folder_id == StarredFolder.starred_id).config(distinct = True) ], 'album': [ sf.as_subsonic_child(request.user) for sf in folders.filter(lambda f: count(f.tracks) > 0) ],
'song': [ st.starred.as_subsonic_child(request.user, request.prefs) for st in store.find(StarredTrack, StarredTrack.user_id == User.id, User.name == request.username) ] 'song': [ st.as_subsonic_child(request.user, request.client) for st in select(s.starred for s in StarredTrack if s.user.id == request.user.id) ]
} }
}) })
@app.route('/rest/getStarred2.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getStarred2.view', methods = [ 'GET', 'POST' ])
@db_session
def get_starred_id3(): def get_starred_id3():
return request.formatter({ return request.formatter({
'starred2': { 'starred2': {
'artist': [ sa.starred.as_subsonic_artist(request.user) for sa in store.find(StarredArtist, StarredArtist.user_id == User.id, User.name == request.username) ], 'artist': [ sa.as_subsonic_artist(request.user) for sa in select(s.starred for s in StarredArtist if s.user.id == request.user.id) ],
'album': [ sa.starred.as_subsonic_album(request.user) for sa in store.find(StarredAlbum, StarredAlbum.user_id == User.id, User.name == request.username) ], 'album': [ sa.as_subsonic_album(request.user) for sa in select(s.starred for s in StarredAlbum if s.user.id == request.user.id) ],
'song': [ st.starred.as_subsonic_child(request.user, request.prefs) for st in store.find(StarredTrack, StarredTrack.user_id == User.id, User.name == request.username) ] 'song': [ st.as_subsonic_child(request.user, request.client) for st in select(s.starred for s in StarredTrack if s.user.id == request.user.id) ]
} }
}) })

View File

@ -22,20 +22,22 @@ import time
import uuid import uuid
from flask import request, current_app as app from flask import request, current_app as app
from pony.orm import db_session, delete
from pony.orm import ObjectNotFound
from ..db import Track, Album, Artist, Folder from ..db import Track, Album, Artist, Folder, User
from ..db import StarredTrack, StarredAlbum, StarredArtist, StarredFolder from ..db import StarredTrack, StarredAlbum, StarredArtist, StarredFolder
from ..db import RatingTrack, RatingFolder from ..db import RatingTrack, RatingFolder
from ..lastfm import LastFm from ..lastfm import LastFm
from ..web import store
from . import get_entity from . import get_entity
def try_star(ent, starred_ent, eid): @db_session
def try_star(cls, starred_cls, eid):
""" Stars an entity """ Stars an entity
:param ent: entity class, Folder, Artist, Album or Track :param cls: entity class, Folder, Artist, Album or Track
:param starred_ent: class used for the db representation of the starring of ent :param starred_cls: class used for the db representation of the starring of ent
:param eid: id of the entity to star :param eid: id of the entity to star
:return error dict, if any. None otherwise :return error dict, if any. None otherwise
""" """
@ -43,26 +45,27 @@ def try_star(ent, starred_ent, eid):
try: try:
uid = uuid.UUID(eid) uid = uuid.UUID(eid)
except: except:
return { 'code': 0, 'message': 'Invalid {} id {}'.format(ent.__name__, eid) } return { 'code': 0, 'message': 'Invalid {} id {}'.format(cls.__name__, eid) }
if store.get(starred_ent, (request.user.id, uid)): try:
return { 'code': 0, 'message': '{} {} already starred'.format(ent.__name__, eid) } e = cls[uid]
except ObjectNotFound:
return { 'code': 70, 'message': 'Unknown {} id {}'.format(cls.__name__, eid) }
e = store.get(ent, uid) try:
if not e: starred_cls[request.user.id, uid]
return { 'code': 70, 'message': 'Unknown {} id {}'.format(ent.__name__, eid) } return { 'code': 0, 'message': '{} {} already starred'.format(cls.__name__, eid) }
except ObjectNotFound:
starred = starred_ent() pass
starred.user_id = request.user.id
starred.starred_id = uid
store.add(starred)
starred_cls(user = User[request.user.id], starred = e)
return None return None
def try_unstar(starred_ent, eid): @db_session
def try_unstar(starred_cls, eid):
""" Unstars an entity """ Unstars an entity
:param starred_ent: class used for the db representation of the starring of the entity :param starred_cls: class used for the db representation of the starring of the entity
:param eid: id of the entity to unstar :param eid: id of the entity to unstar
:return error dict, if any. None otherwise :return error dict, if any. None otherwise
""" """
@ -72,7 +75,7 @@ def try_unstar(starred_ent, eid):
except: except:
return { 'code': 0, 'message': 'Invalid id {}'.format(eid) } return { 'code': 0, 'message': 'Invalid id {}'.format(eid) }
store.find(starred_ent, starred_ent.user_id == request.user.id, starred_ent.starred_id == uid).remove() delete(s for s in starred_cls if s.user.id == request.user.id and s.starred.id == uid)
return None return None
def merge_errors(errors): def merge_errors(errors):
@ -106,7 +109,6 @@ def star():
for arId in artistId: for arId in artistId:
errors.append(try_star(Artist, StarredArtist, arId)) errors.append(try_star(Artist, StarredArtist, arId))
store.commit()
error = merge_errors(errors) error = merge_errors(errors)
return request.formatter({ 'error': error }, error = True) if error else request.formatter({}) return request.formatter({ 'error': error }, error = True) if error else request.formatter({})
@ -130,7 +132,6 @@ def unstar():
for arId in artistId: for arId in artistId:
errors.append(try_unstar(StarredArtist, arId)) errors.append(try_unstar(StarredArtist, arId))
store.commit()
error = merge_errors(errors) error = merge_errors(errors)
return request.formatter({ 'error': error }, error = True) if error else request.formatter({}) return request.formatter({ 'error': error }, error = True) if error else request.formatter({})
@ -149,32 +150,31 @@ def rate():
if not rating in xrange(6): if not rating in xrange(6):
return request.error_formatter(0, 'rating must be between 0 and 5 (inclusive)') return request.error_formatter(0, 'rating must be between 0 and 5 (inclusive)')
with db_session:
if rating == 0: if rating == 0:
store.find(RatingTrack, RatingTrack.user_id == request.user.id, RatingTrack.rated_id == uid).remove() delete(r for r in RatingTrack if r.user.id == request.user.id and r.rated.id == uid)
store.find(RatingFolder, RatingFolder.user_id == request.user.id, RatingFolder.rated_id == uid).remove() delete(r for r in RatingFolder if r.user.id == request.user.id and r.rated.id == uid)
else: else:
rated = store.get(Track, uid) try:
rating_ent = RatingTrack rated = Track[uid]
if not rated: rating_cls = RatingTrack
rated = store.get(Folder, uid) except ObjectNotFound:
rating_ent = RatingFolder try:
if not rated: rated = Folder[uid]
rating_cls = RatingFolder
except ObjectNotFound:
return request.error_formatter(70, 'Unknown id') return request.error_formatter(70, 'Unknown id')
rating_info = store.get(rating_ent, (request.user.id, uid)) try:
if rating_info: rating_info = rating_cls[request.user.id, uid]
rating_info.rating = rating rating_info.rating = rating
else: except ObjectNotFound:
rating_info = rating_ent() rating_cls(user = User[request.user.id], rated = rated, rating = rating)
rating_info.user_id = request.user.id
rating_info.rated_id = uid
rating_info.rating = rating
store.add(rating_info)
store.commit()
return request.formatter({}) return request.formatter({})
@app.route('/rest/scrobble.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/scrobble.view', methods = [ 'GET', 'POST' ])
@db_session
def scrobble(): def scrobble():
status, res = get_entity(request, Track) status, res = get_entity(request, Track)
if not status: if not status:
@ -190,7 +190,7 @@ def scrobble():
else: else:
t = int(time.time()) t = int(time.time())
lfm = LastFm(app.config['LASTFM'], request.user, app.logger) lfm = LastFm(app.config['LASTFM'], User[request.user.id], app.logger)
if submission in (None, '', True, 'true', 'True', 1, '1'): if submission in (None, '', True, 'true', 'True', 1, '1'):
lfm.scrobble(res, t) lfm.scrobble(res, t)

View File

@ -22,24 +22,27 @@ import string
import uuid import uuid
from flask import request, current_app as app from flask import request, current_app as app
from pony.orm import db_session
from pony.orm import ObjectNotFound
from ..db import Folder, Artist, Album, Track from ..db import Folder, Artist, Album, Track
from ..web import store
from . import get_entity from . import get_entity
@app.route('/rest/getMusicFolders.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getMusicFolders.view', methods = [ 'GET', 'POST' ])
@db_session
def list_folders(): def list_folders():
return request.formatter({ return request.formatter({
'musicFolders': { 'musicFolders': {
'musicFolder': [ { 'musicFolder': [ {
'id': str(f.id), 'id': str(f.id),
'name': f.name 'name': f.name
} for f in store.find(Folder, Folder.root == True).order_by(Folder.name) ] } for f in Folder.select(lambda f: f.root).order_by(Folder.name) ]
} }
}) })
@app.route('/rest/getIndexes.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getIndexes.view', methods = [ 'GET', 'POST' ])
@db_session
def list_indexes(): def list_indexes():
musicFolderId = request.values.get('musicFolderId') musicFolderId = request.values.get('musicFolderId')
ifModifiedSince = request.values.get('ifModifiedSince') ifModifiedSince = request.values.get('ifModifiedSince')
@ -50,33 +53,31 @@ def list_indexes():
return request.error_formatter(0, 'Invalid timestamp') return request.error_formatter(0, 'Invalid timestamp')
if musicFolderId is None: if musicFolderId is None:
folder = store.find(Folder, Folder.root == True) folders = Folder.select(lambda f: f.root)[:]
else: else:
try: try:
mfid = uuid.UUID(musicFolderId) mfid = uuid.UUID(musicFolderId)
except: except:
return request.error_formatter(0, 'Invalid id') return request.error_formatter(0, 'Invalid id')
folder = store.get(Folder, mfid) try:
folder = Folder[mfid]
if not folder or (type(folder) is Folder and not folder.root): except ObjectNotFound:
return request.error_formatter(70, 'Folder not found') return request.error_formatter(70, 'Folder not found')
if not folder.root:
return request.error_formatter(70, 'Folder not found')
folders = [ folder ]
last_modif = max(map(lambda f: f.last_scan, folder)) if type(folder) is not Folder else folder.last_scan last_modif = max(map(lambda f: f.last_scan, folders))
if ifModifiedSince is not None and last_modif < ifModifiedSince:
if (not ifModifiedSince is None) and last_modif < ifModifiedSince:
return request.formatter({ 'indexes': { 'lastModified': last_modif * 1000 } }) return request.formatter({ 'indexes': { 'lastModified': last_modif * 1000 } })
# The XSD lies, we don't return artists but a directory structure # The XSD lies, we don't return artists but a directory structure
if type(folder) is not Folder:
artists = [] artists = []
childs = [] children = []
for f in folder: for f in folders:
artists += f.children artists += f.children.select()[:]
childs += f.tracks children += f.tracks.select()[:]
else:
artists = folder.children
childs = folder.tracks
indexes = {} indexes = {}
for artist in artists: for artist in artists:
@ -101,11 +102,12 @@ def list_indexes():
'name': a.name 'name': a.name
} for a in sorted(v, key = lambda a: a.name.lower()) ] } for a in sorted(v, key = lambda a: a.name.lower()) ]
} for k, v in sorted(indexes.iteritems()) ], } for k, v in sorted(indexes.iteritems()) ],
'child': [ c.as_subsonic_child(request.user, request.prefs) for c in sorted(childs, key = lambda t: t.sort_key()) ] 'child': [ c.as_subsonic_child(request.user, request.client) for c in sorted(children, key = lambda t: t.sort_key()) ]
} }
}) })
@app.route('/rest/getMusicDirectory.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getMusicDirectory.view', methods = [ 'GET', 'POST' ])
@db_session
def show_directory(): def show_directory():
status, res = get_entity(request, Folder) status, res = get_entity(request, Folder)
if not status: if not status:
@ -114,18 +116,19 @@ def show_directory():
directory = { directory = {
'id': str(res.id), 'id': str(res.id),
'name': res.name, 'name': res.name,
'child': [ f.as_subsonic_child(request.user) for f in sorted(res.children, key = lambda c: c.name.lower()) ] + [ t.as_subsonic_child(request.user, request.prefs) for t in sorted(res.tracks, key = lambda t: t.sort_key()) ] 'child': [ f.as_subsonic_child(request.user) for f in res.children.order_by(lambda c: c.name.lower()) ] + [ t.as_subsonic_child(request.user, request.client) for t in sorted(res.tracks, key = lambda t: t.sort_key()) ]
} }
if not res.root: if not res.root:
directory['parent'] = str(res.parent_id) directory['parent'] = str(res.parent.id)
return request.formatter({ 'directory': directory }) return request.formatter({ 'directory': directory })
@app.route('/rest/getArtists.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getArtists.view', methods = [ 'GET', 'POST' ])
@db_session
def list_artists(): def list_artists():
# According to the API page, there are no parameters? # According to the API page, there are no parameters?
indexes = {} indexes = {}
for artist in store.find(Artist): for artist in Artist.select():
index = artist.name[0].upper() if artist.name else '?' index = artist.name[0].upper() if artist.name else '?'
if index in map(str, xrange(10)): if index in map(str, xrange(10)):
index = '#' index = '#'
@ -147,6 +150,7 @@ def list_artists():
}) })
@app.route('/rest/getArtist.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getArtist.view', methods = [ 'GET', 'POST' ])
@db_session
def artist_info(): def artist_info():
status, res = get_entity(request, Artist) status, res = get_entity(request, Artist)
if not status: if not status:
@ -160,23 +164,25 @@ def artist_info():
return request.formatter({ 'artist': info }) return request.formatter({ 'artist': info })
@app.route('/rest/getAlbum.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getAlbum.view', methods = [ 'GET', 'POST' ])
@db_session
def album_info(): def album_info():
status, res = get_entity(request, Album) status, res = get_entity(request, Album)
if not status: if not status:
return res return res
info = res.as_subsonic_album(request.user) info = res.as_subsonic_album(request.user)
info['song'] = [ t.as_subsonic_child(request.user, request.prefs) for t in sorted(res.tracks, key = lambda t: t.sort_key()) ] info['song'] = [ t.as_subsonic_child(request.user, request.client) for t in sorted(res.tracks, key = lambda t: t.sort_key()) ]
return request.formatter({ 'album': info }) return request.formatter({ 'album': info })
@app.route('/rest/getSong.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getSong.view', methods = [ 'GET', 'POST' ])
@db_session
def track_info(): def track_info():
status, res = get_entity(request, Track) status, res = get_entity(request, Track)
if not status: if not status:
return res return res
return request.formatter({ 'song': res.as_subsonic_child(request.user, request.prefs) }) return request.formatter({ 'song': res.as_subsonic_child(request.user, request.client) })
@app.route('/rest/getVideos.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getVideos.view', methods = [ 'GET', 'POST' ])
def list_videos(): def list_videos():

View File

@ -19,9 +19,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from flask import request, current_app as app from flask import request, current_app as app
from pony.orm import db_session
from ..db import ChatMessage from ..db import ChatMessage, User
from ..web import store
@app.route('/rest/getChatMessages.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getChatMessages.view', methods = [ 'GET', 'POST' ])
def get_chat(): def get_chat():
@ -31,9 +31,10 @@ def get_chat():
except: except:
return request.error_formatter(0, 'Invalid parameter') return request.error_formatter(0, 'Invalid parameter')
query = store.find(ChatMessage).order_by(ChatMessage.time) with db_session:
query = ChatMessage.select().order_by(ChatMessage.time)
if since: if since:
query = query.find(ChatMessage.time > since) query = query.filter(lambda m: m.time > since)
return request.formatter({ 'chatMessages': { 'chatMessage': [ msg.responsize() for msg in query ] }}) return request.formatter({ 'chatMessages': { 'chatMessage': [ msg.responsize() for msg in query ] }})
@ -43,10 +44,8 @@ def add_chat_message():
if not msg: if not msg:
return request.error_formatter(10, 'Missing message') return request.error_formatter(10, 'Missing message')
chat = ChatMessage() with db_session:
chat.user_id = request.user.id ChatMessage(user = User[request.user.id], message = msg)
chat.message = msg
store.add(chat)
store.commit()
return request.formatter({}) return request.formatter({})

View File

@ -26,10 +26,10 @@ import subprocess
from flask import request, send_file, Response, current_app as app from flask import request, send_file, Response, current_app as app
from PIL import Image from PIL import Image
from pony.orm import db_session
from xml.etree import ElementTree from xml.etree import ElementTree
from .. import scanner from .. import scanner
from ..web import store
from ..db import Track, Album, Artist, Folder, User, ClientPrefs, now from ..db import Track, Album, Artist, Folder, User, ClientPrefs, now
from . import get_entity from . import get_entity
@ -43,6 +43,7 @@ def prepare_transcoding_cmdline(base_cmdline, input_file, input_format, output_f
return ret return ret
@app.route('/rest/stream.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/stream.view', methods = [ 'GET', 'POST' ])
@db_session
def stream_media(): def stream_media():
status, res = get_entity(request, Track) status, res = get_entity(request, Track)
if not status: if not status:
@ -57,10 +58,11 @@ def stream_media():
dst_bitrate = res.bitrate dst_bitrate = res.bitrate
dst_mimetype = res.content_type dst_mimetype = res.content_type
if request.prefs.format: prefs = ClientPrefs.get(lambda p: p.user.id == request.user.id and p.client_name == request.client)
dst_suffix = request.prefs.format if prefs.format:
if request.prefs.bitrate and request.prefs.bitrate < dst_bitrate: dst_suffix = prefs.format
dst_bitrate = request.prefs.bitrate if prefs.bitrate and prefs.bitrate < dst_bitrate:
dst_bitrate = prefs.bitrate
if maxBitRate: if maxBitRate:
try: try:
@ -121,14 +123,15 @@ def stream_media():
res.play_count = res.play_count + 1 res.play_count = res.play_count + 1
res.last_play = now() res.last_play = now()
request.user.last_play = res user = User[request.user.id]
request.user.last_play_date = now() user.last_play = res
store.commit() user.last_play_date = now()
return response return response
@app.route('/rest/download.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/download.view', methods = [ 'GET', 'POST' ])
def download_media(): def download_media():
with db_session:
status, res = get_entity(request, Track) status, res = get_entity(request, Track)
if not status: if not status:
return res return res
@ -137,6 +140,7 @@ def download_media():
@app.route('/rest/getCoverArt.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getCoverArt.view', methods = [ 'GET', 'POST' ])
def cover_art(): def cover_art():
with db_session:
status, res = get_entity(request, Folder) status, res = get_entity(request, Folder)
if not status: if not status:
return res return res
@ -176,7 +180,8 @@ def lyrics():
if not title: if not title:
return request.error_formatter(10, 'Missing title parameter') return request.error_formatter(10, 'Missing title parameter')
query = store.find(Track, Album.id == Track.album_id, Artist.id == Album.artist_id, Track.title.like(title), Artist.name.like(artist)) with db_session:
query = Track.select(lambda t: title in t.title and artist in t.artist.name)
for track in query: for track in query:
lyrics_path = os.path.splitext(track.path)[0] + '.txt' lyrics_path = os.path.splitext(track.path)[0] + '.txt'
if os.path.exists(lyrics_path): if os.path.exists(lyrics_path):

View File

@ -21,46 +21,49 @@
import uuid import uuid
from flask import request, current_app as app from flask import request, current_app as app
from storm.expr import Or from pony.orm import db_session, rollback
from pony.orm import ObjectNotFound
from ..db import Playlist, User, Track from ..db import Playlist, User, Track
from ..web import store
from . import get_entity from . import get_entity
@app.route('/rest/getPlaylists.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getPlaylists.view', methods = [ 'GET', 'POST' ])
def list_playlists(): def list_playlists():
query = store.find(Playlist, Or(Playlist.user_id == request.user.id, Playlist.public == True)).order_by(Playlist.name) query = Playlist.select(lambda p: p.user.id == request.user.id or p.public).order_by(Playlist.name)
username = request.values.get('username') username = request.values.get('username')
if username: if username:
if not request.user.admin: if not request.user.admin:
return request.error_formatter(50, 'Restricted to admins') return request.error_formatter(50, 'Restricted to admins')
user = store.find(User, User.name == username).one() with db_session:
if not user: user = User.get(name = username)
if user is None:
return request.error_formatter(70, 'No such user') return request.error_formatter(70, 'No such user')
query = store.find(Playlist, Playlist.user_id == User.id, User.name == username).order_by(Playlist.name) query = Playlist.select(lambda p: p.user.name == username).order_by(Playlist.name)
with db_session:
return request.formatter({ 'playlists': { 'playlist': [ p.as_subsonic_playlist(request.user) for p in query ] } }) return request.formatter({ 'playlists': { 'playlist': [ p.as_subsonic_playlist(request.user) for p in query ] } })
@app.route('/rest/getPlaylist.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/getPlaylist.view', methods = [ 'GET', 'POST' ])
@db_session
def show_playlist(): def show_playlist():
status, res = get_entity(request, Playlist) status, res = get_entity(request, Playlist)
if not status: if not status:
return res return res
if res.user_id != request.user.id and not request.user.admin: if res.user.id != request.user.id and not request.user.admin:
return request.error_formatter('50', 'Private playlist') return request.error_formatter('50', 'Private playlist')
info = res.as_subsonic_playlist(request.user) info = res.as_subsonic_playlist(request.user)
info['entry'] = [ t.as_subsonic_child(request.user, request.prefs) for t in res.get_tracks() ] info['entry'] = [ t.as_subsonic_child(request.user, request.client) for t in res.get_tracks() ]
return request.formatter({ 'playlist': info }) return request.formatter({ 'playlist': info })
@app.route('/rest/createPlaylist.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/createPlaylist.view', methods = [ 'GET', 'POST' ])
@db_session
def create_playlist(): def create_playlist():
# Only(?) method where the android client uses form data rather than GET params
playlist_id, name = map(request.values.get, [ 'playlistId', 'name' ]) playlist_id, name = map(request.values.get, [ 'playlistId', 'name' ])
# songId actually doesn't seem to be required # songId actually doesn't seem to be required
songs = request.values.getlist('songId') songs = request.values.getlist('songId')
@ -71,55 +74,54 @@ def create_playlist():
return request.error_formatter(0, 'Invalid parameter') return request.error_formatter(0, 'Invalid parameter')
if playlist_id: if playlist_id:
playlist = store.get(Playlist, playlist_id) try:
if not playlist: playlist = Playlist[playlist_id]
except ObjectNotFound:
return request.error_formatter(70, 'Unknwon playlist') return request.error_formatter(70, 'Unknwon playlist')
if playlist.user_id != request.user.id and not request.user.admin: if playlist.user.id != request.user.id and not request.user.admin:
return request.error_formatter(50, "You're not allowed to modify a playlist that isn't yours") return request.error_formatter(50, "You're not allowed to modify a playlist that isn't yours")
playlist.clear() playlist.clear()
if name: if name:
playlist.name = name playlist.name = name
elif name: elif name:
playlist = Playlist() playlist = Playlist(user = User[request.user.id], name = name)
playlist.user_id = request.user.id
playlist.name = name
store.add(playlist)
else: else:
return request.error_formatter(10, 'Missing playlist id or name') return request.error_formatter(10, 'Missing playlist id or name')
for sid in songs: for sid in songs:
track = store.get(Track, sid) try:
if not track: track = Track[sid]
store.rollback() except ObjectNotFound:
rollback()
return request.error_formatter(70, 'Unknown song') return request.error_formatter(70, 'Unknown song')
playlist.add(track) playlist.add(track)
store.commit()
return request.formatter({}) return request.formatter({})
@app.route('/rest/deletePlaylist.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/deletePlaylist.view', methods = [ 'GET', 'POST' ])
@db_session
def delete_playlist(): def delete_playlist():
status, res = get_entity(request, Playlist) status, res = get_entity(request, Playlist)
if not status: if not status:
return res return res
if res.user_id != request.user.id and not request.user.admin: if res.user.id != request.user.id and not request.user.admin:
return request.error_formatter(50, "You're not allowed to delete a playlist that isn't yours") return request.error_formatter(50, "You're not allowed to delete a playlist that isn't yours")
store.remove(res) res.delete()
store.commit()
return request.formatter({}) return request.formatter({})
@app.route('/rest/updatePlaylist.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/updatePlaylist.view', methods = [ 'GET', 'POST' ])
@db_session
def update_playlist(): def update_playlist():
status, res = get_entity(request, Playlist, 'playlistId') status, res = get_entity(request, Playlist, 'playlistId')
if not status: if not status:
return res return res
if res.user_id != request.user.id and not request.user.admin: if res.user.id != request.user.id and not request.user.admin:
return request.error_formatter(50, "You're not allowed to delete a playlist that isn't yours") return request.error_formatter(50, "You're not allowed to delete a playlist that isn't yours")
playlist = res playlist = res
@ -139,13 +141,13 @@ def update_playlist():
playlist.public = public in (True, 'True', 'true', 1, '1') playlist.public = public in (True, 'True', 'true', 1, '1')
for sid in to_add: for sid in to_add:
track = store.get(Track, sid) try:
if not track: track = Track[sid]
except ObjectNotFound:
return request.error_formatter(70, 'Unknown song') return request.error_formatter(70, 'Unknown song')
playlist.add(track) playlist.add(track)
playlist.remove_at_indexes(to_remove) playlist.remove_at_indexes(to_remove)
store.commit()
return request.formatter({}) return request.formatter({})

View File

@ -20,10 +20,9 @@
from datetime import datetime from datetime import datetime
from flask import request, current_app as app from flask import request, current_app as app
from storm.info import ClassAlias from pony.orm import db_session, select
from ..db import Folder, Track, Artist, Album from ..db import Folder, Track, Artist, Album
from ..web import store
@app.route('/rest/search.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/search.view', methods = [ 'GET', 'POST' ])
def old_search(): def old_search():
@ -38,33 +37,35 @@ def old_search():
min_date = datetime.fromtimestamp(newer_than) min_date = datetime.fromtimestamp(newer_than)
if artist: if artist:
parent = ClassAlias(Folder) query = select(t.folder.parent for t in Track if artist in t.folder.parent.name and t.folder.parent.created > min_date)
query = store.find(parent, Folder.parent_id == parent.id, Track.folder_id == Folder.id, parent.name.contains_string(artist), parent.created > min_date).config(distinct = True)
elif album: elif album:
query = store.find(Folder, Track.folder_id == Folder.id, Folder.name.contains_string(album), Folder.created > min_date).config(distinct = True) query = select(t.folder for t in Track if album in t.folder.name and t.folder.created > min_date)
elif title: elif title:
query = store.find(Track, Track.title.contains_string(title), Track.created > min_date) query = Track.select(lambda t: title in t.title and t.created > min_date)
elif anyf: elif anyf:
folders = store.find(Folder, Folder.name.contains_string(anyf), Folder.created > min_date) folders = Folder.select(lambda f: anyf in f.name and f.created > min_date)
tracks = store.find(Track, Track.title.contains_string(anyf), Track.created > min_date) tracks = Track.select(lambda t: anyf in t.title and t.created > min_date)
res = list(folders[offset : offset + count]) with db_session:
if offset + count > folders.count(): res = folders[offset : offset + count]
toff = max(0, offset - folders.count()) fcount = folders.count()
tend = offset + count - folders.count() if offset + count > fcount:
res += list(tracks[toff : tend]) toff = max(0, offset - fcount)
tend = offset + count - fcount
res += tracks[toff : tend]
return request.formatter({ 'searchResult': { return request.formatter({ 'searchResult': {
'totalHits': folders.count() + tracks.count(), 'totalHits': folders.count() + tracks.count(),
'offset': offset, 'offset': offset,
'match': [ r.as_subsonic_child(request.user) if isinstance(r, Folder) else r.as_subsonic_child(request.user, request.prefs) for r in res ] 'match': [ r.as_subsonic_child(request.user) if isinstance(r, Folder) else r.as_subsonic_child(request.user, request.client) for r in res ]
}}) }})
else: else:
return request.error_formatter(10, 'Missing search parameter') return request.error_formatter(10, 'Missing search parameter')
with db_session:
return request.formatter({ 'searchResult': { return request.formatter({ 'searchResult': {
'totalHits': query.count(), 'totalHits': query.count(),
'offset': offset, 'offset': offset,
'match': [ r.as_subsonic_child(request.user) if isinstance(r, Folder) else r.as_subsonic_child(request.user, request.prefs) for r in query[offset : offset + count] ] 'match': [ r.as_subsonic_child(request.user) if isinstance(r, Folder) else r.as_subsonic_child(request.user, request.client) for r in query[offset : offset + count] ]
}}) }})
@app.route('/rest/search2.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/search2.view', methods = [ 'GET', 'POST' ])
@ -85,15 +86,15 @@ def new_search():
if not query: if not query:
return request.error_formatter(10, 'Missing query parameter') return request.error_formatter(10, 'Missing query parameter')
parent = ClassAlias(Folder) with db_session:
artist_query = store.find(parent, Folder.parent_id == parent.id, Track.folder_id == Folder.id, parent.name.contains_string(query)).config(distinct = True, offset = artist_offset, limit = artist_count) artists = select(t.folder.parent for t in Track if query in t.folder.parent.name).limit(artist_count, artist_offset)
album_query = store.find(Folder, Track.folder_id == Folder.id, Folder.name.contains_string(query)).config(distinct = True, offset = album_offset, limit = album_count) albums = select(t.folder for t in Track if query in t.folder.name).limit(album_count, album_offset)
song_query = store.find(Track, Track.title.contains_string(query))[song_offset : song_offset + song_count] songs = Track.select(lambda t: query in t.title).limit(song_count, song_offset)
return request.formatter({ 'searchResult2': { return request.formatter({ 'searchResult2': {
'artist': [ { 'id': str(a.id), 'name': a.name } for a in artist_query ], 'artist': [ { 'id': str(a.id), 'name': a.name } for a in artists ],
'album': [ f.as_subsonic_child(request.user) for f in album_query ], 'album': [ f.as_subsonic_child(request.user) for f in albums ],
'song': [ t.as_subsonic_child(request.user, request.prefs) for t in song_query ] 'song': [ t.as_subsonic_child(request.user, request.client) for t in songs ]
}}) }})
@app.route('/rest/search3.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/search3.view', methods = [ 'GET', 'POST' ])
@ -114,13 +115,14 @@ def search_id3():
if not query: if not query:
return request.error_formatter(10, 'Missing query parameter') return request.error_formatter(10, 'Missing query parameter')
artist_query = store.find(Artist, Artist.name.contains_string(query))[artist_offset : artist_offset + artist_count] with db_session:
album_query = store.find(Album, Album.name.contains_string(query))[album_offset : album_offset + album_count] artists = Artist.select(lambda a: query in a.name).limit(artist_count, artist_offset)
song_query = store.find(Track, Track.title.contains_string(query))[song_offset : song_offset + song_count] albums = Album.select(lambda a: query in a.name).limit(album_count, album_offset)
songs = Track.select(lambda t: query in t.title).limit(song_count, song_offset)
return request.formatter({ 'searchResult3': { return request.formatter({ 'searchResult3': {
'artist': [ a.as_subsonic_artist(request.user) for a in artist_query ], 'artist': [ a.as_subsonic_artist(request.user) for a in artists ],
'album': [ a.as_subsonic_album(request.user) for a in album_query ], 'album': [ a.as_subsonic_album(request.user) for a in albums ],
'song': [ t.as_subsonic_child(request.user, request.prefs) for t in song_query ] 'song': [ t.as_subsonic_child(request.user, request.client) for t in songs ]
}}) }})

View File

@ -19,10 +19,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from flask import request, current_app as app from flask import request, current_app as app
from pony.orm import db_session
from ..db import User from ..db import User
from ..managers.user import UserManager from ..managers.user import UserManager
from ..web import store
from . import decode_password from . import decode_password
@ -35,7 +35,8 @@ def user_info():
if username != request.username and not request.user.admin: if username != request.username and not request.user.admin:
return request.error_formatter(50, 'Admin restricted') return request.error_formatter(50, 'Admin restricted')
user = store.find(User, User.name == username).one() with db_session:
user = User.get(name = username)
if user is None: if user is None:
return request.error_formatter(70, 'Unknown user') return request.error_formatter(70, 'Unknown user')
@ -46,7 +47,8 @@ def users_info():
if not request.user.admin: if not request.user.admin:
return request.error_formatter(50, 'Admin restricted') return request.error_formatter(50, 'Admin restricted')
return request.formatter({ 'users': { 'user': [ u.as_subsonic_user() for u in store.find(User) ] } }) with db_session:
return request.formatter({ 'users': { 'user': [ u.as_subsonic_user() for u in User.select() ] } })
@app.route('/rest/createUser.view', methods = [ 'GET', 'POST' ]) @app.route('/rest/createUser.view', methods = [ 'GET', 'POST' ])
def user_add(): def user_add():
@ -59,7 +61,7 @@ def user_add():
admin = True if admin in (True, 'True', 'true', 1, '1') else False admin = True if admin in (True, 'True', 'true', 1, '1') else False
password = decode_password(password) password = decode_password(password)
status = UserManager.add(store, username, password, email, admin) status = UserManager.add(username, password, email, admin)
if status == UserManager.NAME_EXISTS: if status == UserManager.NAME_EXISTS:
return request.error_formatter(0, 'There is already a user with that username') return request.error_formatter(0, 'There is already a user with that username')
@ -74,11 +76,12 @@ def user_del():
if not username: if not username:
return request.error_formatter(10, 'Missing parameter') return request.error_formatter(10, 'Missing parameter')
user = store.find(User, User.name == username).one() with db_session:
if not user: user = User.get(name = username)
if user is None:
return request.error_formatter(70, 'Unknown user') return request.error_formatter(70, 'Unknown user')
status = UserManager.delete(store, user.id) status = UserManager.delete(user.id)
if status != UserManager.SUCCESS: if status != UserManager.SUCCESS:
return request.error_formatter(0, UserManager.error_str(status)) return request.error_formatter(0, UserManager.error_str(status))
@ -94,7 +97,7 @@ def user_changepass():
return request.error_formatter(50, 'Admin restricted') return request.error_formatter(50, 'Admin restricted')
password = decode_password(password) password = decode_password(password)
status = UserManager.change_password2(store, username, password) status = UserManager.change_password2(username, password)
if status != UserManager.SUCCESS: if status != UserManager.SUCCESS:
code = 0 code = 0
if status == UserManager.NO_SUCH_USER: if status == UserManager.NO_SUCH_USER:

View File

@ -25,7 +25,9 @@ import getpass
import sys import sys
import time import time
from .db import get_store, Folder, User from pony.orm import db_session
from .db import Folder, User
from .managers.folder import FolderManager from .managers.folder import FolderManager
from .managers.user import UserManager from .managers.user import UserManager
from .scanner import Scanner from .scanner import Scanner
@ -105,8 +107,6 @@ class SupysonicCLI(cmd.Cmd):
for action, subparser in getattr(self.__class__, command + '_subparsers').choices.iteritems(): for action, subparser in getattr(self.__class__, command + '_subparsers').choices.iteritems():
setattr(self, 'help_{} {}'.format(command, action), subparser.print_help) setattr(self, 'help_{} {}'.format(command, action), subparser.print_help)
self.__store = get_store(config.BASE['database_uri'])
def write_line(self, line = ''): def write_line(self, line = ''):
self.stdout.write(line + '\n') self.stdout.write(line + '\n')
@ -148,44 +148,49 @@ class SupysonicCLI(cmd.Cmd):
folder_scan_parser.add_argument('folders', metavar = 'folder', nargs = '*', help = 'Folder(s) to be scanned. If ommitted, all folders are scanned') folder_scan_parser.add_argument('folders', metavar = 'folder', nargs = '*', help = 'Folder(s) to be scanned. If ommitted, all folders are scanned')
folder_scan_parser.add_argument('-f', '--force', action = 'store_true', help = "Force scan of already know files even if they haven't changed") folder_scan_parser.add_argument('-f', '--force', action = 'store_true', help = "Force scan of already know files even if they haven't changed")
@db_session
def folder_list(self): def folder_list(self):
self.write_line('Name\t\tPath\n----\t\t----') self.write_line('Name\t\tPath\n----\t\t----')
self.write_line('\n'.join('{0: <16}{1}'.format(f.name, f.path) for f in self.__store.find(Folder, Folder.root == True))) self.write_line('\n'.join('{0: <16}{1}'.format(f.name, f.path) for f in Folder.select(lambda f: f.root)))
def folder_add(self, name, path): def folder_add(self, name, path):
ret = FolderManager.add(self.__store, name, path) ret = FolderManager.add(name, path)
if ret != FolderManager.SUCCESS: if ret != FolderManager.SUCCESS:
self.write_error_line(FolderManager.error_str(ret)) self.write_error_line(FolderManager.error_str(ret))
else: else:
self.write_line("Folder '{}' added".format(name)) self.write_line("Folder '{}' added".format(name))
def folder_delete(self, name): def folder_delete(self, name):
ret = FolderManager.delete_by_name(self.__store, name) ret = FolderManager.delete_by_name(name)
if ret != FolderManager.SUCCESS: if ret != FolderManager.SUCCESS:
self.write_error_line(FolderManager.error_str(ret)) self.write_error_line(FolderManager.error_str(ret))
else: else:
self.write_line("Deleted folder '{}'".format(name)) self.write_line("Deleted folder '{}'".format(name))
@db_session
def folder_scan(self, folders, force): def folder_scan(self, folders, force):
extensions = self.__config.BASE['scanner_extensions'] extensions = self.__config.BASE['scanner_extensions']
if extensions: if extensions:
extensions = extensions.split(' ') extensions = extensions.split(' ')
scanner = Scanner(self.__store, force = force, extensions = extensions)
scanner = Scanner(force = force, extensions = extensions)
if folders: if folders:
folders = map(lambda n: self.__store.find(Folder, Folder.name == n, Folder.root == True).one() or n, folders) fstrs = folders
if any(map(lambda f: isinstance(f, basestring), folders)): folders = Folder.select(lambda f: f.root and f.name in fstrs)[:]
self.write_line("No such folder(s): " + ' '.join(f for f in folders if isinstance(f, basestring))) notfound = set(fstrs) - set(map(lambda f: f.name, folders))
for folder in filter(lambda f: isinstance(f, Folder), folders): if notfound:
self.write_line("No such folder(s): " + ' '.join(notfound))
for folder in folders:
scanner.scan(folder, TimedProgressDisplay(folder.name, self.stdout)) scanner.scan(folder, TimedProgressDisplay(folder.name, self.stdout))
self.write_line() self.write_line()
else: else:
for folder in self.__store.find(Folder, Folder.root == True): for folder in Folder.select(lambda f: f.root):
scanner.scan(folder, TimedProgressDisplay(folder.name, self.stdout)) scanner.scan(folder, TimedProgressDisplay(folder.name, self.stdout))
self.write_line() self.write_line()
scanner.finish() scanner.finish()
added, deleted = scanner.stats() added, deleted = scanner.stats()
self.__store.commit()
self.write_line("Scanning done") self.write_line("Scanning done")
self.write_line('Added: %i artists, %i albums, %i tracks' % (added[0], added[1], added[2])) self.write_line('Added: %i artists, %i albums, %i tracks' % (added[0], added[1], added[2]))
@ -208,9 +213,10 @@ class SupysonicCLI(cmd.Cmd):
user_pass_parser.add_argument('name', help = 'Name/login of the user to which change the password') user_pass_parser.add_argument('name', help = 'Name/login of the user to which change the password')
user_pass_parser.add_argument('password', nargs = '?', help = 'New password') user_pass_parser.add_argument('password', nargs = '?', help = 'New password')
@db_session
def user_list(self): def user_list(self):
self.write_line('Name\t\tAdmin\tEmail\n----\t\t-----\t-----') self.write_line('Name\t\tAdmin\tEmail\n----\t\t-----\t-----')
self.write_line('\n'.join('{0: <16}{1}\t{2}'.format(u.name, '*' if u.admin else '', u.mail) for u in self.__store.find(User))) self.write_line('\n'.join('{0: <16}{1}\t{2}'.format(u.name, '*' if u.admin else '', u.mail) for u in User.select()))
def user_add(self, name, admin, password, email): def user_add(self, name, admin, password, email):
if not password: if not password:
@ -219,24 +225,24 @@ class SupysonicCLI(cmd.Cmd):
if password != confirm: if password != confirm:
self.write_error_line("Passwords don't match") self.write_error_line("Passwords don't match")
return return
status = UserManager.add(self.__store, name, password, email, admin) status = UserManager.add(name, password, email, admin)
if status != UserManager.SUCCESS: if status != UserManager.SUCCESS:
self.write_error_line(UserManager.error_str(status)) self.write_error_line(UserManager.error_str(status))
def user_delete(self, name): def user_delete(self, name):
ret = UserManager.delete_by_name(self.__store, name) ret = UserManager.delete_by_name(name)
if ret != UserManager.SUCCESS: if ret != UserManager.SUCCESS:
self.write_error_line(UserManager.error_str(ret)) self.write_error_line(UserManager.error_str(ret))
else: else:
self.write_line("Deleted user '{}'".format(name)) self.write_line("Deleted user '{}'".format(name))
@db_session
def user_setadmin(self, name, off): def user_setadmin(self, name, off):
user = self.__store.find(User, User.name == name).one() user = User.get(name = name)
if not user: if user is None:
self.write_error_line('No such user') self.write_error_line('No such user')
else: else:
user.admin = not off user.admin = not off
self.__store.commit()
self.write_line("{0} '{1}' admin rights".format('Revoked' if off else 'Granted', name)) self.write_line("{0} '{1}' admin rights".format('Revoked' if off else 'Granted', name))
def user_changepass(self, name, password): def user_changepass(self, name, password):
@ -246,7 +252,7 @@ class SupysonicCLI(cmd.Cmd):
if password != confirm: if password != confirm:
self.write_error_line("Passwords don't match") self.write_error_line("Passwords don't match")
return return
status = UserManager.change_password2(self.__store, name, password) status = UserManager.change_password2(name, password)
if status != UserManager.SUCCESS: if status != UserManager.SUCCESS:
self.write_error_line(UserManager.error_str(status)) self.write_error_line(UserManager.error_str(status))
else: else:

View File

@ -20,7 +20,7 @@ class DefaultConfig(object):
tempdir = os.path.join(tempfile.gettempdir(), 'supysonic') tempdir = os.path.join(tempfile.gettempdir(), 'supysonic')
BASE = { BASE = {
'database_uri': 'sqlite://' + os.path.join(tempdir, 'supysonic.db'), 'database_uri': 'sqlite:///' + os.path.join(tempdir, 'supysonic.db'),
'scanner_extensions': None 'scanner_extensions': None
} }
WEBAPP = { WEBAPP = {

View File

@ -18,45 +18,41 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from storm.properties import * import time
from storm.references import Reference, ReferenceSet
from storm.database import create_database
from storm.store import Store
from storm.variables import Variable
import uuid, datetime, time
import mimetypes import mimetypes
import os.path import os.path
from datetime import datetime
from pony.orm import Database, Required, Optional, Set, PrimaryKey, LongStr
from pony.orm import ObjectNotFound
from pony.orm import min, max, avg, sum
from urlparse import urlparse
from uuid import UUID, uuid4
def now(): def now():
return datetime.datetime.now().replace(microsecond = 0) return datetime.now().replace(microsecond = 0)
class UnicodeOrStrVariable(Variable): db = Database()
__slots__ = ()
def parse_set(self, value, from_db): class Folder(db.Entity):
if isinstance(value, unicode): _table_ = 'folder'
return value
elif isinstance(value, str):
return unicode(value)
raise TypeError("Expected unicode, found %r: %r" % (type(value), value))
Unicode.variable_class = UnicodeOrStrVariable id = PrimaryKey(UUID, default = uuid4)
root = Required(bool, default = False)
name = Required(str)
path = Required(str, 4096) # unique
created = Required(datetime, precision = 0, default = now)
has_cover_art = Required(bool, default = False)
last_scan = Required(int, default = 0)
class Folder(object): parent = Optional(lambda: Folder, reverse = 'children', column = 'parent_id')
__storm_table__ = 'folder' children = Set(lambda: Folder, reverse = 'parent')
id = UUID(primary = True, default_factory = uuid.uuid4) __alltracks = Set(lambda: Track, lazy = True, reverse = 'root_folder') # Never used, hide it. Could be huge, lazy load
root = Bool(default = False) tracks = Set(lambda: Track, reverse = 'folder')
name = Unicode()
path = Unicode() # unique
created = DateTime(default_factory = now)
has_cover_art = Bool(default = False)
last_scan = Int(default = 0)
parent_id = UUID() # nullable stars = Set(lambda: StarredFolder)
parent = Reference(parent_id, id) ratings = Set(lambda: RatingFolder)
children = ReferenceSet(id, parent_id)
def as_subsonic_child(self, user): def as_subsonic_child(self, user):
info = { info = {
@ -67,29 +63,36 @@ class Folder(object):
'created': self.created.isoformat() 'created': self.created.isoformat()
} }
if not self.root: if not self.root:
info['parent'] = str(self.parent_id) info['parent'] = str(self.parent.id)
info['artist'] = self.parent.name info['artist'] = self.parent.name
if self.has_cover_art: if self.has_cover_art:
info['coverArt'] = str(self.id) info['coverArt'] = str(self.id)
starred = Store.of(self).get(StarredFolder, (user.id, self.id)) try:
if starred: starred = StarredFolder[user.id, self.id]
info['starred'] = starred.date.isoformat() info['starred'] = starred.date.isoformat()
except ObjectNotFound: pass
rating = Store.of(self).get(RatingFolder, (user.id, self.id)) try:
if rating: rating = RatingFolder[user.id, self.id]
info['userRating'] = rating.rating info['userRating'] = rating.rating
avgRating = Store.of(self).find(RatingFolder, RatingFolder.rated_id == self.id).avg(RatingFolder.rating) except ObjectNotFound: pass
avgRating = avg(self.ratings.rating)
if avgRating: if avgRating:
info['averageRating'] = avgRating info['averageRating'] = avgRating
return info return info
class Artist(object): class Artist(db.Entity):
__storm_table__ = 'artist' _table_ = 'artist'
id = UUID(primary = True, default_factory = uuid.uuid4) id = PrimaryKey(UUID, default = uuid4)
name = Unicode() # unique name = Required(str) # unique
albums = Set(lambda: Album)
tracks = Set(lambda: Track)
stars = Set(lambda: StarredArtist)
def as_subsonic_artist(self, user): def as_subsonic_artist(self, user):
info = { info = {
@ -99,38 +102,42 @@ class Artist(object):
'albumCount': self.albums.count() 'albumCount': self.albums.count()
} }
starred = Store.of(self).get(StarredArtist, (user.id, self.id)) try:
if starred: starred = StarredArtist[user.id, self.id]
info['starred'] = starred.date.isoformat() info['starred'] = starred.date.isoformat()
except ObjectNotFound: pass
return info return info
class Album(object): class Album(db.Entity):
__storm_table__ = 'album' _table_ = 'album'
id = UUID(primary = True, default_factory = uuid.uuid4) id = PrimaryKey(UUID, default = uuid4)
name = Unicode() name = Required(str)
artist_id = UUID() artist = Required(Artist, column = 'artist_id')
artist = Reference(artist_id, Artist.id) tracks = Set(lambda: Track)
stars = Set(lambda: StarredAlbum)
def as_subsonic_album(self, user): def as_subsonic_album(self, user):
info = { info = {
'id': str(self.id), 'id': str(self.id),
'name': self.name, 'name': self.name,
'artist': self.artist.name, 'artist': self.artist.name,
'artistId': str(self.artist_id), 'artistId': str(self.artist.id),
'songCount': self.tracks.count(), 'songCount': self.tracks.count(),
'duration': sum(self.tracks.values(Track.duration)), 'duration': sum(self.tracks.duration),
'created': min(self.tracks.values(Track.created)).isoformat() 'created': min(self.tracks.created).isoformat()
} }
track_with_cover = self.tracks.find(Track.folder_id == Folder.id, Folder.has_cover_art).any() track_with_cover = self.tracks.select(lambda t: t.folder.has_cover_art).first()
if track_with_cover: if track_with_cover is not None:
info['coverArt'] = str(track_with_cover.folder_id) info['coverArt'] = str(track_with_cover.folder.id)
starred = Store.of(self).get(StarredAlbum, (user.id, self.id)) try:
if starred: starred = StarredAlbum[user.id, self.id]
info['starred'] = starred.date.isoformat() info['starred'] = starred.date.isoformat()
except ObjectNotFound: pass
return info return info
@ -138,41 +145,42 @@ class Album(object):
year = min(map(lambda t: t.year if t.year else 9999, self.tracks)) year = min(map(lambda t: t.year if t.year else 9999, self.tracks))
return '%i%s' % (year, self.name.lower()) return '%i%s' % (year, self.name.lower())
Artist.albums = ReferenceSet(Artist.id, Album.artist_id) class Track(db.Entity):
_table_ = 'track'
class Track(object): id = PrimaryKey(UUID, default = uuid4)
__storm_table__ = 'track' disc = Required(int)
number = Required(int)
title = Required(str)
year = Optional(int)
genre = Optional(str, nullable = True)
duration = Required(int)
id = UUID(primary = True, default_factory = uuid.uuid4) album = Required(Album, column = 'album_id')
disc = Int() artist = Required(Artist, column = 'artist_id')
number = Int()
title = Unicode()
year = Int() # nullable
genre = Unicode() # nullable
duration = Int()
album_id = UUID()
album = Reference(album_id, Album.id)
artist_id = UUID()
artist = Reference(artist_id, Artist.id)
bitrate = Int()
path = Unicode() # unique bitrate = Required(int)
content_type = Unicode()
created = DateTime(default_factory = now)
last_modification = Int()
play_count = Int(default = 0) path = Required(str, 4096) # unique
last_play = DateTime() # nullable content_type = Required(str)
created = Required(datetime, precision = 0, default = now)
last_modification = Required(int)
root_folder_id = UUID() play_count = Required(int, default = 0)
root_folder = Reference(root_folder_id, Folder.id) last_play = Optional(datetime, precision = 0)
folder_id = UUID()
folder = Reference(folder_id, Folder.id)
def as_subsonic_child(self, user, prefs): root_folder = Required(Folder, column = 'root_folder_id')
folder = Required(Folder, column = 'folder_id')
__lastly_played_by = Set(lambda: User) # Never used, hide it
stars = Set(lambda: StarredTrack)
ratings = Set(lambda: RatingTrack)
def as_subsonic_child(self, user, client):
info = { info = {
'id': str(self.id), 'id': str(self.id),
'parent': str(self.folder_id), 'parent': str(self.folder.id),
'isDir': False, 'isDir': False,
'title': self.title, 'title': self.title,
'album': self.album.name, 'album': self.album.name,
@ -187,8 +195,8 @@ class Track(object):
'isVideo': False, 'isVideo': False,
'discNumber': self.disc, 'discNumber': self.disc,
'created': self.created.isoformat(), 'created': self.created.isoformat(),
'albumId': str(self.album_id), 'albumId': str(self.album.id),
'artistId': str(self.artist_id), 'artistId': str(self.artist.id),
'type': 'music' 'type': 'music'
} }
@ -197,20 +205,24 @@ class Track(object):
if self.genre: if self.genre:
info['genre'] = self.genre info['genre'] = self.genre
if self.folder.has_cover_art: if self.folder.has_cover_art:
info['coverArt'] = str(self.folder_id) info['coverArt'] = str(self.folder.id)
starred = Store.of(self).get(StarredTrack, (user.id, self.id)) try:
if starred: starred = StarredTrack[user.id, self.id]
info['starred'] = starred.date.isoformat() info['starred'] = starred.date.isoformat()
except ObjectNotFound: pass
rating = Store.of(self).get(RatingTrack, (user.id, self.id)) try:
if rating: rating = RatingTrack[user.id, self.id]
info['userRating'] = rating.rating info['userRating'] = rating.rating
avgRating = Store.of(self).find(RatingTrack, RatingTrack.rated_id == self.id).avg(RatingTrack.rating) except ObjectNotFound: pass
avgRating = avg(self.ratings.rating)
if avgRating: if avgRating:
info['averageRating'] = avgRating info['averageRating'] = avgRating
if prefs and prefs.format and prefs.format != self.suffix(): prefs = ClientPrefs.get(lambda p: p.user.id == user.id and p.client_name == client)
if prefs is not None and prefs.format is not None and prefs.format != self.suffix():
info['transcodedSuffix'] = prefs.format info['transcodedSuffix'] = prefs.format
info['transcodedContentType'] = mimetypes.guess_type('dummyname.' + prefs.format, False)[0] or 'application/octet-stream' info['transcodedContentType'] = mimetypes.guess_type('dummyname.' + prefs.format, False)[0] or 'application/octet-stream'
@ -228,25 +240,31 @@ class Track(object):
def sort_key(self): def sort_key(self):
return (self.album.artist.name + self.album.name + ("%02i" % self.disc) + ("%02i" % self.number) + self.title).lower() return (self.album.artist.name + self.album.name + ("%02i" % self.disc) + ("%02i" % self.number) + self.title).lower()
Folder.tracks = ReferenceSet(Folder.id, Track.folder_id) class User(db.Entity):
Album.tracks = ReferenceSet(Album.id, Track.album_id) _table_ = 'user'
Artist.tracks = ReferenceSet(Artist.id, Track.artist_id)
class User(object): id = PrimaryKey(UUID, default = uuid4)
__storm_table__ = 'user' name = Required(str, 64) # unique
mail = Optional(str)
password = Required(str, 40)
salt = Required(str, 6)
admin = Required(bool, default = False)
lastfm_session = Optional(str, 32, nullable = True)
lastfm_status = Required(bool, default = True) # True: ok/unlinked, False: invalid session
id = UUID(primary = True, default_factory = uuid.uuid4) last_play = Optional(Track, column = 'last_play_id')
name = Unicode() # unique last_play_date = Optional(datetime, precision = 0)
mail = Unicode()
password = Unicode()
salt = Unicode()
admin = Bool(default = False)
lastfm_session = Unicode() # nullable
lastfm_status = Bool(default = True) # True: ok/unlinked, False: invalid session
last_play_id = UUID() # nullable clients = Set(lambda: ClientPrefs)
last_play = Reference(last_play_id, Track.id) playlists = Set(lambda: Playlist)
last_play_date = DateTime() # nullable __messages = Set(lambda: ChatMessage, lazy = True) # Never used, hide it
starred_folders = Set(lambda: StarredFolder, lazy = True)
starred_artists = Set(lambda: StarredArtist, lazy = True)
starred_albums = Set(lambda: StarredAlbum, lazy = True)
starred_tracks = Set(lambda: StarredTrack, lazy = True)
folder_ratings = Set(lambda: RatingFolder, lazy = True)
track_ratings = Set(lambda: RatingTrack, lazy = True)
def as_subsonic_user(self): def as_subsonic_user(self):
return { return {
@ -266,72 +284,74 @@ class User(object):
'shareRole': False 'shareRole': False
} }
class ClientPrefs(object): class ClientPrefs(db.Entity):
__storm_table__ = 'client_prefs' _table_ = 'client_prefs'
__storm_primary__ = 'user_id', 'client_name'
user_id = UUID() user = Required(User, column = 'user_id')
client_name = Unicode() client_name = Required(str, 32)
format = Unicode() # nullable PrimaryKey(user, client_name)
bitrate = Int() # nullable format = Optional(str, 8)
bitrate = Optional(int)
class BaseStarred(object): class StarredFolder(db.Entity):
__storm_primary__ = 'user_id', 'starred_id' _table_ = 'starred_folder'
user_id = UUID() user = Required(User, column = 'user_id')
starred_id = UUID() starred = Required(Folder, column = 'starred_id')
date = DateTime(default_factory = now) date = Required(datetime, precision = 0, default = now)
user = Reference(user_id, User.id) PrimaryKey(user, starred)
class StarredFolder(BaseStarred): class StarredArtist(db.Entity):
__storm_table__ = 'starred_folder' _table_ = 'starred_artist'
starred = Reference(BaseStarred.starred_id, Folder.id) user = Required(User, column = 'user_id')
starred = Required(Artist, column = 'starred_id')
date = Required(datetime, precision = 0, default = now)
class StarredArtist(BaseStarred): PrimaryKey(user, starred)
__storm_table__ = 'starred_artist'
starred = Reference(BaseStarred.starred_id, Artist.id) class StarredAlbum(db.Entity):
_table_ = 'starred_album'
class StarredAlbum(BaseStarred): user = Required(User, column = 'user_id')
__storm_table__ = 'starred_album' starred = Required(Album, column = 'starred_id')
date = Required(datetime, precision = 0, default = now)
starred = Reference(BaseStarred.starred_id, Album.id) PrimaryKey(user, starred)
class StarredTrack(BaseStarred): class StarredTrack(db.Entity):
__storm_table__ = 'starred_track' _table_ = 'starred_track'
starred = Reference(BaseStarred.starred_id, Track.id) user = Required(User, column = 'user_id')
starred = Required(Track, column = 'starred_id')
date = Required(datetime, precision = 0, default = now)
class BaseRating(object): PrimaryKey(user, starred)
__storm_primary__ = 'user_id', 'rated_id'
user_id = UUID() class RatingFolder(db.Entity):
rated_id = UUID() _table_ = 'rating_folder'
rating = Int() user = Required(User, column = 'user_id')
rated = Required(Folder, column = 'rated_id')
rating = Required(int, min = 1, max = 5)
user = Reference(user_id, User.id) PrimaryKey(user, rated)
class RatingFolder(BaseRating): class RatingTrack(db.Entity):
__storm_table__ = 'rating_folder' _table_ = 'rating_track'
user = Required(User, column = 'user_id')
rated = Required(Track, column = 'rated_id')
rating = Required(int, min = 1, max = 5)
rated = Reference(BaseRating.rated_id, Folder.id) PrimaryKey(user, rated)
class RatingTrack(BaseRating): class ChatMessage(db.Entity):
__storm_table__ = 'rating_track' _table_ = 'chat_message'
rated = Reference(BaseRating.rated_id, Track.id) id = PrimaryKey(UUID, default = uuid4)
user = Required(User, column = 'user_id')
class ChatMessage(object): time = Required(int, default = lambda: int(time.time()))
__storm_table__ = 'chat_message' message = Required(str, 512)
id = UUID(primary = True, default_factory = uuid.uuid4)
user_id = UUID()
time = Int(default_factory = lambda: int(time.time()))
message = Unicode()
user = Reference(user_id, User.id)
def responsize(self): def responsize(self):
return { return {
@ -340,24 +360,22 @@ class ChatMessage(object):
'message': self.message 'message': self.message
} }
class Playlist(object): class Playlist(db.Entity):
__storm_table__ = 'playlist' _table_ = 'playlist'
id = UUID(primary = True, default_factory = uuid.uuid4) id = PrimaryKey(UUID, default = uuid4)
user_id = UUID() user = Required(User, column = 'user_id')
name = Unicode() name = Required(str)
comment = Unicode() # nullable comment = Optional(str)
public = Bool(default = False) public = Required(bool, default = False)
created = DateTime(default_factory = now) created = Required(datetime, precision = 0, default = now)
tracks = Unicode() tracks = Optional(LongStr)
user = Reference(user_id, User.id)
def as_subsonic_playlist(self, user): def as_subsonic_playlist(self, user):
tracks = self.get_tracks() tracks = self.get_tracks()
info = { info = {
'id': str(self.id), 'id': str(self.id),
'name': self.name if self.user_id == user.id else '[%s] %s' % (self.user.name, self.name), 'name': self.name if self.user.id == user.id else '[%s] %s' % (self.user.name, self.name),
'owner': self.user.name, 'owner': self.user.name,
'public': self.public, 'public': self.public,
'songCount': len(tracks), 'songCount': len(tracks),
@ -374,38 +392,34 @@ class Playlist(object):
tracks = [] tracks = []
should_fix = False should_fix = False
store = Store.of(self)
for t in self.tracks.split(','): for t in self.tracks.split(','):
try: try:
tid = uuid.UUID(t) tid = UUID(t)
track = store.get(Track, tid) track = Track[tid]
if track:
tracks.append(track) tracks.append(track)
else:
should_fix = True
except: except:
should_fix = True should_fix = True
if should_fix: if should_fix:
self.tracks = ','.join(map(lambda t: str(t.id), tracks)) self.tracks = ','.join(map(lambda t: str(t.id), tracks))
store.commit() db.commit()
return tracks return tracks
def clear(self): def clear(self):
self.tracks = "" self.tracks = ''
def add(self, track): def add(self, track):
if isinstance(track, uuid.UUID): if isinstance(track, UUID):
tid = track tid = track
elif isinstance(track, Track): elif isinstance(track, Track):
tid = track.id tid = track.id
elif isinstance(track, basestring): elif isinstance(track, basestring):
tid = uuid.UUID(track) tid = UUID(track)
if self.tracks and len(self.tracks) > 0: if self.tracks and len(self.tracks) > 0:
self.tracks = "{},{}".format(self.tracks, tid) self.tracks = '{},{}'.format(self.tracks, tid)
else: else:
self.tracks = str(tid) self.tracks = str(tid)
@ -418,8 +432,31 @@ class Playlist(object):
self.tracks = ','.join(t for t in tracks if t) self.tracks = ','.join(t for t in tracks if t)
def get_store(database_uri): def parse_uri(database_uri):
database = create_database(database_uri) if not isinstance(database_uri, basestring):
store = Store(database) raise TypeError('Expecting a string')
return store
uri = urlparse(database_uri)
if uri.scheme == 'sqlite':
path = uri.path
if not path:
path = ':memory:'
elif path[0] == '/':
path = path[1:]
return dict(provider = 'sqlite', filename = path)
elif uri.scheme in ('postgres', 'postgresql'):
return dict(provider = 'postgres', user = uri.username, password = uri.password, host = uri.hostname, database = uri.path[1:])
elif uri.scheme == 'mysql':
return dict(provider = 'mysql', user = uri.username, passwd = uri.password, host = uri.hostname, db = uri.path[1:])
return dict()
def init_database(database_uri, create_tables = False):
db.bind(**parse_uri(database_uri))
db.generate_mapping(create_tables = create_tables)
def release_database():
db.disconnect()
db.provider = None
db.schema = None

View File

@ -11,8 +11,8 @@
from flask import session, request, redirect, url_for, current_app as app from flask import session, request, redirect, url_for, current_app as app
from functools import wraps from functools import wraps
from pony.orm import db_session
from ..web import store
from ..db import Artist, Album, Track from ..db import Artist, Album, Track
from ..managers.user import UserManager from ..managers.user import UserManager
@ -27,7 +27,7 @@ def login_check():
request.user = None request.user = None
should_login = True should_login = True
if session.get('userid'): if session.get('userid'):
code, user = UserManager.get(store, session.get('userid')) code, user = UserManager.get(session.get('userid'))
if code != UserManager.SUCCESS: if code != UserManager.SUCCESS:
session.clear() session.clear()
else: else:
@ -39,13 +39,14 @@ def login_check():
return redirect(url_for('login', returnUrl = request.script_root + request.url[len(request.url_root)-1:])) return redirect(url_for('login', returnUrl = request.script_root + request.url[len(request.url_root)-1:]))
@app.route('/') @app.route('/')
@db_session
def index(): def index():
stats = { stats = {
'artists': store.find(Artist).count(), 'artists': Artist.select().count(),
'albums': store.find(Album).count(), 'albums': Album.select().count(),
'tracks': store.find(Track).count() 'tracks': Track.select().count()
} }
return render_template('home.html', stats = stats, admin = UserManager.get(store, session.get('userid'))[1].admin) return render_template('home.html', stats = stats)
def admin_only(f): def admin_only(f):
@wraps(f) @wraps(f)

View File

@ -22,19 +22,20 @@ import os.path
import uuid import uuid
from flask import request, flash, render_template, redirect, url_for, current_app as app from flask import request, flash, render_template, redirect, url_for, current_app as app
from pony.orm import db_session
from ..db import Folder from ..db import Folder
from ..managers.user import UserManager from ..managers.user import UserManager
from ..managers.folder import FolderManager from ..managers.folder import FolderManager
from ..scanner import Scanner from ..scanner import Scanner
from ..web import store
from . import admin_only from . import admin_only
@app.route('/folder') @app.route('/folder')
@admin_only @admin_only
@db_session
def folder_index(): def folder_index():
return render_template('folders.html', folders = store.find(Folder, Folder.root == True)) return render_template('folders.html', folders = Folder.select(lambda f: f.root))
@app.route('/folder/add') @app.route('/folder/add')
@admin_only @admin_only
@ -55,7 +56,7 @@ def add_folder_post():
if error: if error:
return render_template('addfolder.html') return render_template('addfolder.html')
ret = FolderManager.add(store, name, path) ret = FolderManager.add(name, path)
if ret != FolderManager.SUCCESS: if ret != FolderManager.SUCCESS:
flash(FolderManager.error_str(ret)) flash(FolderManager.error_str(ret))
return render_template('addfolder.html') return render_template('addfolder.html')
@ -73,7 +74,7 @@ def del_folder(id):
flash('Invalid folder id') flash('Invalid folder id')
return redirect(url_for('folder_index')) return redirect(url_for('folder_index'))
ret = FolderManager.delete(store, idid) ret = FolderManager.delete(idid)
if ret != FolderManager.SUCCESS: if ret != FolderManager.SUCCESS:
flash(FolderManager.error_str(ret)) flash(FolderManager.error_str(ret))
else: else:
@ -84,16 +85,19 @@ def del_folder(id):
@app.route('/folder/scan') @app.route('/folder/scan')
@app.route('/folder/scan/<id>') @app.route('/folder/scan/<id>')
@admin_only @admin_only
@db_session
def scan_folder(id = None): def scan_folder(id = None):
extensions = app.config['BASE']['scanner_extensions'] extensions = app.config['BASE']['scanner_extensions']
if extensions: if extensions:
extensions = extensions.split(' ') extensions = extensions.split(' ')
scanner = Scanner(store, extensions = extensions)
scanner = Scanner(extensions = extensions)
if id is None: if id is None:
for folder in store.find(Folder, Folder.root == True): for folder in Folder.select(lambda f: f.root):
scanner.scan(folder) scanner.scan(folder)
else: else:
status, folder = FolderManager.get(store, id) status, folder = FolderManager.get(id)
if status != FolderManager.SUCCESS: if status != FolderManager.SUCCESS:
flash(FolderManager.error_str(status)) flash(FolderManager.error_str(status))
return redirect(url_for('folder_index')) return redirect(url_for('folder_index'))
@ -101,7 +105,6 @@ def scan_folder(id = None):
scanner.finish() scanner.finish()
added, deleted = scanner.stats() added, deleted = scanner.stats()
store.commit()
flash('Added: %i artists, %i albums, %i tracks' % (added[0], added[1], added[2])) flash('Added: %i artists, %i albums, %i tracks' % (added[0], added[1], added[2]))
flash('Deleted: %i artists, %i albums, %i tracks' % (deleted[0], deleted[1], deleted[2])) flash('Deleted: %i artists, %i albums, %i tracks' % (deleted[0], deleted[1], deleted[2]))

View File

@ -21,33 +21,38 @@
import uuid import uuid
from flask import request, flash, render_template, redirect, url_for, current_app as app from flask import request, flash, render_template, redirect, url_for, current_app as app
from pony.orm import db_session
from pony.orm import ObjectNotFound
from ..web import store
from ..db import Playlist from ..db import Playlist
from ..managers.user import UserManager from ..managers.user import UserManager
@app.route('/playlist') @app.route('/playlist')
@db_session
def playlist_index(): def playlist_index():
return render_template('playlists.html', return render_template('playlists.html',
mine = store.find(Playlist, Playlist.user_id == request.user.id), mine = Playlist.select(lambda p: p.user == request.user),
others = store.find(Playlist, Playlist.user_id != request.user.id, Playlist.public == True)) others = Playlist.select(lambda p: p.user != request.user and p.public))
@app.route('/playlist/<uid>') @app.route('/playlist/<uid>')
@db_session
def playlist_details(uid): def playlist_details(uid):
try: try:
uid = uuid.UUID(uid) if type(uid) in (str, unicode) else uid uid = uuid.UUID(uid)
except: except:
flash('Invalid playlist id') flash('Invalid playlist id')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))
playlist = store.get(Playlist, uid) try:
if not playlist: playlist = Playlist[uid]
except ObjectNotFound:
flash('Unknown playlist') flash('Unknown playlist')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))
return render_template('playlist.html', playlist = playlist) return render_template('playlist.html', playlist = playlist)
@app.route('/playlist/<uid>', methods = [ 'POST' ]) @app.route('/playlist/<uid>', methods = [ 'POST' ])
@db_session
def playlist_update(uid): def playlist_update(uid):
try: try:
uid = uuid.UUID(uid) uid = uuid.UUID(uid)
@ -55,24 +60,25 @@ def playlist_update(uid):
flash('Invalid playlist id') flash('Invalid playlist id')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))
playlist = store.get(Playlist, uid) try:
if not playlist: playlist = Playlist[uid]
except ObjectNotFound:
flash('Unknown playlist') flash('Unknown playlist')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))
if playlist.user_id != request.user.id: if playlist.user.id != request.user.id:
flash("You're not allowed to edit this playlist") flash("You're not allowed to edit this playlist")
elif not request.form.get('name'): elif not request.form.get('name'):
flash('Missing playlist name') flash('Missing playlist name')
else: else:
playlist.name = request.form.get('name') playlist.name = request.form.get('name')
playlist.public = request.form.get('public') in (True, 'True', 1, '1', 'on', 'checked') playlist.public = request.form.get('public') in (True, 'True', 1, '1', 'on', 'checked')
store.commit()
flash('Playlist updated.') flash('Playlist updated.')
return playlist_details(uid) return playlist_details(uid)
@app.route('/playlist/del/<uid>') @app.route('/playlist/del/<uid>')
@db_session
def playlist_delete(uid): def playlist_delete(uid):
try: try:
uid = uuid.UUID(uid) uid = uuid.UUID(uid)
@ -80,14 +86,16 @@ def playlist_delete(uid):
flash('Invalid playlist id') flash('Invalid playlist id')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))
playlist = store.get(Playlist, uid) try:
if not playlist: playlist = Playlist[uid]
except ObjectNotFound:
flash('Unknown playlist') flash('Unknown playlist')
elif playlist.user_id != request.user.id: return redirect(url_for('playlist_index'))
if playlist.user.id != request.user.id:
flash("You're not allowed to delete this playlist") flash("You're not allowed to delete this playlist")
else: else:
store.remove(playlist) playlist.delete()
store.commit()
flash('Playlist deleted') flash('Playlist deleted')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))

View File

@ -20,15 +20,16 @@
from flask import request, session, flash, render_template, redirect, url_for, current_app as app from flask import request, session, flash, render_template, redirect, url_for, current_app as app
from functools import wraps from functools import wraps
from pony.orm import db_session
from ..db import User, ClientPrefs from ..db import User, ClientPrefs
from ..lastfm import LastFm from ..lastfm import LastFm
from ..managers.user import UserManager from ..managers.user import UserManager
from ..web import store
from . import admin_only from . import admin_only
def me_or_uuid(f, arg = 'uid'): def me_or_uuid(f, arg = 'uid'):
@db_session
@wraps(f) @wraps(f)
def decorated_func(*args, **kwargs): def decorated_func(*args, **kwargs):
if kwargs: if kwargs:
@ -37,11 +38,11 @@ def me_or_uuid(f, arg = 'uid'):
uid = args[0] uid = args[0]
if uid == 'me': if uid == 'me':
user = request.user user = User[request.user.id] # Refetch user from previous transaction
elif not request.user.admin: elif not request.user.admin:
return redirect(url_for('index')) return redirect(url_for('index'))
else: else:
code, user = UserManager.get(store, uid) code, user = UserManager.get(uid)
if code != UserManager.SUCCESS: if code != UserManager.SUCCESS:
flash(UserManager.error_str(code)) flash(UserManager.error_str(code))
return redirect(url_for('index')) return redirect(url_for('index'))
@ -57,14 +58,14 @@ def me_or_uuid(f, arg = 'uid'):
@app.route('/user') @app.route('/user')
@admin_only @admin_only
@db_session
def user_index(): def user_index():
return render_template('users.html', users = store.find(User)) return render_template('users.html', users = User.select())
@app.route('/user/<uid>') @app.route('/user/<uid>')
@me_or_uuid @me_or_uuid
def user_profile(uid, user): def user_profile(uid, user):
prefs = store.find(ClientPrefs, ClientPrefs.user_id == user.id) return render_template('profile.html', user = user, has_lastfm = app.config['LASTFM']['api_key'] != None, clients = user.clients)
return render_template('profile.html', user = user, has_lastfm = app.config['LASTFM']['api_key'] != None, clients = prefs)
@app.route('/user/<uid>', methods = [ 'POST' ]) @app.route('/user/<uid>', methods = [ 'POST' ])
@me_or_uuid @me_or_uuid
@ -87,25 +88,24 @@ def update_clients(uid, user):
app.logger.debug(clients_opts) app.logger.debug(clients_opts)
for client, opts in clients_opts.iteritems(): for client, opts in clients_opts.iteritems():
prefs = store.get(ClientPrefs, (user.id, client)) prefs = user.clients.select(lambda c: c.client_name == client).first()
if not prefs: if prefs is None:
continue continue
if 'delete' in opts and opts['delete'] in [ 'on', 'true', 'checked', 'selected', '1' ]: if 'delete' in opts and opts['delete'] in [ 'on', 'true', 'checked', 'selected', '1' ]:
store.remove(prefs) prefs.delete()
continue continue
prefs.format = opts['format'] if 'format' in opts and opts['format'] else None prefs.format = opts['format'] if 'format' in opts and opts['format'] else None
prefs.bitrate = int(opts['bitrate']) if 'bitrate' in opts and opts['bitrate'] else None prefs.bitrate = int(opts['bitrate']) if 'bitrate' in opts and opts['bitrate'] else None
store.commit()
flash('Clients preferences updated.') flash('Clients preferences updated.')
return user_profile(uid, user) return user_profile(uid, user)
@app.route('/user/<uid>/changeusername') @app.route('/user/<uid>/changeusername')
@admin_only @admin_only
def change_username_form(uid): def change_username_form(uid):
code, user = UserManager.get(store, uid) code, user = UserManager.get(uid)
if code != UserManager.SUCCESS: if code != UserManager.SUCCESS:
flash(UserManager.error_str(code)) flash(UserManager.error_str(code))
return redirect(url_for('index')) return redirect(url_for('index'))
@ -114,8 +114,9 @@ def change_username_form(uid):
@app.route('/user/<uid>/changeusername', methods = [ 'POST' ]) @app.route('/user/<uid>/changeusername', methods = [ 'POST' ])
@admin_only @admin_only
@db_session
def change_username_post(uid): def change_username_post(uid):
code, user = UserManager.get(store, uid) code, user = UserManager.get(uid)
if code != UserManager.SUCCESS: if code != UserManager.SUCCESS:
return redirect(url_for('index')) return redirect(url_for('index'))
@ -123,7 +124,7 @@ def change_username_post(uid):
if username in ('', None): if username in ('', None):
flash('The username is required') flash('The username is required')
return render_template('change_username.html', user = user) return render_template('change_username.html', user = user)
if user.name != username and store.find(User, User.name == username).one(): if user.name != username and User.get(name = username) is not None:
flash('This name is already taken') flash('This name is already taken')
return render_template('change_username.html', user = user) return render_template('change_username.html', user = user)
@ -135,7 +136,6 @@ def change_username_post(uid):
if user.name != username or user.admin != admin: if user.name != username or user.admin != admin:
user.name = username user.name = username
user.admin = admin user.admin = admin
store.commit()
flash("User '%s' updated." % username) flash("User '%s' updated." % username)
else: else:
flash("No changes for '%s'." % username) flash("No changes for '%s'." % username)
@ -150,10 +150,9 @@ def change_mail_form(uid, user):
@app.route('/user/<uid>/changemail', methods = [ 'POST' ]) @app.route('/user/<uid>/changemail', methods = [ 'POST' ])
@me_or_uuid @me_or_uuid
def change_mail_post(uid, user): def change_mail_post(uid, user):
mail = request.form.get('mail') mail = request.form.get('mail', '')
# No validation, lol. # No validation, lol.
user.mail = mail user.mail = mail
store.commit()
return redirect(url_for('user_profile', uid = uid)) return redirect(url_for('user_profile', uid = uid))
@app.route('/user/<uid>/changepass') @app.route('/user/<uid>/changepass')
@ -182,9 +181,9 @@ def change_password_post(uid, user):
if not error: if not error:
if user.id == request.user.id: if user.id == request.user.id:
status = UserManager.change_password(store, user.id, current, new) status = UserManager.change_password(user.id, current, new)
else: else:
status = UserManager.change_password2(store, user.name, new) status = UserManager.change_password2(user.name, new)
if status != UserManager.SUCCESS: if status != UserManager.SUCCESS:
flash(UserManager.error_str(status)) flash(UserManager.error_str(status))
@ -214,13 +213,12 @@ def add_user_post():
flash("The passwords don't match.") flash("The passwords don't match.")
error = True error = True
if admin is None: admin = admin is not None
admin = True if store.find(User, User.admin == True).count() == 0 else False if mail is None:
else: mail = ''
admin = True
if not error: if not error:
status = UserManager.add(store, name, passwd, mail, admin) status = UserManager.add(name, passwd, mail, admin)
if status == UserManager.SUCCESS: if status == UserManager.SUCCESS:
flash("User '%s' successfully added" % name) flash("User '%s' successfully added" % name)
return redirect(url_for('user_index')) return redirect(url_for('user_index'))
@ -232,7 +230,7 @@ def add_user_post():
@app.route('/user/del/<uid>') @app.route('/user/del/<uid>')
@admin_only @admin_only
def del_user(uid): def del_user(uid):
status = UserManager.delete(store, uid) status = UserManager.delete(uid)
if status == UserManager.SUCCESS: if status == UserManager.SUCCESS:
flash('Deleted user') flash('Deleted user')
else: else:
@ -250,7 +248,6 @@ def lastfm_reg(uid, user):
lfm = LastFm(app.config['LASTFM'], user, app.logger) lfm = LastFm(app.config['LASTFM'], user, app.logger)
status, error = lfm.link_account(token) status, error = lfm.link_account(token)
store.commit()
flash(error if not status else 'Successfully linked LastFM account') flash(error if not status else 'Successfully linked LastFM account')
return redirect(url_for('user_profile', uid = uid)) return redirect(url_for('user_profile', uid = uid))
@ -260,7 +257,6 @@ def lastfm_reg(uid, user):
def lastfm_unreg(uid, user): def lastfm_unreg(uid, user):
lfm = LastFm(app.config['LASTFM'], user, app.logger) lfm = LastFm(app.config['LASTFM'], user, app.logger)
lfm.unlink_account() lfm.unlink_account()
store.commit()
flash('Unlinked LastFM account') flash('Unlinked LastFM account')
return redirect(url_for('user_profile', uid = uid)) return redirect(url_for('user_profile', uid = uid))
@ -284,7 +280,7 @@ def login():
error = True error = True
if not error: if not error:
status, user = UserManager.try_auth(store, name, password) status, user = UserManager.try_auth(name, password)
if status == UserManager.SUCCESS: if status == UserManager.SUCCESS:
session['userid'] = str(user.id) session['userid'] = str(user.id)
flash('Logged in!') flash('Logged in!')

View File

@ -21,6 +21,9 @@
import os.path import os.path
import uuid import uuid
from pony.orm import db_session, select
from pony.orm import ObjectNotFound
from ..db import Folder, Artist, Album, Track, StarredFolder, RatingFolder from ..db import Folder, Artist, Album, Track, StarredFolder, RatingFolder
from ..scanner import Scanner from ..scanner import Scanner
@ -34,7 +37,8 @@ class FolderManager:
SUBPATH_EXISTS = 6 SUBPATH_EXISTS = 6
@staticmethod @staticmethod
def get(store, uid): @db_session
def get(uid):
if isinstance(uid, basestring): if isinstance(uid, basestring):
try: try:
uid = uuid.UUID(uid) uid = uuid.UUID(uid)
@ -45,65 +49,56 @@ class FolderManager:
else: else:
return FolderManager.INVALID_ID, None return FolderManager.INVALID_ID, None
folder = store.get(Folder, uid) try:
if not folder: folder = Folder[uid]
return FolderManager.SUCCESS, folder
except ObjectNotFound:
return FolderManager.NO_SUCH_FOLDER, None return FolderManager.NO_SUCH_FOLDER, None
return FolderManager.SUCCESS, folder
@staticmethod @staticmethod
def add(store, name, path): @db_session
if not store.find(Folder, Folder.name == name, Folder.root == True).is_empty(): def add(name, path):
if Folder.get(name = name, root = True) is not None:
return FolderManager.NAME_EXISTS return FolderManager.NAME_EXISTS
path = unicode(os.path.abspath(path)) path = unicode(os.path.abspath(path))
if not os.path.isdir(path): if not os.path.isdir(path):
return FolderManager.INVALID_PATH return FolderManager.INVALID_PATH
if not store.find(Folder, Folder.path == path).is_empty(): if Folder.get(path = path) is not None:
return FolderManager.PATH_EXISTS return FolderManager.PATH_EXISTS
if any(path.startswith(p) for p in store.find(Folder).values(Folder.path)): if any(path.startswith(p) for p in select(f.path for f in Folder)):
return FolderManager.PATH_EXISTS return FolderManager.PATH_EXISTS
if not store.find(Folder, Folder.path.startswith(path)).is_empty(): if Folder.exists(lambda f: f.path.startswith(path)):
return FolderManager.SUBPATH_EXISTS return FolderManager.SUBPATH_EXISTS
folder = Folder() folder = Folder(root = True, name = name, path = path)
folder.root = True
folder.name = name
folder.path = path
store.add(folder)
store.commit()
return FolderManager.SUCCESS return FolderManager.SUCCESS
@staticmethod @staticmethod
def delete(store, uid): @db_session
status, folder = FolderManager.get(store, uid) def delete(uid):
status, folder = FolderManager.get(uid)
if status != FolderManager.SUCCESS: if status != FolderManager.SUCCESS:
return status return status
if not folder.root: if not folder.root:
return FolderManager.NO_SUCH_FOLDER return FolderManager.NO_SUCH_FOLDER
scanner = Scanner(store) scanner = Scanner()
for track in store.find(Track, Track.root_folder_id == folder.id): for track in Track.select(lambda t: t.root_folder == folder):
scanner.remove_file(track.path) scanner.remove_file(track.path)
scanner.finish() scanner.finish()
store.find(StarredFolder, StarredFolder.starred_id == uid).remove() folder.delete()
store.find(RatingFolder, RatingFolder.rated_id == uid).remove()
store.remove(folder)
store.commit()
return FolderManager.SUCCESS return FolderManager.SUCCESS
@staticmethod @staticmethod
def delete_by_name(store, name): @db_session
folder = store.find(Folder, Folder.name == name, Folder.root == True).one() def delete_by_name(name):
folder = Folder.get(name = name, root = True)
if not folder: if not folder:
return FolderManager.NO_SUCH_FOLDER return FolderManager.NO_SUCH_FOLDER
return FolderManager.delete(store, folder.id) return FolderManager.delete(folder.id)
@staticmethod @staticmethod
def error_str(err): def error_str(err):

View File

@ -14,6 +14,9 @@ import random
import string import string
import uuid import uuid
from pony.orm import db_session
from pony.orm import ObjectNotFound
from ..db import User, ChatMessage, Playlist from ..db import User, ChatMessage, Playlist
from ..db import StarredFolder, StarredArtist, StarredAlbum, StarredTrack from ..db import StarredFolder, StarredArtist, StarredAlbum, StarredTrack
from ..db import RatingFolder, RatingTrack from ..db import RatingFolder, RatingTrack
@ -26,7 +29,8 @@ class UserManager:
WRONG_PASS = 4 WRONG_PASS = 4
@staticmethod @staticmethod
def get(store, uid): @db_session
def get(uid):
if type(uid) in (str, unicode): if type(uid) in (str, unicode):
try: try:
uid = uuid.UUID(uid) uid = uuid.UUID(uid)
@ -37,63 +41,53 @@ class UserManager:
else: else:
return UserManager.INVALID_ID, None return UserManager.INVALID_ID, None
user = store.get(User, uid) try:
if user is None: user = User[uid]
return UserManager.SUCCESS, user
except ObjectNotFound:
return UserManager.NO_SUCH_USER, None return UserManager.NO_SUCH_USER, None
return UserManager.SUCCESS, user
@staticmethod @staticmethod
def add(store, name, password, mail, admin): @db_session
if store.find(User, User.name == name).one(): def add(name, password, mail, admin):
if User.get(name = name) is not None:
return UserManager.NAME_EXISTS return UserManager.NAME_EXISTS
crypt, salt = UserManager.__encrypt_password(password) crypt, salt = UserManager.__encrypt_password(password)
user = User() user = User(
user.name = name name = name,
user.mail = mail mail = mail,
user.password = crypt password = crypt,
user.salt = salt salt = salt,
user.admin = admin admin = admin
)
store.add(user)
store.commit()
return UserManager.SUCCESS return UserManager.SUCCESS
@staticmethod @staticmethod
def delete(store, uid): @db_session
status, user = UserManager.get(store, uid) def delete(uid):
status, user = UserManager.get(uid)
if status != UserManager.SUCCESS: if status != UserManager.SUCCESS:
return status return status
store.find(StarredFolder, StarredFolder.user_id == user.id).remove() user.delete()
store.find(StarredArtist, StarredArtist.user_id == user.id).remove()
store.find(StarredAlbum, StarredAlbum.user_id == user.id).remove()
store.find(StarredTrack, StarredTrack.user_id == user.id).remove()
store.find(RatingFolder, RatingFolder.user_id == user.id).remove()
store.find(RatingTrack, RatingTrack.user_id == user.id).remove()
store.find(ChatMessage, ChatMessage.user_id == user.id).remove()
for playlist in store.find(Playlist, Playlist.user_id == user.id):
store.remove(playlist)
store.remove(user)
store.commit()
return UserManager.SUCCESS return UserManager.SUCCESS
@staticmethod @staticmethod
def delete_by_name(store, name): @db_session
user = store.find(User, User.name == name).one() def delete_by_name(name):
if not user: user = User.get(name = name)
if user is None:
return UserManager.NO_SUCH_USER return UserManager.NO_SUCH_USER
return UserManager.delete(store, user.id) return UserManager.delete(user.id)
@staticmethod @staticmethod
def try_auth(store, name, password): @db_session
user = store.find(User, User.name == name).one() def try_auth(name, password):
if not user: user = User.get(name = name)
if user is None:
return UserManager.NO_SUCH_USER, None return UserManager.NO_SUCH_USER, None
elif UserManager.__encrypt_password(password, user.salt)[0] != user.password: elif UserManager.__encrypt_password(password, user.salt)[0] != user.password:
return UserManager.WRONG_PASS, None return UserManager.WRONG_PASS, None
@ -101,8 +95,9 @@ class UserManager:
return UserManager.SUCCESS, user return UserManager.SUCCESS, user
@staticmethod @staticmethod
def change_password(store, uid, old_pass, new_pass): @db_session
status, user = UserManager.get(store, uid) def change_password(uid, old_pass, new_pass):
status, user = UserManager.get(uid)
if status != UserManager.SUCCESS: if status != UserManager.SUCCESS:
return status return status
@ -110,17 +105,16 @@ class UserManager:
return UserManager.WRONG_PASS return UserManager.WRONG_PASS
user.password = UserManager.__encrypt_password(new_pass, user.salt)[0] user.password = UserManager.__encrypt_password(new_pass, user.salt)[0]
store.commit()
return UserManager.SUCCESS return UserManager.SUCCESS
@staticmethod @staticmethod
def change_password2(store, name, new_pass): @db_session
user = store.find(User, User.name == name).one() def change_password2(name, new_pass):
if not user: user = User.get(name = name)
if user is None:
return UserManager.NO_SUCH_USER return UserManager.NO_SUCH_USER
user.password = UserManager.__encrypt_password(new_pass, user.salt)[0] user.password = UserManager.__encrypt_password(new_pass, user.salt)[0]
store.commit()
return UserManager.SUCCESS return UserManager.SUCCESS
@staticmethod @staticmethod

View File

@ -23,40 +23,17 @@ import mimetypes
import mutagen import mutagen
import time import time
from storm.expr import ComparableExpr, compile, Like from pony.orm import db_session
from storm.exceptions import NotSupportedError
from .db import Folder, Artist, Album, Track, User from .db import Folder, Artist, Album, Track, User
from .db import StarredFolder, StarredArtist, StarredAlbum, StarredTrack from .db import StarredFolder, StarredArtist, StarredAlbum, StarredTrack
from .db import RatingFolder, RatingTrack from .db import RatingFolder, RatingTrack
# Hacking in support for a concatenation expression
class Concat(ComparableExpr):
__slots__ = ("left", "right", "db")
def __init__(self, left, right, db):
self.left = left
self.right = right
self.db = db
@compile.when(Concat)
def compile_concat(compile, concat, state):
left = compile(concat.left, state)
right = compile(concat.right, state)
if concat.db in ('sqlite', 'postgres'):
statement = "%s||%s"
elif concat.db == 'mysql':
statement = "CONCAT(%s, %s)"
else:
raise NotSupportedError("Unspported database (%s)" % concat.db)
return statement % (left, right)
class Scanner: class Scanner:
def __init__(self, store, force = False, extensions = None): def __init__(self, force = False, extensions = None):
if extensions is not None and not isinstance(extensions, list): if extensions is not None and not isinstance(extensions, list):
raise TypeError('Invalid extensions type') raise TypeError('Invalid extensions type')
self.__store = store
self.__force = force self.__force = force
self.__added_artists = 0 self.__added_artists = 0
@ -92,7 +69,8 @@ class Scanner:
progress_callback(current, total) progress_callback(current, total)
# Remove files that have been deleted # Remove files that have been deleted
for track in [ t for t in self.__store.find(Track, Track.root_folder_id == folder.id) if not self.__is_valid_path(t.path) ]: for track in Track.select(lambda t: t.root_folder == folder):
if not self.__is_valid_path(track.path):
self.remove_file(track.path) self.remove_file(track.path)
# Update cover art info # Update cover art info
@ -104,33 +82,33 @@ class Scanner:
folder.last_scan = int(time.time()) folder.last_scan = int(time.time())
@db_session
def finish(self): def finish(self):
for album in [ a for a in self.__albums_to_check if not a.tracks.count() ]: for album in Album.select(lambda a: a.id in self.__albums_to_check):
self.__store.find(StarredAlbum, StarredAlbum.starred_id == album.id).remove() if not album.tracks.is_empty():
continue
self.__artists_to_check.add(album.artist) self.__artists_to_check.add(album.artist.id)
self.__store.remove(album)
self.__deleted_albums += 1 self.__deleted_albums += 1
album.delete()
self.__albums_to_check.clear() self.__albums_to_check.clear()
for artist in [ a for a in self.__artists_to_check if not a.albums.count() and not a.tracks.count() ]: for artist in Artist.select(lambda a: a.id in self.__artists_to_check):
self.__store.find(StarredArtist, StarredArtist.starred_id == artist.id).remove() if not artist.albums.is_empty() or not artist.tracks.is_empty():
continue
self.__store.remove(artist)
self.__deleted_artists += 1 self.__deleted_artists += 1
artist.delete()
self.__artists_to_check.clear() self.__artists_to_check.clear()
while self.__folders_to_check: while self.__folders_to_check:
folder = self.__folders_to_check.pop() folder = Folder[self.__folders_to_check.pop()]
if folder.root: if folder.root:
continue continue
if not folder.tracks.count() and not folder.children.count(): if folder.tracks.is_empty() and folder.children.is_empty():
self.__store.find(StarredFolder, StarredFolder.starred_id == folder.id).remove() self.__folders_to_check.add(folder.parent.id)
self.__store.find(RatingFolder, RatingFolder.rated_id == folder.id).remove() folder.delete()
self.__folders_to_check.add(folder.parent)
self.__store.remove(folder)
def __is_valid_path(self, path): def __is_valid_path(self, path):
if not os.path.exists(path): if not os.path.exists(path):
@ -139,13 +117,13 @@ class Scanner:
return True return True
return os.path.splitext(path)[1][1:].lower() in self.__extensions return os.path.splitext(path)[1][1:].lower() in self.__extensions
@db_session
def scan_file(self, path): def scan_file(self, path):
if not isinstance(path, basestring): if not isinstance(path, basestring):
raise TypeError('Expecting string, got ' + str(type(path))) raise TypeError('Expecting string, got ' + str(type(path)))
tr = self.__store.find(Track, Track.path == path).one() tr = Track.get(path = path)
add = False if tr is not None:
if tr:
if not self.__force and not int(os.path.getmtime(path)) > tr.last_modification: if not self.__force and not int(os.path.getmtime(path)) > tr.last_modification:
return return
@ -153,74 +131,70 @@ class Scanner:
if not tag: if not tag:
self.remove_file(path) self.remove_file(path)
return return
trdict = {}
else: else:
tag = self.__try_load_tag(path) tag = self.__try_load_tag(path)
if not tag: if not tag:
return return
tr = Track() trdict = { 'path': path }
tr.path = path
add = True
artist = self.__try_read_tag(tag, 'artist', '') artist = self.__try_read_tag(tag, 'artist')
album = self.__try_read_tag(tag, 'album', '') if not artist:
return
album = self.__try_read_tag(tag, 'album', '[non-album tracks]')
albumartist = self.__try_read_tag(tag, 'albumartist', artist) albumartist = self.__try_read_tag(tag, 'albumartist', artist)
tr.disc = self.__try_read_tag(tag, 'discnumber', 1, lambda x: int(x[0].split('/')[0])) trdict['disc'] = self.__try_read_tag(tag, 'discnumber', 1, lambda x: int(x[0].split('/')[0]))
tr.number = self.__try_read_tag(tag, 'tracknumber', 1, lambda x: int(x[0].split('/')[0])) trdict['number'] = self.__try_read_tag(tag, 'tracknumber', 1, lambda x: int(x[0].split('/')[0]))
tr.title = self.__try_read_tag(tag, 'title', '') trdict['title'] = self.__try_read_tag(tag, 'title', '')
tr.year = self.__try_read_tag(tag, 'date', None, lambda x: int(x[0].split('-')[0])) trdict['year'] = self.__try_read_tag(tag, 'date', None, lambda x: int(x[0].split('-')[0]))
tr.genre = self.__try_read_tag(tag, 'genre') trdict['genre'] = self.__try_read_tag(tag, 'genre')
tr.duration = int(tag.info.length) trdict['duration'] = int(tag.info.length)
tr.bitrate = (tag.info.bitrate if hasattr(tag.info, 'bitrate') else int(os.path.getsize(path) * 8 / tag.info.length)) / 1000 trdict['bitrate'] = (tag.info.bitrate if hasattr(tag.info, 'bitrate') else int(os.path.getsize(path) * 8 / tag.info.length)) / 1000
tr.content_type = mimetypes.guess_type(path, False)[0] or 'application/octet-stream' trdict['content_type'] = mimetypes.guess_type(path, False)[0] or 'application/octet-stream'
tr.last_modification = os.path.getmtime(path) trdict['last_modification'] = int(os.path.getmtime(path))
tralbum = self.__find_album(albumartist, album) tralbum = self.__find_album(albumartist, album)
trartist = self.__find_artist(artist) trartist = self.__find_artist(artist)
if add: if tr is None:
trroot = self.__find_root_folder(path) trdict['root_folder'] = self.__find_root_folder(path)
trfolder = self.__find_folder(path) trdict['folder'] = self.__find_folder(path)
trdict['album'] = tralbum
trdict['artist'] = trartist
# Set the references at the very last as searching for them will cause the added track to be flushed, even if Track(**trdict)
# it is incomplete, causing not null constraints errors.
tr.album = tralbum
tr.artist = trartist
tr.folder = trfolder
tr.root_folder = trroot
self.__store.add(tr)
self.__added_tracks += 1 self.__added_tracks += 1
else: else:
if tr.album.id != tralbum.id: if tr.album.id != tralbum.id:
self.__albums_to_check.add(tr.album) self.__albums_to_check.add(tr.album.id)
tr.album = tralbum trdict['album'] = tralbum
if tr.artist.id != trartist.id: if tr.artist.id != trartist.id:
self.__artists_to_check.add(tr.artist) self.__artists_to_check.add(tr.artist.id)
tr.artist = trartist trdict['artist'] = trartist
tr.set(**trdict)
@db_session
def remove_file(self, path): def remove_file(self, path):
if not isinstance(path, basestring): if not isinstance(path, basestring):
raise TypeError('Expecting string, got ' + str(type(path))) raise TypeError('Expecting string, got ' + str(type(path)))
tr = self.__store.find(Track, Track.path == path).one() tr = Track.get(path = path)
if not tr: if not tr:
return return
self.__store.find(StarredTrack, StarredTrack.starred_id == tr.id).remove() self.__folders_to_check.add(tr.folder.id)
self.__store.find(RatingTrack, RatingTrack.rated_id == tr.id).remove() self.__albums_to_check.add(tr.album.id)
# Playlist autofix themselves self.__artists_to_check.add(tr.artist.id)
self.__store.find(User, User.last_play_id == tr.id).set(last_play_id = None)
self.__folders_to_check.add(tr.folder)
self.__albums_to_check.add(tr.album)
self.__artists_to_check.add(tr.artist)
self.__store.remove(tr)
self.__deleted_tracks += 1 self.__deleted_tracks += 1
tr.delete()
@db_session
def move_file(self, src_path, dst_path): def move_file(self, src_path, dst_path):
if not isinstance(src_path, basestring): if not isinstance(src_path, basestring):
raise TypeError('Expecting string, got ' + str(type(src_path))) raise TypeError('Expecting string, got ' + str(type(src_path)))
@ -230,16 +204,18 @@ class Scanner:
if src_path == dst_path: if src_path == dst_path:
return return
tr = self.__store.find(Track, Track.path == src_path).one() tr = Track.get(path = src_path)
if not tr: if tr is None:
return return
self.__folders_to_check.add(tr.folder) self.__folders_to_check.add(tr.folder.id)
tr_dst = self.__store.find(Track, Track.path == dst_path).one() tr_dst = Track.get(path = dst_path)
if tr_dst: if tr_dst is not None:
tr.root_folder = tr_dst.root_folder root = tr_dst.root_folder
tr.folder = tr_dst.folder folder = tr_dst.folder
self.remove_file(dst_path) self.remove_file(dst_path)
tr.root_folder = root
tr.folder = folder
else: else:
root = self.__find_root_folder(dst_path) root = self.__find_root_folder(dst_path)
folder = self.__find_folder(dst_path) folder = self.__find_folder(dst_path)
@ -249,70 +225,48 @@ class Scanner:
def __find_album(self, artist, album): def __find_album(self, artist, album):
ar = self.__find_artist(artist) ar = self.__find_artist(artist)
al = ar.albums.find(name = album).one() al = ar.albums.select(lambda a: a.name == album).first()
if al: if al:
return al return al
al = Album() al = Album(name = album, artist = ar)
al.name = album
al.artist = ar
self.__store.add(al)
self.__added_albums += 1 self.__added_albums += 1
return al return al
def __find_artist(self, artist): def __find_artist(self, artist):
ar = self.__store.find(Artist, Artist.name == artist).one() ar = Artist.get(name = artist)
if ar: if ar:
return ar return ar
ar = Artist() ar = Artist(name = artist)
ar.name = artist
self.__store.add(ar)
self.__added_artists += 1 self.__added_artists += 1
return ar return ar
def __find_root_folder(self, path): def __find_root_folder(self, path):
path = os.path.dirname(path) path = os.path.dirname(path)
db = self.__store.get_database().__module__[len('storm.databases.'):] for folder in Folder.select(lambda f: f.root):
folders = self.__store.find(Folder, Like(path, Concat(Folder.path, u'%', db)), Folder.root == True) if path.startswith(folder.path):
count = folders.count() return folder
if count > 1:
raise Exception("Found multiple root folders for '{}'.".format(path))
elif count == 0:
raise Exception("Couldn't find the root folder for '{}'.\nDon't scan files that aren't located in a defined music folder".format(path)) raise Exception("Couldn't find the root folder for '{}'.\nDon't scan files that aren't located in a defined music folder".format(path))
return folders.one()
def __find_folder(self, path): def __find_folder(self, path):
children = []
drive, _ = os.path.splitdrive(path)
path = os.path.dirname(path) path = os.path.dirname(path)
folders = self.__store.find(Folder, Folder.path == path) while path != drive and path != '/':
count = folders.count() folder = Folder.get(path = path)
if count > 1: if folder is not None:
raise Exception("Found multiple folders for '{}'.".format(path)) break
elif count == 1:
return folders.one()
db = self.__store.get_database().__module__[len('storm.databases.'):] children.append(dict(root = False, name = os.path.basename(path), path = path))
folder = self.__store.find(Folder, Like(path, Concat(Folder.path, os.sep + u'%', db))).order_by(Folder.path).last() path = os.path.dirname(path)
full_path = folder.path assert folder is not None
path = path[len(folder.path) + 1:] while children:
folder = Folder(parent = folder, **children.pop())
for name in path.split(os.sep):
full_path = os.path.join(full_path, name)
fold = Folder()
fold.root = False
fold.name = name
fold.path = full_path
fold.parent = folder
self.__store.add(fold)
folder = fold
return folder return folder

View File

@ -22,12 +22,13 @@ import logging
import time import time
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
from pony.orm import db_session
from signal import signal, SIGTERM, SIGINT from signal import signal, SIGTERM, SIGINT
from threading import Thread, Condition, Timer from threading import Thread, Condition, Timer
from watchdog.observers import Observer from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler from watchdog.events import PatternMatchingEventHandler
from . import db from .db import init_database, release_database, Folder
from .scanner import Scanner from .scanner import Scanner
OP_SCAN = 1 OP_SCAN = 1
@ -109,12 +110,11 @@ class Event(object):
return self.__src return self.__src
class ScannerProcessingQueue(Thread): class ScannerProcessingQueue(Thread):
def __init__(self, database_uri, delay, logger): def __init__(self, delay, logger):
super(ScannerProcessingQueue, self).__init__() super(ScannerProcessingQueue, self).__init__()
self.__logger = logger self.__logger = logger
self.__timeout = delay self.__timeout = delay
self.__database_uri = database_uri
self.__cond = Condition() self.__cond = Condition()
self.__timer = None self.__timer = None
self.__queue = {} self.__queue = {}
@ -138,8 +138,7 @@ class ScannerProcessingQueue(Thread):
continue continue
self.__logger.debug("Instantiating scanner") self.__logger.debug("Instantiating scanner")
store = db.get_store(self.__database_uri) scanner = Scanner()
scanner = Scanner(store)
item = self.__next_item() item = self.__next_item()
while item: while item:
@ -155,8 +154,6 @@ class ScannerProcessingQueue(Thread):
item = self.__next_item() item = self.__next_item()
scanner.finish() scanner.finish()
store.commit()
store.close()
self.__logger.debug("Freeing scanner") self.__logger.debug("Freeing scanner")
del scanner del scanner
@ -208,6 +205,7 @@ class SupysonicWatcher(object):
def __init__(self, config): def __init__(self, config):
self.__config = config self.__config = config
self.__running = True self.__running = True
init_database(config.BASE['database_uri'])
def run(self): def run(self):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -227,22 +225,22 @@ class SupysonicWatcher(object):
} }
logger.setLevel(mapping.get(self.__config.DAEMON['log_level'].upper(), logging.NOTSET)) logger.setLevel(mapping.get(self.__config.DAEMON['log_level'].upper(), logging.NOTSET))
store = db.get_store(self.__config.BASE['database_uri']) with db_session:
folders = store.find(db.Folder, db.Folder.root == True) folders = Folder.select(lambda f: f.root)
shouldrun = folders.exists()
if not folders.count(): if not shouldrun:
logger.info("No folder set. Exiting.") logger.info("No folder set. Exiting.")
store.close() release_database()
return return
queue = ScannerProcessingQueue(self.__config.BASE['database_uri'], self.__config.DAEMON['wait_delay'], logger) queue = ScannerProcessingQueue(self.__config.DAEMON['wait_delay'], logger)
handler = SupysonicWatcherEventHandler(self.__config.BASE['scanner_extensions'], queue, logger) handler = SupysonicWatcherEventHandler(self.__config.BASE['scanner_extensions'], queue, logger)
observer = Observer() observer = Observer()
with db_session:
for folder in folders: for folder in folders:
logger.info("Starting watcher for %s", folder.path) logger.info("Starting watcher for %s", folder.path)
observer.schedule(handler, folder.path, recursive = True) observer.schedule(handler, folder.path, recursive = True)
store.close()
try: try:
signal(SIGTERM, self.__terminate) signal(SIGTERM, self.__terminate)
@ -260,6 +258,7 @@ class SupysonicWatcher(object):
observer.join() observer.join()
queue.stop() queue.stop()
queue.join() queue.join()
release_database()
def stop(self): def stop(self):
self.__running = False self.__running = False

View File

@ -11,25 +11,11 @@
import mimetypes import mimetypes
from flask import Flask, g, current_app from flask import Flask
from os import makedirs, path from os import makedirs, path
from werkzeug.local import LocalProxy
from .config import IniConfig from .config import IniConfig
from .db import get_store from .db import init_database, release_database
# Supysonic database open
def get_db():
if not hasattr(g, 'database'):
g.database = get_store(current_app.config['BASE']['database_uri'])
return g.database
# Supysonic database close
def close_db(error):
if hasattr(g, 'database'):
g.database.close()
store = LocalProxy(get_db)
def create_application(config = None): def create_application(config = None):
global app global app
@ -42,9 +28,6 @@ def create_application(config = None):
config = IniConfig.from_common_locations() config = IniConfig.from_common_locations()
app.config.from_object(config) app.config.from_object(config)
# Close database connection on teardown
app.teardown_appcontext(close_db)
# Set loglevel # Set loglevel
logfile = app.config['WEBAPP']['log_file'] logfile = app.config['WEBAPP']['log_file']
if logfile: if logfile:
@ -63,6 +46,9 @@ def create_application(config = None):
handler.setLevel(mapping.get(loglevel.upper(), logging.NOTSET)) handler.setLevel(mapping.get(loglevel.upper(), logging.NOTSET))
app.logger.addHandler(handler) app.logger.addHandler(handler)
# Initialize database
init_database(app.config['BASE']['database_uri'])
# Insert unknown mimetypes # Insert unknown mimetypes
for k, v in app.config['MIMETYPES'].iteritems(): for k, v in app.config['MIMETYPES'].iteritems():
extension = '.' + k.lower() extension = '.' + k.lower()

View File

@ -11,7 +11,10 @@
import unittest import unittest
from . import base, managers, api, frontend from . import base
from . import managers
from . import api
from . import frontend
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()

View File

@ -11,6 +11,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import Folder, Artist, Album, Track from supysonic.db import Folder, Artist, Album, Track
from .apitestbase import ApiTestBase from .apitestbase import ApiTestBase
@ -22,34 +24,25 @@ class AlbumSongsTestCase(ApiTestBase):
def setUp(self): def setUp(self):
super(AlbumSongsTestCase, self).setUp() super(AlbumSongsTestCase, self).setUp()
folder = Folder() with db_session:
folder.name = 'Root' folder = Folder(name = 'Root', root = True, path = 'tests/assets')
folder.root = True artist = Artist(name = 'Artist')
folder.path = 'tests/assets' album = Album(name = 'Album', artist = artist)
artist = Artist() track = Track(
artist.name = 'Artist' title = 'Track',
album = album,
album = Album() artist = artist,
album.name = 'Album' disc = 1,
album.artist = artist number = 1,
path = 'tests/assets/empty',
track = Track() folder = folder,
track.title = 'Track' root_folder = folder,
track.album = album duration = 2,
track.artist = artist bitrate = 320,
track.disc = 1 content_type = 'audio/mpeg',
track.number = 1 last_modification = 0
track.path = 'tests/assets/empty' )
track.folder = folder
track.root_folder = folder
track.duration = 2
track.bitrate = 320
track.content_type = 'audio/mpeg'
track.last_modification = 0
self.store.add(track)
self.store.commit()
def test_get_album_list(self): def test_get_album_list(self):
self._make_request('getAlbumList', error = 10) self._make_request('getAlbumList', error = 10)
@ -63,11 +56,9 @@ class AlbumSongsTestCase(ApiTestBase):
self._make_request('getAlbumList', { 'type': t }, tag = 'albumList', skip_post = True) self._make_request('getAlbumList', { 'type': t }, tag = 'albumList', skip_post = True)
rv, child = self._make_request('getAlbumList', { 'type': 'random' }, tag = 'albumList', skip_post = True) rv, child = self._make_request('getAlbumList', { 'type': 'random' }, tag = 'albumList', skip_post = True)
self.assertEqual(len(child), 10)
rv, child = self._make_request('getAlbumList', { 'type': 'random', 'size': 3 }, tag = 'albumList', skip_post = True)
self.assertEqual(len(child), 3)
self.store.remove(self.store.find(Folder).one()) with db_session:
Folder.get().delete()
rv, child = self._make_request('getAlbumList', { 'type': 'random' }, tag = 'albumList') rv, child = self._make_request('getAlbumList', { 'type': 'random' }, tag = 'albumList')
self.assertEqual(len(child), 0) self.assertEqual(len(child), 0)
@ -82,12 +73,10 @@ class AlbumSongsTestCase(ApiTestBase):
self._make_request('getAlbumList2', { 'type': t }, tag = 'albumList2', skip_post = True) self._make_request('getAlbumList2', { 'type': t }, tag = 'albumList2', skip_post = True)
rv, child = self._make_request('getAlbumList2', { 'type': 'random' }, tag = 'albumList2', skip_post = True) rv, child = self._make_request('getAlbumList2', { 'type': 'random' }, tag = 'albumList2', skip_post = True)
self.assertEqual(len(child), 10)
rv, child = self._make_request('getAlbumList2', { 'type': 'random', 'size': 3 }, tag = 'albumList2', skip_post = True)
self.assertEqual(len(child), 3)
self.store.remove(self.store.find(Track).one()) with db_session:
self.store.remove(self.store.find(Album).one()) Track.get().delete()
Album.get().delete()
rv, child = self._make_request('getAlbumList2', { 'type': 'random' }, tag = 'albumList2') rv, child = self._make_request('getAlbumList2', { 'type': 'random' }, tag = 'albumList2')
self.assertEqual(len(child), 0) self.assertEqual(len(child), 0)
@ -98,12 +87,10 @@ class AlbumSongsTestCase(ApiTestBase):
self._make_request('getRandomSongs', { 'musicFolderId': 'idid' }, error = 0) self._make_request('getRandomSongs', { 'musicFolderId': 'idid' }, error = 0)
self._make_request('getRandomSongs', { 'musicFolderId': uuid.uuid4() }, error = 70) self._make_request('getRandomSongs', { 'musicFolderId': uuid.uuid4() }, error = 70)
rv, child = self._make_request('getRandomSongs', tag = 'randomSongs') rv, child = self._make_request('getRandomSongs', tag = 'randomSongs', skip_post = True)
self.assertEqual(len(child), 10)
rv, child = self._make_request('getRandomSongs', { 'size': 3 }, tag = 'randomSongs')
self.assertEqual(len(child), 3)
fid = self.store.find(Folder).one().id with db_session:
fid = Folder.get().id
self._make_request('getRandomSongs', { 'fromYear': -52, 'toYear': '1984', 'genre': 'some cryptic subgenre youve never heard of', 'musicFolderId': fid }, tag = 'randomSongs') self._make_request('getRandomSongs', { 'fromYear': -52, 'toYear': '1984', 'genre': 'some cryptic subgenre youve never heard of', 'musicFolderId': fid }, tag = 'randomSongs')
def test_now_playing(self): def test_now_playing(self):

View File

@ -11,6 +11,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import Folder, Artist, Album, Track, User, ClientPrefs from supysonic.db import Folder, Artist, Album, Track, User, ClientPrefs
from .apitestbase import ApiTestBase from .apitestbase import ApiTestBase
@ -19,45 +21,32 @@ class AnnotationTestCase(ApiTestBase):
def setUp(self): def setUp(self):
super(AnnotationTestCase, self).setUp() super(AnnotationTestCase, self).setUp()
root = Folder() with db_session:
root.name = 'Root' root = Folder(name = 'Root', root = True, path = 'tests')
root.root = True folder = Folder(name = 'Folder', path = 'tests/assets', parent = root)
root.path = 'tests/assets' artist = Artist(name = 'Artist')
album = Album(name = 'Album', artist = artist)
folder = Folder() track = Track(
folder.name = 'Folder' title = 'Track',
folder.path = 'tests/assets' album = album,
folder.parent = root artist = artist,
disc = 1,
number = 1,
path = 'tests/assets/empty',
folder = folder,
root_folder = root,
duration = 2,
bitrate = 320,
content_type = 'audio/mpeg',
last_modification = 0
)
artist = Artist() self.folderid = folder.id
artist.name = 'Artist' self.artistid = artist.id
self.albumid = album.id
album = Album() self.trackid = track.id
album.name = 'Album' self.user = User.get(name = 'alice')
album.artist = artist
track = Track()
track.title = 'Track'
track.album = album
track.artist = artist
track.disc = 1
track.number = 1
track.path = 'tests/assets/empty'
track.folder = folder
track.root_folder = root
track.duration = 2
track.bitrate = 320
track.content_type = 'audio/mpeg'
track.last_modification = 0
self.store.add(track)
self.store.commit()
self.folder = folder
self.artist = artist
self.album = album
self.track = track
self.user = self.store.find(User, User.name == 'alice').one()
def test_star(self): def test_star(self):
self._make_request('star', error = 10) self._make_request('star', error = 10)
@ -68,88 +57,101 @@ class AnnotationTestCase(ApiTestBase):
self._make_request('star', { 'albumId': str(uuid.uuid4()) }, error = 70) self._make_request('star', { 'albumId': str(uuid.uuid4()) }, error = 70)
self._make_request('star', { 'artistId': str(uuid.uuid4()) }, error = 70) self._make_request('star', { 'artistId': str(uuid.uuid4()) }, error = 70)
self._make_request('star', { 'id': str(self.artist.id) }, error = 70, skip_xsd = True) self._make_request('star', { 'id': str(self.artistid) }, error = 70, skip_xsd = True)
self._make_request('star', { 'id': str(self.album.id) }, error = 70, skip_xsd = True) self._make_request('star', { 'id': str(self.albumid) }, error = 70, skip_xsd = True)
self._make_request('star', { 'id': str(self.track.id) }, skip_post = True) self._make_request('star', { 'id': str(self.trackid) }, skip_post = True)
self.assertIn('starred', self.track.as_subsonic_child(self.user, ClientPrefs())) with db_session:
self._make_request('star', { 'id': str(self.track.id) }, error = 0, skip_xsd = True) self.assertIn('starred', Track[self.trackid].as_subsonic_child(self.user, 'tests'))
self._make_request('star', { 'id': str(self.trackid) }, error = 0, skip_xsd = True)
self._make_request('star', { 'id': str(self.folder.id) }, skip_post = True) self._make_request('star', { 'id': str(self.folderid) }, skip_post = True)
self.assertIn('starred', self.folder.as_subsonic_child(self.user)) with db_session:
self._make_request('star', { 'id': str(self.folder.id) }, error = 0, skip_xsd = True) self.assertIn('starred', Folder[self.folderid].as_subsonic_child(self.user))
self._make_request('star', { 'id': str(self.folderid) }, error = 0, skip_xsd = True)
self._make_request('star', { 'albumId': str(self.folder.id) }, error = 70) self._make_request('star', { 'albumId': str(self.folderid) }, error = 70)
self._make_request('star', { 'albumId': str(self.artist.id) }, error = 70) self._make_request('star', { 'albumId': str(self.artistid) }, error = 70)
self._make_request('star', { 'albumId': str(self.track.id) }, error = 70) self._make_request('star', { 'albumId': str(self.trackid) }, error = 70)
self._make_request('star', { 'albumId': str(self.album.id) }, skip_post = True) self._make_request('star', { 'albumId': str(self.albumid) }, skip_post = True)
self.assertIn('starred', self.album.as_subsonic_album(self.user)) with db_session:
self._make_request('star', { 'albumId': str(self.album.id) }, error = 0) self.assertIn('starred', Album[self.albumid].as_subsonic_album(self.user))
self._make_request('star', { 'albumId': str(self.albumid) }, error = 0)
self._make_request('star', { 'artistId': str(self.folder.id) }, error = 70) self._make_request('star', { 'artistId': str(self.folderid) }, error = 70)
self._make_request('star', { 'artistId': str(self.album.id) }, error = 70) self._make_request('star', { 'artistId': str(self.albumid) }, error = 70)
self._make_request('star', { 'artistId': str(self.track.id) }, error = 70) self._make_request('star', { 'artistId': str(self.trackid) }, error = 70)
self._make_request('star', { 'artistId': str(self.artist.id) }, skip_post = True) self._make_request('star', { 'artistId': str(self.artistid) }, skip_post = True)
self.assertIn('starred', self.artist.as_subsonic_artist(self.user)) with db_session:
self._make_request('star', { 'artistId': str(self.artist.id) }, error = 0) self.assertIn('starred', Artist[self.artistid].as_subsonic_artist(self.user))
self._make_request('star', { 'artistId': str(self.artistid) }, error = 0)
def test_unstar(self): def test_unstar(self):
self._make_request('star', { 'id': [ str(self.folder.id), str(self.track.id) ], 'artistId': str(self.artist.id), 'albumId': str(self.album.id) }, skip_post = True) self._make_request('star', { 'id': [ str(self.folderid), str(self.trackid) ], 'artistId': str(self.artistid), 'albumId': str(self.albumid) }, skip_post = True)
self._make_request('unstar', error = 10) self._make_request('unstar', error = 10)
self._make_request('unstar', { 'id': 'unknown' }, error = 0, skip_xsd = True) self._make_request('unstar', { 'id': 'unknown' }, error = 0, skip_xsd = True)
self._make_request('unstar', { 'albumId': 'unknown' }, error = 0) self._make_request('unstar', { 'albumId': 'unknown' }, error = 0)
self._make_request('unstar', { 'artistId': 'unknown' }, error = 0) self._make_request('unstar', { 'artistId': 'unknown' }, error = 0)
self._make_request('unstar', { 'id': str(self.track.id) }, skip_post = True) self._make_request('unstar', { 'id': str(self.trackid) }, skip_post = True)
self.assertNotIn('starred', self.track.as_subsonic_child(self.user, ClientPrefs())) with db_session:
self.assertNotIn('starred', Track[self.trackid].as_subsonic_child(self.user, 'tests'))
self._make_request('unstar', { 'id': str(self.folder.id) }, skip_post = True) self._make_request('unstar', { 'id': str(self.folderid) }, skip_post = True)
self.assertNotIn('starred', self.folder.as_subsonic_child(self.user)) with db_session:
self.assertNotIn('starred', Folder[self.folderid].as_subsonic_child(self.user))
self._make_request('unstar', { 'albumId': str(self.album.id) }, skip_post = True) self._make_request('unstar', { 'albumId': str(self.albumid) }, skip_post = True)
self.assertNotIn('starred', self.album.as_subsonic_album(self.user)) with db_session:
self.assertNotIn('starred', Album[self.albumid].as_subsonic_album(self.user))
self._make_request('unstar', { 'artistId': str(self.artist.id) }, skip_post = True) self._make_request('unstar', { 'artistId': str(self.artistid) }, skip_post = True)
self.assertNotIn('starred', self.artist.as_subsonic_artist(self.user)) with db_session:
self.assertNotIn('starred', Artist[self.artistid].as_subsonic_artist(self.user))
def test_set_rating(self): def test_set_rating(self):
self._make_request('setRating', error = 10) self._make_request('setRating', error = 10)
self._make_request('setRating', { 'id': str(self.track.id) }, error = 10) self._make_request('setRating', { 'id': str(self.trackid) }, error = 10)
self._make_request('setRating', { 'rating': 3 }, error = 10) self._make_request('setRating', { 'rating': 3 }, error = 10)
self._make_request('setRating', { 'id': 'string', 'rating': 3 }, error = 0) self._make_request('setRating', { 'id': 'string', 'rating': 3 }, error = 0)
self._make_request('setRating', { 'id': str(uuid.uuid4()), 'rating': 3 }, error = 70) self._make_request('setRating', { 'id': str(uuid.uuid4()), 'rating': 3 }, error = 70)
self._make_request('setRating', { 'id': str(self.artist.id), 'rating': 3 }, error = 70) self._make_request('setRating', { 'id': str(self.artistid), 'rating': 3 }, error = 70)
self._make_request('setRating', { 'id': str(self.album.id), 'rating': 3 }, error = 70) self._make_request('setRating', { 'id': str(self.albumid), 'rating': 3 }, error = 70)
self._make_request('setRating', { 'id': str(self.track.id), 'rating': 'string' }, error = 0) self._make_request('setRating', { 'id': str(self.trackid), 'rating': 'string' }, error = 0)
self._make_request('setRating', { 'id': str(self.track.id), 'rating': -1 }, error = 0) self._make_request('setRating', { 'id': str(self.trackid), 'rating': -1 }, error = 0)
self._make_request('setRating', { 'id': str(self.track.id), 'rating': 6 }, error = 0) self._make_request('setRating', { 'id': str(self.trackid), 'rating': 6 }, error = 0)
prefs = ClientPrefs() with db_session:
self.assertNotIn('userRating', self.track.as_subsonic_child(self.user, prefs)) self.assertNotIn('userRating', Track[self.trackid].as_subsonic_child(self.user, 'tests'))
for i in range(1, 6): for i in range(1, 6):
self._make_request('setRating', { 'id': str(self.track.id), 'rating': i }, skip_post = True) self._make_request('setRating', { 'id': str(self.trackid), 'rating': i }, skip_post = True)
self.assertEqual(self.track.as_subsonic_child(self.user, prefs)['userRating'], i) with db_session:
self._make_request('setRating', { 'id': str(self.track.id), 'rating': 0 }, skip_post = True) self.assertEqual(Track[self.trackid].as_subsonic_child(self.user, 'tests')['userRating'], i)
self.assertNotIn('userRating', self.track.as_subsonic_child(self.user, prefs))
self.assertNotIn('userRating', self.folder.as_subsonic_child(self.user)) self._make_request('setRating', { 'id': str(self.trackid), 'rating': 0 }, skip_post = True)
with db_session:
self.assertNotIn('userRating', Track[self.trackid].as_subsonic_child(self.user, 'tests'))
self.assertNotIn('userRating', Folder[self.folderid].as_subsonic_child(self.user))
for i in range(1, 6): for i in range(1, 6):
self._make_request('setRating', { 'id': str(self.folder.id), 'rating': i }, skip_post = True) self._make_request('setRating', { 'id': str(self.folderid), 'rating': i }, skip_post = True)
self.assertEqual(self.folder.as_subsonic_child(self.user)['userRating'], i) with db_session:
self._make_request('setRating', { 'id': str(self.folder.id), 'rating': 0 }, skip_post = True) self.assertEqual(Folder[self.folderid].as_subsonic_child(self.user)['userRating'], i)
self.assertNotIn('userRating', self.folder.as_subsonic_child(self.user)) self._make_request('setRating', { 'id': str(self.folderid), 'rating': 0 }, skip_post = True)
with db_session:
self.assertNotIn('userRating', Folder[self.folderid].as_subsonic_child(self.user))
def test_scrobble(self): def test_scrobble(self):
self._make_request('scrobble', error = 10) self._make_request('scrobble', error = 10)
self._make_request('scrobble', { 'id': 'song' }, error = 0) self._make_request('scrobble', { 'id': 'song' }, error = 0)
self._make_request('scrobble', { 'id': str(uuid.uuid4()) }, error = 70) self._make_request('scrobble', { 'id': str(uuid.uuid4()) }, error = 70)
self._make_request('scrobble', { 'id': str(self.folder.id) }, error = 70) self._make_request('scrobble', { 'id': str(self.folderid) }, error = 70)
self.skipTest('Weird request context/logger issue at exit') self.skipTest('Weird request context/logger issue at exit')
self._make_request('scrobble', { 'id': str(self.track.id) }) self._make_request('scrobble', { 'id': str(self.trackid) })
self._make_request('scrobble', { 'id': str(self.track.id), 'submission': True }) self._make_request('scrobble', { 'id': str(self.trackid), 'submission': True })
self._make_request('scrobble', { 'id': str(self.track.id), 'submission': False }) self._make_request('scrobble', { 'id': str(self.trackid), 'submission': False })
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -9,10 +9,12 @@
# #
# Distributed under terms of the GNU AGPLv3 license. # Distributed under terms of the GNU AGPLv3 license.
from lxml import etree
import time import time
import uuid import uuid
from lxml import etree
from pony.orm import db_session
from supysonic.db import Folder, Artist, Album, Track from supysonic.db import Folder, Artist, Album, Track
from .apitestbase import ApiTestBase from .apitestbase import ApiTestBase
@ -21,60 +23,49 @@ class BrowseTestCase(ApiTestBase):
def setUp(self): def setUp(self):
super(BrowseTestCase, self).setUp() super(BrowseTestCase, self).setUp()
empty = Folder() with db_session:
empty.root = True Folder(root = True, name = 'Empty root', path = '/tmp')
empty.name = 'Empty root' root = Folder(root = True, name = 'Root folder', path = 'tests/assets')
empty.path = '/tmp'
self.store.add(empty)
root = Folder()
root.root = True
root.name = 'Root folder'
root.path = 'tests/assets'
self.store.add(root)
for letter in 'ABC': for letter in 'ABC':
folder = Folder() folder = Folder(
folder.name = letter + 'rtist' name = letter + 'rtist',
folder.path = 'tests/assets/{}rtist'.format(letter) path = 'tests/assets/{}rtist'.format(letter),
folder.parent = root parent = root
)
artist = Artist() artist = Artist(name = letter + 'rtist')
artist.name = letter + 'rtist'
for lether in 'AB': for lether in 'AB':
afolder = Folder() afolder = Folder(
afolder.name = letter + lether + 'lbum' name = letter + lether + 'lbum',
afolder.path = 'tests/assets/{0}rtist/{0}{1}lbum'.format(letter, lether) path = 'tests/assets/{0}rtist/{0}{1}lbum'.format(letter, lether),
afolder.parent = folder parent = folder
)
album = Album() album = Album(name = letter + lether + 'lbum', artist = artist)
album.name = letter + lether + 'lbum'
album.artist = artist
for num, song in enumerate([ 'One', 'Two', 'Three' ]): for num, song in enumerate([ 'One', 'Two', 'Three' ]):
track = Track() track = Track(
track.disc = 1 disc = 1,
track.number = num number = num,
track.title = song title = song,
track.duration = 2 duration = 2,
track.album = album album = album,
track.artist = artist artist = artist,
track.bitrate = 320 bitrate = 320,
track.path = 'tests/assets/{0}rtist/{0}{1}lbum/{2}'.format(letter, lether, song) path = 'tests/assets/{0}rtist/{0}{1}lbum/{2}'.format(letter, lether, song),
track.content_type = 'audio/mpeg' content_type = 'audio/mpeg',
track.last_modification = 0 last_modification = 0,
track.root_folder = root root_folder = root,
track.folder = afolder folder = afolder
self.store.add(track) )
self.store.commit() self.assertEqual(Folder.select().count(), 11)
self.assertEqual(Folder.select(lambda f: f.root).count(), 2)
self.assertEqual(self.store.find(Folder).count(), 11) self.assertEqual(Artist.select().count(), 3)
self.assertEqual(self.store.find(Folder, Folder.root == True).count(), 2) self.assertEqual(Album.select().count(), 6)
self.assertEqual(self.store.find(Artist).count(), 3) self.assertEqual(Track.select().count(), 18)
self.assertEqual(self.store.find(Album).count(), 6)
self.assertEqual(self.store.find(Track).count(), 18)
def test_get_music_folders(self): def test_get_music_folders(self):
# Do not validate against the XSD here, this is the only place where the API should return ids as ints # Do not validate against the XSD here, this is the only place where the API should return ids as ints
@ -91,7 +82,8 @@ class BrowseTestCase(ApiTestBase):
rv, child = self._make_request('getIndexes', { 'ifModifiedSince': int(time.time() * 1000 + 1000) }, tag = 'indexes') rv, child = self._make_request('getIndexes', { 'ifModifiedSince': int(time.time() * 1000 + 1000) }, tag = 'indexes')
self.assertEqual(len(child), 0) self.assertEqual(len(child), 0)
fid = self.store.find(Folder, Folder.name == 'Empty root').one().id with db_session:
fid = Folder.get(name = 'Empty root').id
rv, child = self._make_request('getIndexes', { 'musicFolderId': str(fid) }, tag = 'indexes') rv, child = self._make_request('getIndexes', { 'musicFolderId': str(fid) }, tag = 'indexes')
self.assertEqual(len(child), 0) self.assertEqual(len(child), 0)
@ -108,7 +100,8 @@ class BrowseTestCase(ApiTestBase):
self._make_request('getMusicDirectory', { 'id': str(uuid.uuid4()) }, error = 70) self._make_request('getMusicDirectory', { 'id': str(uuid.uuid4()) }, error = 70)
# should test with folders with both children folders and tracks. this code would break in that case # should test with folders with both children folders and tracks. this code would break in that case
for f in self.store.find(Folder): with db_session:
for f in Folder.select():
rv, child = self._make_request('getMusicDirectory', { 'id': str(f.id) }, tag = 'directory') rv, child = self._make_request('getMusicDirectory', { 'id': str(f.id) }, tag = 'directory')
self.assertEqual(child.get('id'), str(f.id)) self.assertEqual(child.get('id'), str(f.id))
self.assertEqual(child.get('name'), f.name) self.assertEqual(child.get('name'), f.name)
@ -138,7 +131,8 @@ class BrowseTestCase(ApiTestBase):
self._make_request('getArtist', { 'id': 'artist' }, error = 0) self._make_request('getArtist', { 'id': 'artist' }, error = 0)
self._make_request('getArtist', { 'id': str(uuid.uuid4()) }, error = 70) self._make_request('getArtist', { 'id': str(uuid.uuid4()) }, error = 70)
for ar in self.store.find(Artist): with db_session:
for ar in Artist.select():
rv, child = self._make_request('getArtist', { 'id': str(ar.id) }, tag = 'artist') rv, child = self._make_request('getArtist', { 'id': str(ar.id) }, tag = 'artist')
self.assertEqual(child.get('id'), str(ar.id)) self.assertEqual(child.get('id'), str(ar.id))
self.assertEqual(child.get('albumCount'), str(len(child))) self.assertEqual(child.get('albumCount'), str(len(child)))
@ -153,7 +147,8 @@ class BrowseTestCase(ApiTestBase):
self._make_request('getAlbum', { 'id': 'nastynasty' }, error = 0) self._make_request('getAlbum', { 'id': 'nastynasty' }, error = 0)
self._make_request('getAlbum', { 'id': str(uuid.uuid4()) }, error = 70) self._make_request('getAlbum', { 'id': str(uuid.uuid4()) }, error = 70)
a = self.store.find(Album)[0] with db_session:
a = Album.select().first()
rv, child = self._make_request('getAlbum', { 'id': str(a.id) }, tag = 'album') rv, child = self._make_request('getAlbum', { 'id': str(a.id) }, tag = 'album')
self.assertEqual(child.get('id'), str(a.id)) self.assertEqual(child.get('id'), str(a.id))
self.assertEqual(child.get('songCount'), str(len(child))) self.assertEqual(child.get('songCount'), str(len(child)))
@ -169,7 +164,8 @@ class BrowseTestCase(ApiTestBase):
self._make_request('getSong', { 'id': 'nastynasty' }, error = 0) self._make_request('getSong', { 'id': 'nastynasty' }, error = 0)
self._make_request('getSong', { 'id': str(uuid.uuid4()) }, error = 70) self._make_request('getSong', { 'id': str(uuid.uuid4()) }, error = 70)
s = self.store.find(Track)[0] with db_session:
s = Track.select().first()
self._make_request('getSong', { 'id': str(s.id) }, tag = 'song') self._make_request('getSong', { 'id': str(s.id) }, tag = 'song')
def test_get_videos(self): def test_get_videos(self):

View File

@ -11,8 +11,10 @@
import os.path import os.path
import uuid import uuid
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from pony.orm import db_session
from supysonic.db import Folder, Artist, Album, Track from supysonic.db import Folder, Artist, Album, Track
@ -22,69 +24,69 @@ class MediaTestCase(ApiTestBase):
def setUp(self): def setUp(self):
super(MediaTestCase, self).setUp() super(MediaTestCase, self).setUp()
self.folder = Folder() with db_session:
self.folder.name = 'Root' folder = Folder(
self.folder.path = os.path.abspath('tests/assets') name = 'Root',
self.folder.root = True path = os.path.abspath('tests/assets'),
self.folder.has_cover_art = True # 420x420 PNG root = True,
has_cover_art = True # 420x420 PNG
)
self.folderid = folder.id
artist = Artist() artist = Artist(name = 'Artist')
artist.name = 'Artist' album = Album(artist = artist, name = 'Album')
album = Album() track = Track(
album.artist = artist title = '23bytes',
album.name = 'Album' number = 1,
disc = 1,
self.track = Track() artist = artist,
self.track.title = '23bytes' album = album,
self.track.number = 1 path = os.path.abspath('tests/assets/23bytes'),
self.track.disc = 1 root_folder = folder,
self.track.artist = artist folder = folder,
self.track.album = album duration = 2,
self.track.path = os.path.abspath('tests/assets/23bytes') bitrate = 320,
self.track.root_folder = self.folder content_type = 'audio/mpeg',
self.track.folder = self.folder last_modification = 0
self.track.duration = 2 )
self.track.bitrate = 320 self.trackid = track.id
self.track.content_type = 'audio/mpeg'
self.track.last_modification = 0
self.store.add(self.track)
self.store.commit()
def test_stream(self): def test_stream(self):
self._make_request('stream', error = 10) self._make_request('stream', error = 10)
self._make_request('stream', { 'id': 'string' }, error = 0) self._make_request('stream', { 'id': 'string' }, error = 0)
self._make_request('stream', { 'id': str(uuid.uuid4()) }, error = 70) self._make_request('stream', { 'id': str(uuid.uuid4()) }, error = 70)
self._make_request('stream', { 'id': str(self.folder.id) }, error = 70) self._make_request('stream', { 'id': str(self.folderid) }, error = 70)
self._make_request('stream', { 'id': str(self.track.id), 'maxBitRate': 'string' }, error = 0) self._make_request('stream', { 'id': str(self.trackid), 'maxBitRate': 'string' }, error = 0)
rv = self.client.get('/rest/stream.view', query_string = { 'u': 'alice', 'p': 'Alic3', 'c': 'tests', 'id': str(self.track.id) }) rv = self.client.get('/rest/stream.view', query_string = { 'u': 'alice', 'p': 'Alic3', 'c': 'tests', 'id': str(self.trackid) })
self.assertEqual(rv.status_code, 200) self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.mimetype, 'audio/mpeg') self.assertEqual(rv.mimetype, 'audio/mpeg')
self.assertEqual(len(rv.data), 23) self.assertEqual(len(rv.data), 23)
self.assertEqual(self.track.play_count, 1) with db_session:
self.assertEqual(Track[self.trackid].play_count, 1)
def test_download(self): def test_download(self):
self._make_request('download', error = 10) self._make_request('download', error = 10)
self._make_request('download', { 'id': 'string' }, error = 0) self._make_request('download', { 'id': 'string' }, error = 0)
self._make_request('download', { 'id': str(uuid.uuid4()) }, error = 70) self._make_request('download', { 'id': str(uuid.uuid4()) }, error = 70)
self._make_request('download', { 'id': str(self.folder.id) }, error = 70) self._make_request('download', { 'id': str(self.folderid) }, error = 70)
rv = self.client.get('/rest/download.view', query_string = { 'u': 'alice', 'p': 'Alic3', 'c': 'tests', 'id': str(self.track.id) }) rv = self.client.get('/rest/download.view', query_string = { 'u': 'alice', 'p': 'Alic3', 'c': 'tests', 'id': str(self.trackid) })
self.assertEqual(rv.status_code, 200) self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.mimetype, 'audio/mpeg') self.assertEqual(rv.mimetype, 'audio/mpeg')
self.assertEqual(len(rv.data), 23) self.assertEqual(len(rv.data), 23)
self.assertEqual(self.track.play_count, 0) with db_session:
self.assertEqual(Track[self.trackid].play_count, 0)
def test_get_cover_art(self): def test_get_cover_art(self):
self._make_request('getCoverArt', error = 10) self._make_request('getCoverArt', error = 10)
self._make_request('getCoverArt', { 'id': 'string' }, error = 0) self._make_request('getCoverArt', { 'id': 'string' }, error = 0)
self._make_request('getCoverArt', { 'id': str(uuid.uuid4()) }, error = 70) self._make_request('getCoverArt', { 'id': str(uuid.uuid4()) }, error = 70)
self._make_request('getCoverArt', { 'id': str(self.track.id) }, error = 70) self._make_request('getCoverArt', { 'id': str(self.trackid) }, error = 70)
self._make_request('getCoverArt', { 'id': str(self.folder.id), 'size': 'large' }, error = 0) self._make_request('getCoverArt', { 'id': str(self.folderid), 'size': 'large' }, error = 0)
args = { 'u': 'alice', 'p': 'Alic3', 'c': 'tests', 'id': str(self.folder.id) } args = { 'u': 'alice', 'p': 'Alic3', 'c': 'tests', 'id': str(self.folderid) }
rv = self.client.get('/rest/getCoverArt.view', query_string = args) rv = self.client.get('/rest/getCoverArt.view', query_string = args)
self.assertEqual(rv.status_code, 200) self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.mimetype, 'image/jpeg') self.assertEqual(rv.mimetype, 'image/jpeg')

View File

@ -11,6 +11,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import Folder, Artist, Album, Track, Playlist, User from supysonic.db import Folder, Artist, Album, Track, Playlist, User
from .apitestbase import ApiTestBase from .apitestbase import ApiTestBase
@ -19,63 +21,42 @@ class PlaylistTestCase(ApiTestBase):
def setUp(self): def setUp(self):
super(PlaylistTestCase, self).setUp() super(PlaylistTestCase, self).setUp()
root = Folder() with db_session:
root.root = True root = Folder(root = True, name = 'Root folder', path = 'tests/assets')
root.name = 'Root folder' artist = Artist(name = 'Artist')
root.path = 'tests/assets' album = Album(name = 'Album', artist = artist)
self.store.add(root)
artist = Artist()
artist.name = 'Artist'
album = Album()
album.name = 'Album'
album.artist = artist
songs = {} songs = {}
for num, song in enumerate([ 'One', 'Two', 'Three', 'Four' ]): for num, song in enumerate([ 'One', 'Two', 'Three', 'Four' ]):
track = Track() track = Track(
track.disc = 1 disc = 1,
track.number = num number = num,
track.title = song title = song,
track.duration = 2 duration = 2,
track.album = album album = album,
track.artist = artist artist = artist,
track.bitrate = 320 bitrate = 320,
track.path = 'tests/assets/empty' path = 'tests/assets/' + song,
track.content_type = 'audio/mpeg' content_type = 'audio/mpeg',
track.last_modification = 0 last_modification = 0,
track.root_folder = root root_folder = root,
track.folder = root folder = root
)
self.store.add(track)
songs[song] = track songs[song] = track
users = { u.name: u for u in self.store.find(User) } users = { u.name: u for u in User.select() }
playlist = Playlist() playlist = Playlist(user = users['alice'], name = "Alice's")
playlist.user = users['alice']
playlist.name = "Alice's"
playlist.add(songs['One']) playlist.add(songs['One'])
playlist.add(songs['Three']) playlist.add(songs['Three'])
self.store.add(playlist)
playlist = Playlist() playlist = Playlist(user = users['alice'], public = True, name = "Alice's public")
playlist.user = users['alice']
playlist.public = True
playlist.name = "Alice's public"
playlist.add(songs['One']) playlist.add(songs['One'])
playlist.add(songs['Two']) playlist.add(songs['Two'])
self.store.add(playlist)
playlist = Playlist() playlist = Playlist(user = users['bob'], name = "Bob's")
playlist.user = users['bob']
playlist.name = "Bob's"
playlist.add(songs['Two']) playlist.add(songs['Two'])
playlist.add(songs['Four']) playlist.add(songs['Four'])
self.store.add(playlist)
self.store.commit()
def test_get_playlists(self): def test_get_playlists(self):
# get own playlists # get own playlists
@ -113,7 +94,8 @@ class PlaylistTestCase(ApiTestBase):
self._make_request('getPlaylist', { 'id': str(uuid.uuid4()) }, error = 70) self._make_request('getPlaylist', { 'id': str(uuid.uuid4()) }, error = 70)
# other's private from non admin # other's private from non admin
playlist = self.store.find(Playlist, Playlist.public == False, Playlist.user_id == User.id, User.name == 'alice').one() with db_session:
playlist = Playlist.get(lambda p: not p.public == False and p.user.name == 'alice')
self._make_request('getPlaylist', { 'u': 'bob', 'p': 'B0b', 'id': str(playlist.id) }, error = 50) self._make_request('getPlaylist', { 'u': 'bob', 'p': 'B0b', 'id': str(playlist.id) }, error = 50)
# standard # standard
@ -156,9 +138,11 @@ class PlaylistTestCase(ApiTestBase):
self._make_request('createPlaylist', { 'u': 'bob', 'p': 'B0b', 'playlistId': playlist.get('id') }, error = 50) self._make_request('createPlaylist', { 'u': 'bob', 'p': 'B0b', 'playlistId': playlist.get('id') }, error = 50)
# create more useful playlist # create more useful playlist
songs = { s.title: str(s.id) for s in self.store.find(Track) } with db_session:
songs = { s.title: str(s.id) for s in Track.select() }
self._make_request('createPlaylist', { 'name': 'songs', 'songId': map(lambda s: songs[s], [ 'Three', 'One', 'Two' ]) }, skip_post = True) self._make_request('createPlaylist', { 'name': 'songs', 'songId': map(lambda s: songs[s], [ 'Three', 'One', 'Two' ]) }, skip_post = True)
playlist = self.store.find(Playlist, Playlist.name == 'songs').one() with db_session:
playlist = Playlist.get(name = 'songs')
self.assertIsNotNone(playlist) self.assertIsNotNone(playlist)
rv, child = self._make_request('getPlaylist', { 'id': str(playlist.id) }, tag = 'playlist') rv, child = self._make_request('getPlaylist', { 'id': str(playlist.id) }, tag = 'playlist')
self.assertEqual(child.get('songCount'), '3') self.assertEqual(child.get('songCount'), '3')
@ -174,6 +158,10 @@ class PlaylistTestCase(ApiTestBase):
self.assertEqual(self._xpath(child, 'count(./entry)'), 1) self.assertEqual(self._xpath(child, 'count(./entry)'), 1)
self.assertEqual(child[0].get('title'), 'Two') self.assertEqual(child[0].get('title'), 'Two')
@db_session
def assertPlaylistCountEqual(self, count):
self.assertEqual(Playlist.select().count(), count)
def test_delete_playlist(self): def test_delete_playlist(self):
# check params # check params
self._make_request('deletePlaylist', error = 10) self._make_request('deletePlaylist', error = 10)
@ -181,27 +169,30 @@ class PlaylistTestCase(ApiTestBase):
self._make_request('deletePlaylist', { 'id': str(uuid.uuid4()) }, error = 70) self._make_request('deletePlaylist', { 'id': str(uuid.uuid4()) }, error = 70)
# delete unowned when not admin # delete unowned when not admin
playlist = self.store.find(Playlist, Playlist.user_id == User.id, User.name == 'alice')[0] with db_session:
playlist = Playlist.select(lambda p: p.user.name == 'alice').first()
self._make_request('deletePlaylist', { 'u': 'bob', 'p': 'B0b', 'id': str(playlist.id) }, error = 50) self._make_request('deletePlaylist', { 'u': 'bob', 'p': 'B0b', 'id': str(playlist.id) }, error = 50)
self.assertEqual(self.store.find(Playlist).count(), 3) self.assertPlaylistCountEqual(3);
# delete owned # delete owned
self._make_request('deletePlaylist', { 'id': str(playlist.id) }, skip_post = True) self._make_request('deletePlaylist', { 'id': str(playlist.id) }, skip_post = True)
self.assertEqual(self.store.find(Playlist).count(), 2) self.assertPlaylistCountEqual(2);
self._make_request('deletePlaylist', { 'id': str(playlist.id) }, error = 70) self._make_request('deletePlaylist', { 'id': str(playlist.id) }, error = 70)
self.assertEqual(self.store.find(Playlist).count(), 2) self.assertPlaylistCountEqual(2);
# delete unowned when admin # delete unowned when admin
playlist = self.store.find(Playlist, Playlist.user_id == User.id, User.name == 'bob').one() with db_session:
playlist = Playlist.get(lambda p: p.user.name == 'bob')
self._make_request('deletePlaylist', { 'id': str(playlist.id) }, skip_post = True) self._make_request('deletePlaylist', { 'id': str(playlist.id) }, skip_post = True)
self.assertEqual(self.store.find(Playlist).count(), 1) self.assertPlaylistCountEqual(1);
def test_update_playlist(self): def test_update_playlist(self):
self._make_request('updatePlaylist', error = 10) self._make_request('updatePlaylist', error = 10)
self._make_request('updatePlaylist', { 'playlistId': 1234 }, error = 0) self._make_request('updatePlaylist', { 'playlistId': 1234 }, error = 0)
self._make_request('updatePlaylist', { 'playlistId': str(uuid.uuid4()) }, error = 70) self._make_request('updatePlaylist', { 'playlistId': str(uuid.uuid4()) }, error = 70)
playlist = self.store.find(Playlist, Playlist.user_id == User.id, User.name == 'alice')[0] with db_session:
playlist = Playlist.select(lambda p: p.user.name == 'alice').order_by(Playlist.created).first()
pid = str(playlist.id) pid = str(playlist.id)
self._make_request('updatePlaylist', { 'playlistId': pid, 'songIdToAdd': 'string' }, error = 0) self._make_request('updatePlaylist', { 'playlistId': pid, 'songIdToAdd': 'string' }, error = 0)
self._make_request('updatePlaylist', { 'playlistId': pid, 'songIndexToRemove': 'string' }, error = 0) self._make_request('updatePlaylist', { 'playlistId': pid, 'songIndexToRemove': 'string' }, error = 0)
@ -226,7 +217,8 @@ class PlaylistTestCase(ApiTestBase):
self.assertEqual(self._xpath(child, 'count(./entry)'), 1) self.assertEqual(self._xpath(child, 'count(./entry)'), 1)
self.assertEqual(self._find(child, './entry').get('title'), 'Three') self.assertEqual(self._find(child, './entry').get('title'), 'Three')
songs = { s.title: str(s.id) for s in self.store.find(Track) } with db_session:
songs = { s.title: str(s.id) for s in Track.select() }
self._make_request('updatePlaylist', { 'playlistId': pid, 'songIdToAdd': [ songs['One'], songs['Two'], songs['Two'] ] }, skip_post = True) self._make_request('updatePlaylist', { 'playlistId': pid, 'songIdToAdd': [ songs['One'], songs['Two'], songs['Two'] ] }, skip_post = True)
rv, child = self._make_request('getPlaylist', { 'id': pid }, tag = 'playlist') rv, child = self._make_request('getPlaylist', { 'id': pid }, tag = 'playlist')

View File

@ -12,6 +12,8 @@
import time import time
import unittest import unittest
from pony.orm import db_session, commit
from supysonic.db import Folder, Artist, Album, Track from supysonic.db import Folder, Artist, Album, Track
from .apitestbase import ApiTestBase from .apitestbase import ApiTestBase
@ -20,53 +22,44 @@ class SearchTestCase(ApiTestBase):
def setUp(self): def setUp(self):
super(SearchTestCase, self).setUp() super(SearchTestCase, self).setUp()
root = Folder() with db_session:
root.root = True root = Folder(root = True, name = 'Root folder', path = 'tests/assets')
root.name = 'Root folder'
root.path = 'tests/assets'
self.store.add(root)
for letter in 'ABC': for letter in 'ABC':
folder = Folder() folder = Folder(name = letter + 'rtist', path = 'tests/assets/{}rtist'.format(letter), parent = root)
folder.name = letter + 'rtist' artist = Artist(name = letter + 'rtist')
folder.path = 'tests/assets/{}rtist'.format(letter)
folder.parent = root
artist = Artist()
artist.name = letter + 'rtist'
for lether in 'AB': for lether in 'AB':
afolder = Folder() afolder = Folder(
afolder.name = letter + lether + 'lbum' name = letter + lether + 'lbum',
afolder.path = 'tests/assets/{0}rtist/{0}{1}lbum'.format(letter, lether) path = 'tests/assets/{0}rtist/{0}{1}lbum'.format(letter, lether),
afolder.parent = folder parent = folder
)
album = Album() album = Album(name = letter + lether + 'lbum', artist = artist)
album.name = letter + lether + 'lbum'
album.artist = artist
for num, song in enumerate([ 'One', 'Two', 'Three' ]): for num, song in enumerate([ 'One', 'Two', 'Three' ]):
track = Track() track = Track(
track.disc = 1 disc = 1,
track.number = num number = num,
track.title = song title = song,
track.duration = 2 duration = 2,
track.album = album album = album,
track.artist = artist artist = artist,
track.bitrate = 320 bitrate = 320,
track.path = 'tests/assets/{0}rtist/{0}{1}lbum/{2}'.format(letter, lether, song) path = 'tests/assets/{0}rtist/{0}{1}lbum/{2}'.format(letter, lether, song),
track.content_type = 'audio/mpeg' content_type = 'audio/mpeg',
track.last_modification = 0 last_modification = 0,
track.root_folder = root root_folder = root,
track.folder = afolder folder = afolder
self.store.add(track) )
self.store.commit() commit()
self.assertEqual(self.store.find(Folder).count(), 10) self.assertEqual(Folder.select().count(), 10)
self.assertEqual(self.store.find(Artist).count(), 3) self.assertEqual(Artist.select().count(), 3)
self.assertEqual(self.store.find(Album).count(), 6) self.assertEqual(Album.select().count(), 6)
self.assertEqual(self.store.find(Track).count(), 18) self.assertEqual(Track.select().count(), 18)
def __track_as_pseudo_unique_str(self, elem): def __track_as_pseudo_unique_str(self, elem):
return elem.get('artist') + elem.get('album') + elem.get('title') return elem.get('artist') + elem.get('album') + elem.get('title')

View File

@ -11,6 +11,8 @@
import unittest import unittest
from pony.orm import db_session
from supysonic.db import Folder, Track from supysonic.db import Folder, Track
from supysonic.managers.folder import FolderManager from supysonic.managers.folder import FolderManager
from supysonic.scanner import Scanner from supysonic.scanner import Scanner
@ -23,12 +25,13 @@ class TranscodingTestCase(ApiTestBase):
super(TranscodingTestCase, self).setUp() super(TranscodingTestCase, self).setUp()
FolderManager.add(self.store, 'Folder', 'tests/assets/folder') FolderManager.add('Folder', 'tests/assets/folder')
scanner = Scanner(self.store) scanner = Scanner()
scanner.scan(self.store.find(Folder).one()) with db_session:
scanner.scan(Folder.get())
scanner.finish() scanner.finish()
self.trackid = self.store.find(Track).one().id self.trackid = Track.get().id
def _stream(self, **kwargs): def _stream(self, **kwargs):
kwargs.update({ 'u': 'alice', 'p': 'Alic3', 'c': 'tests', 'v': '1.8.0', 'id': self.trackid }) kwargs.update({ 'u': 'alice', 'p': 'Alic3', 'c': 'tests', 'v': '1.8.0', 'id': self.trackid })

View File

@ -15,27 +15,22 @@ import tempfile
import unittest import unittest
from contextlib import contextmanager from contextlib import contextmanager
from pony.orm import db_session
from StringIO import StringIO from StringIO import StringIO
from supysonic.db import Folder, User, get_store from supysonic.db import Folder, User, init_database, release_database
from supysonic.cli import SupysonicCLI from supysonic.cli import SupysonicCLI
from ..testbase import TestConfig from ..testbase import TestConfig
class CLITestCase(unittest.TestCase): class CLITestCase(unittest.TestCase):
""" Really basic tests. Some even don't check anything but are juste there for coverage """ """ Really basic tests. Some even don't check anything but are just there for coverage """
def setUp(self): def setUp(self):
conf = TestConfig(False, False) conf = TestConfig(False, False)
self.__dbfile = tempfile.mkstemp()[1] self.__dbfile = tempfile.mkstemp()[1]
conf.BASE['database_uri'] = 'sqlite:///' + self.__dbfile conf.BASE['database_uri'] = 'sqlite:///' + self.__dbfile
self.__store = get_store(conf.BASE['database_uri']) init_database(conf.BASE['database_uri'], True)
with io.open('schema/sqlite.sql', 'r') as sql:
schema = sql.read()
for statement in schema.split(';'):
self.__store.execute(statement)
self.__store.commit()
self.__stdout = StringIO() self.__stdout = StringIO()
self.__stderr = StringIO() self.__stderr = StringIO()
@ -44,7 +39,7 @@ class CLITestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.__stdout.close() self.__stdout.close()
self.__stderr.close() self.__stderr.close()
self.__store.close() release_database()
os.unlink(self.__dbfile) os.unlink(self.__dbfile)
@contextmanager @contextmanager
@ -59,7 +54,8 @@ class CLITestCase(unittest.TestCase):
with self._tempdir() as d: with self._tempdir() as d:
self.__cli.onecmd('folder add tmpfolder ' + d) self.__cli.onecmd('folder add tmpfolder ' + d)
f = self.__store.find(Folder).one() with db_session:
f = Folder.select().first()
self.assertIsNotNone(f) self.assertIsNotNone(f)
self.assertEqual(f.path, d) self.assertEqual(f.path, d)
@ -71,14 +67,17 @@ class CLITestCase(unittest.TestCase):
self.__cli.onecmd('folder add f1 ' + d) self.__cli.onecmd('folder add f1 ' + d)
self.__cli.onecmd('folder add f3 /invalid/path') self.__cli.onecmd('folder add f3 /invalid/path')
self.assertEqual(self.__store.find(Folder).count(), 1) with db_session:
self.assertEqual(Folder.select().count(), 1)
def test_folder_delete(self): def test_folder_delete(self):
with self._tempdir() as d: with self._tempdir() as d:
self.__cli.onecmd('folder add tmpfolder ' + d) self.__cli.onecmd('folder add tmpfolder ' + d)
self.__cli.onecmd('folder delete randomfolder') self.__cli.onecmd('folder delete randomfolder')
self.__cli.onecmd('folder delete tmpfolder') self.__cli.onecmd('folder delete tmpfolder')
self.assertEqual(self.__store.find(Folder).count(), 0)
with db_session:
self.assertEqual(Folder.select().count(), 0)
def test_folder_list(self): def test_folder_list(self):
with self._tempdir() as d: with self._tempdir() as d:
@ -97,13 +96,17 @@ class CLITestCase(unittest.TestCase):
def test_user_add(self): def test_user_add(self):
self.__cli.onecmd('user add -p Alic3 alice') self.__cli.onecmd('user add -p Alic3 alice')
self.__cli.onecmd('user add -p alice alice') self.__cli.onecmd('user add -p alice alice')
self.assertEqual(self.__store.find(User).count(), 1)
with db_session:
self.assertEqual(User.select().count(), 1)
def test_user_delete(self): def test_user_delete(self):
self.__cli.onecmd('user add -p Alic3 alice') self.__cli.onecmd('user add -p Alic3 alice')
self.__cli.onecmd('user delete alice') self.__cli.onecmd('user delete alice')
self.__cli.onecmd('user delete bob') self.__cli.onecmd('user delete bob')
self.assertEqual(self.__store.find(User).count(), 0)
with db_session:
self.assertEqual(User.select().count(), 0)
def test_user_list(self): def test_user_list(self):
self.__cli.onecmd('user add -p Alic3 alice') self.__cli.onecmd('user add -p Alic3 alice')
@ -114,7 +117,8 @@ class CLITestCase(unittest.TestCase):
self.__cli.onecmd('user add -p Alic3 alice') self.__cli.onecmd('user add -p Alic3 alice')
self.__cli.onecmd('user setadmin alice') self.__cli.onecmd('user setadmin alice')
self.__cli.onecmd('user setadmin bob') self.__cli.onecmd('user setadmin bob')
self.assertTrue(self.__store.find(User, User.name == 'alice').one().admin) with db_session:
self.assertTrue(User.get(name = 'alice').admin)
def test_user_changepass(self): def test_user_changepass(self):
self.__cli.onecmd('user add -p Alic3 alice') self.__cli.onecmd('user add -p Alic3 alice')

View File

@ -9,41 +9,38 @@
# #
# Distributed under terms of the GNU AGPLv3 license. # Distributed under terms of the GNU AGPLv3 license.
import re
import unittest import unittest
import io, re
from collections import namedtuple
import uuid import uuid
from collections import namedtuple
from pony.orm import db_session
from supysonic import db from supysonic import db
date_regex = re.compile(r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$') date_regex = re.compile(r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$')
class DbTestCase(unittest.TestCase): class DbTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = db.get_store(u'sqlite:') db.init_database('sqlite:', True)
with io.open(u'schema/sqlite.sql', u'r') as f:
for statement in f.read().split(u';'):
self.store.execute(statement)
def tearDown(self): def tearDown(self):
self.store.close() db.release_database()
def create_some_folders(self): def create_some_folders(self):
root_folder = db.Folder() root_folder = db.Folder(
root_folder.root = True root = True,
root_folder.name = u'Root folder' name = 'Root folder',
root_folder.path = u'tests' path = 'tests'
)
child_folder = db.Folder() child_folder = db.Folder(
child_folder.root = False root = False,
child_folder.name = u'Child folder' name = 'Child folder',
child_folder.path = u'tests/assets' path = 'tests/assets',
child_folder.has_cover_art = True has_cover_art = True,
child_folder.parent = root_folder parent = root_folder
)
self.store.add(root_folder)
self.store.add(child_folder)
self.store.commit()
return root_folder, child_folder return root_folder, child_folder
@ -51,175 +48,152 @@ class DbTestCase(unittest.TestCase):
root, child = self.create_some_folders() root, child = self.create_some_folders()
if not artist: if not artist:
artist = db.Artist() artist = db.Artist(name = 'Test artist')
artist.name = u'Test Artist'
if not album: if not album:
album = db.Album() album = db.Album(artist = artist, name = 'Test Album')
album.artist = artist
album.name = u'Test Album'
track1 = db.Track() track1 = db.Track(
track1.title = u'Track Title' title = 'Track Title',
track1.album = album album = album,
track1.artist = artist artist = artist,
track1.disc = 1 disc = 1,
track1.number = 1 number = 1,
track1.duration = 3 duration = 3,
track1.bitrate = 320 bitrate = 320,
track1.path = u'tests/assets/empty' path = 'tests/assets/empty',
track1.content_type = u'audio/mpeg' content_type = 'audio/mpeg',
track1.last_modification = 1234 last_modification = 1234,
track1.root_folder = root root_folder = root,
track1.folder = child folder = child
self.store.add(track1) )
track2 = db.Track() track2 = db.Track(
track2.title = u'One Awesome Song' title = 'One Awesome Song',
track2.album = album album = album,
track2.artist = artist artist = artist,
track2.disc = 1 disc = 1,
track2.number = 2 number = 2,
track2.duration = 5 duration = 5,
track2.bitrate = 96 bitrate = 96,
track2.path = u'tests/assets/empty' path = 'tests/assets/23bytes',
track2.content_type = u'audio/mpeg' content_type = 'audio/mpeg',
track2.last_modification = 1234 last_modification = 1234,
track2.root_folder = root root_folder = root,
track2.folder = child folder = child
self.store.add(track2) )
return track1, track2 return track1, track2
def create_playlist(self): def create_user(self, name = 'Test User'):
user = db.User() return db.User(
user.name = u'Test User' name = name,
user.password = u'secret' password = 'secret',
user.salt = u'ABC+' salt = 'ABC+',
)
playlist = db.Playlist() def create_playlist(self):
playlist.user = user
playlist.name = u'Playlist!' playlist = db.Playlist(
self.store.add(playlist) user = self.create_user(),
name = 'Playlist!'
)
return playlist return playlist
@db_session
def test_folder_base(self): def test_folder_base(self):
root_folder, child_folder = self.create_some_folders() root_folder, child_folder = self.create_some_folders()
MockUser = namedtuple(u'User', [ u'id' ]) MockUser = namedtuple('User', [ 'id' ])
user = MockUser(uuid.uuid4()) user = MockUser(uuid.uuid4())
root = root_folder.as_subsonic_child(user) root = root_folder.as_subsonic_child(user)
self.assertIsInstance(root, dict) self.assertIsInstance(root, dict)
self.assertIn(u'id', root) self.assertIn('id', root)
self.assertIn(u'isDir', root) self.assertIn('isDir', root)
self.assertIn(u'title', root) self.assertIn('title', root)
self.assertIn(u'album', root) self.assertIn('album', root)
self.assertIn(u'created', root) self.assertIn('created', root)
self.assertTrue(root[u'isDir']) self.assertTrue(root['isDir'])
self.assertEqual(root[u'title'], u'Root folder') self.assertEqual(root['title'], 'Root folder')
self.assertEqual(root[u'album'], u'Root folder') self.assertEqual(root['album'], 'Root folder')
self.assertRegexpMatches(root['created'], date_regex) self.assertRegexpMatches(root['created'], date_regex)
child = child_folder.as_subsonic_child(user) child = child_folder.as_subsonic_child(user)
self.assertIn(u'parent', child) self.assertIn('parent', child)
self.assertIn(u'artist', child) self.assertIn('artist', child)
self.assertIn(u'coverArt', child) self.assertIn('coverArt', child)
self.assertEqual(child[u'parent'], str(root_folder.id)) self.assertEqual(child['parent'], str(root_folder.id))
self.assertEqual(child[u'artist'], root_folder.name) self.assertEqual(child['artist'], root_folder.name)
self.assertEqual(child[u'coverArt'], child[u'id']) self.assertEqual(child['coverArt'], child['id'])
@db_session
def test_folder_annotation(self): def test_folder_annotation(self):
root_folder, child_folder = self.create_some_folders() root_folder, child_folder = self.create_some_folders()
# Assuming SQLite doesn't enforce foreign key constraints user = self.create_user()
MockUser = namedtuple(u'User', [ u'id' ]) star = db.StarredFolder(
user = MockUser(uuid.uuid4()) user = user,
starred = root_folder
star = db.StarredFolder() )
star.user_id = user.id rating_user = db.RatingFolder(
star.starred_id = root_folder.id user = user,
rated = root_folder,
rating_user = db.RatingFolder() rating = 2
rating_user.user_id = user.id )
rating_user.rated_id = root_folder.id other = self.create_user('Other')
rating_user.rating = 2 rating_other = db.RatingFolder(
user = other,
rating_other = db.RatingFolder() rated = root_folder,
rating_other.user_id = uuid.uuid4() rating = 5
rating_other.rated_id = root_folder.id )
rating_other.rating = 5
self.store.add(star)
self.store.add(rating_user)
self.store.add(rating_other)
root = root_folder.as_subsonic_child(user) root = root_folder.as_subsonic_child(user)
self.assertIn(u'starred', root) self.assertIn('starred', root)
self.assertIn(u'userRating', root) self.assertIn('userRating', root)
self.assertIn(u'averageRating', root) self.assertIn('averageRating', root)
self.assertRegexpMatches(root[u'starred'], date_regex) self.assertRegexpMatches(root['starred'], date_regex)
self.assertEqual(root[u'userRating'], 2) self.assertEqual(root['userRating'], 2)
self.assertEqual(root[u'averageRating'], 3.5) self.assertEqual(root['averageRating'], 3.5)
child = child_folder.as_subsonic_child(user) child = child_folder.as_subsonic_child(user)
self.assertNotIn(u'starred', child) self.assertNotIn('starred', child)
self.assertNotIn(u'userRating', child) self.assertNotIn('userRating', child)
@db_session
def test_artist(self): def test_artist(self):
artist = db.Artist() artist = db.Artist(name = 'Test Artist')
artist.name = u'Test Artist'
self.store.add(artist)
# Assuming SQLite doesn't enforce foreign key constraints user = self.create_user()
MockUser = namedtuple(u'User', [ u'id' ]) star = db.StarredArtist(user = user, starred = artist)
user = MockUser(uuid.uuid4())
star = db.StarredArtist()
star.user_id = user.id
star.starred_id = artist.id
self.store.add(star)
artist_dict = artist.as_subsonic_artist(user) artist_dict = artist.as_subsonic_artist(user)
self.assertIsInstance(artist_dict, dict) self.assertIsInstance(artist_dict, dict)
self.assertIn(u'id', artist_dict) self.assertIn('id', artist_dict)
self.assertIn(u'name', artist_dict) self.assertIn('name', artist_dict)
self.assertIn(u'albumCount', artist_dict) self.assertIn('albumCount', artist_dict)
self.assertIn(u'starred', artist_dict) self.assertIn('starred', artist_dict)
self.assertEqual(artist_dict[u'name'], u'Test Artist') self.assertEqual(artist_dict['name'], 'Test Artist')
self.assertEqual(artist_dict[u'albumCount'], 0) self.assertEqual(artist_dict['albumCount'], 0)
self.assertRegexpMatches(artist_dict[u'starred'], date_regex) self.assertRegexpMatches(artist_dict['starred'], date_regex)
album = db.Album() db.Album(name = 'Test Artist', artist = artist) # self-titled
album.name = u'Test Artist' # self-titled db.Album(name = 'The Album After The First One', artist = artist)
artist.albums.add(album)
album = db.Album()
album.name = u'The Album After The Frist One'
artist.albums.add(album)
artist_dict = artist.as_subsonic_artist(user) artist_dict = artist.as_subsonic_artist(user)
self.assertEqual(artist_dict[u'albumCount'], 2) self.assertEqual(artist_dict['albumCount'], 2)
@db_session
def test_album(self): def test_album(self):
artist = db.Artist() artist = db.Artist(name = 'Test Artist')
artist.name = u'Test Artist' album = db.Album(artist = artist, name = 'Test Album')
album = db.Album() user = self.create_user()
album.artist = artist star = db.StarredAlbum(
album.name = u'Test Album' user = user,
starred = album
# Assuming SQLite doesn't enforce foreign key constraints )
MockUser = namedtuple(u'User', [ u'id' ])
user = MockUser(uuid.uuid4())
star = db.StarredAlbum()
star.user_id = user.id
star.starred = album
self.store.add(album)
self.store.add(star)
# No tracks, shouldn't be stored under normal circumstances # No tracks, shouldn't be stored under normal circumstances
self.assertRaises(ValueError, album.as_subsonic_album, user) self.assertRaises(ValueError, album.as_subsonic_album, user)
@ -228,67 +202,67 @@ class DbTestCase(unittest.TestCase):
album_dict = album.as_subsonic_album(user) album_dict = album.as_subsonic_album(user)
self.assertIsInstance(album_dict, dict) self.assertIsInstance(album_dict, dict)
self.assertIn(u'id', album_dict) self.assertIn('id', album_dict)
self.assertIn(u'name', album_dict) self.assertIn('name', album_dict)
self.assertIn(u'artist', album_dict) self.assertIn('artist', album_dict)
self.assertIn(u'artistId', album_dict) self.assertIn('artistId', album_dict)
self.assertIn(u'songCount', album_dict) self.assertIn('songCount', album_dict)
self.assertIn(u'duration', album_dict) self.assertIn('duration', album_dict)
self.assertIn(u'created', album_dict) self.assertIn('created', album_dict)
self.assertIn(u'starred', album_dict) self.assertIn('starred', album_dict)
self.assertEqual(album_dict[u'name'], album.name) self.assertEqual(album_dict['name'], album.name)
self.assertEqual(album_dict[u'artist'], artist.name) self.assertEqual(album_dict['artist'], artist.name)
self.assertEqual(album_dict[u'artistId'], str(artist.id)) self.assertEqual(album_dict['artistId'], str(artist.id))
self.assertEqual(album_dict[u'songCount'], 2) self.assertEqual(album_dict['songCount'], 2)
self.assertEqual(album_dict[u'duration'], 8) self.assertEqual(album_dict['duration'], 8)
self.assertRegexpMatches(album_dict[u'created'], date_regex) self.assertRegexpMatches(album_dict['created'], date_regex)
self.assertRegexpMatches(album_dict[u'starred'], date_regex) self.assertRegexpMatches(album_dict['starred'], date_regex)
@db_session
def test_track(self): def test_track(self):
track1, track2 = self.create_some_tracks() track1, track2 = self.create_some_tracks()
# Assuming SQLite doesn't enforce foreign key constraints # Assuming SQLite doesn't enforce foreign key constraints
MockUser = namedtuple(u'User', [ u'id' ]) MockUser = namedtuple('User', [ 'id' ])
user = MockUser(uuid.uuid4()) user = MockUser(uuid.uuid4())
track1_dict = track1.as_subsonic_child(user, None) track1_dict = track1.as_subsonic_child(user, None)
self.assertIsInstance(track1_dict, dict) self.assertIsInstance(track1_dict, dict)
self.assertIn(u'id', track1_dict) self.assertIn('id', track1_dict)
self.assertIn(u'parent', track1_dict) self.assertIn('parent', track1_dict)
self.assertIn(u'isDir', track1_dict) self.assertIn('isDir', track1_dict)
self.assertIn(u'title', track1_dict) self.assertIn('title', track1_dict)
self.assertFalse(track1_dict[u'isDir']) self.assertFalse(track1_dict['isDir'])
# ... we'll test the rest against the API XSD. # ... we'll test the rest against the API XSD.
@db_session
def test_user(self): def test_user(self):
user = db.User() user = self.create_user()
user.name = u'Test User'
user.password = u'secret'
user.salt = u'ABC+'
user_dict = user.as_subsonic_user() user_dict = user.as_subsonic_user()
self.assertIsInstance(user_dict, dict) self.assertIsInstance(user_dict, dict)
@db_session
def test_chat(self): def test_chat(self):
user = db.User() user = self.create_user()
user.name = u'Test User'
user.password = u'secret'
user.salt = u'ABC+'
line = db.ChatMessage() line = db.ChatMessage(
line.user = user user = user,
line.message = u'Hello world!' message = 'Hello world!'
)
line_dict = line.responsize() line_dict = line.responsize()
self.assertIsInstance(line_dict, dict) self.assertIsInstance(line_dict, dict)
self.assertIn(u'username', line_dict) self.assertIn('username', line_dict)
self.assertEqual(line_dict[u'username'], user.name) self.assertEqual(line_dict['username'], user.name)
@db_session
def test_playlist(self): def test_playlist(self):
playlist = self.create_playlist() playlist = self.create_playlist()
playlist_dict = playlist.as_subsonic_playlist(playlist.user) playlist_dict = playlist.as_subsonic_playlist(playlist.user)
self.assertIsInstance(playlist_dict, dict) self.assertIsInstance(playlist_dict, dict)
@db_session
def test_playlist_tracks(self): def test_playlist_tracks(self):
playlist = self.create_playlist() playlist = self.create_playlist()
track1, track2 = self.create_some_tracks() track1, track2 = self.create_some_tracks()
@ -307,9 +281,10 @@ class DbTestCase(unittest.TestCase):
playlist.add(str(track1.id)) playlist.add(str(track1.id))
self.assertSequenceEqual(playlist.get_tracks(), [ track1 ]) self.assertSequenceEqual(playlist.get_tracks(), [ track1 ])
self.assertRaises(ValueError, playlist.add, u'some string') self.assertRaises(ValueError, playlist.add, 'some string')
self.assertRaises(NameError, playlist.add, 2345) self.assertRaises(NameError, playlist.add, 2345)
@db_session
def test_playlist_remove_tracks(self): def test_playlist_remove_tracks(self):
playlist = self.create_playlist() playlist = self.create_playlist()
track1, track2 = self.create_some_tracks() track1, track2 = self.create_some_tracks()
@ -329,6 +304,7 @@ class DbTestCase(unittest.TestCase):
playlist.remove_at_indexes([ 1, 1 ]) playlist.remove_at_indexes([ 1, 1 ])
self.assertSequenceEqual(playlist.get_tracks(), [ track2, track1 ]) self.assertSequenceEqual(playlist.get_tracks(), [ track2, track1 ])
@db_session
def test_playlist_fixing(self): def test_playlist_fixing(self):
playlist = self.create_playlist() playlist = self.create_playlist()
track1, track2 = self.create_some_tracks() track1, track2 = self.create_some_tracks()
@ -338,10 +314,10 @@ class DbTestCase(unittest.TestCase):
playlist.add(track2) playlist.add(track2)
self.assertSequenceEqual(playlist.get_tracks(), [ track1, track2 ]) self.assertSequenceEqual(playlist.get_tracks(), [ track1, track2 ])
self.store.remove(track2) track2.delete()
self.assertSequenceEqual(playlist.get_tracks(), [ track1 ]) self.assertSequenceEqual(playlist.get_tracks(), [ track1 ])
playlist.tracks = u'{0},{0},some random garbage,{0}'.format(track1.id) playlist.tracks = '{0},{0},some random garbage,{0}'.format(track1.id)
self.assertSequenceEqual(playlist.get_tracks(), [ track1, track1, track1 ]) self.assertSequenceEqual(playlist.get_tracks(), [ track1, track1, track1 ])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -16,6 +16,7 @@ import tempfile
import unittest import unittest
from contextlib import contextmanager from contextlib import contextmanager
from pony.orm import db_session, commit
from supysonic import db from supysonic import db
from supysonic.managers.folder import FolderManager from supysonic.managers.folder import FolderManager
@ -23,133 +24,158 @@ from supysonic.scanner import Scanner
class ScannerTestCase(unittest.TestCase): class ScannerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = db.get_store('sqlite:') db.init_database('sqlite:', True)
with io.open('schema/sqlite.sql', 'r') as f:
for statement in f.read().split(';'):
self.store.execute(statement)
FolderManager.add(self.store, 'folder', os.path.abspath('tests/assets')) FolderManager.add('folder', os.path.abspath('tests/assets'))
self.folder = self.store.find(db.Folder).one() with db_session:
self.assertIsNotNone(self.folder) folder = db.Folder.select().first()
self.assertIsNotNone(folder)
self.folderid = folder.id
self.scanner = Scanner(self.store) self.scanner = Scanner()
self.scanner.scan(self.folder) self.scanner.scan(folder)
def tearDown(self): def tearDown(self):
self.scanner.finish() self.scanner.finish()
self.store.close() db.release_database()
@contextmanager @contextmanager
def __temporary_track_copy(self): def __temporary_track_copy(self):
track = self.store.find(db.Track).one() track = db.Track.select().first()
with tempfile.NamedTemporaryFile(dir = os.path.dirname(track.path)) as tf: with tempfile.NamedTemporaryFile(dir = os.path.dirname(track.path)) as tf:
with io.open(track.path, 'rb') as f: with io.open(track.path, 'rb') as f:
tf.write(f.read()) tf.write(f.read())
yield tf yield tf
@db_session
def test_scan(self): def test_scan(self):
self.assertEqual(self.store.find(db.Track).count(), 1) self.assertEqual(db.Track.select().count(), 1)
self.assertRaises(TypeError, self.scanner.scan, None) self.assertRaises(TypeError, self.scanner.scan, None)
self.assertRaises(TypeError, self.scanner.scan, 'string') self.assertRaises(TypeError, self.scanner.scan, 'string')
@db_session
def test_progress(self): def test_progress(self):
def progress(processed, total): def progress(processed, total):
self.assertIsInstance(processed, int) self.assertIsInstance(processed, int)
self.assertIsInstance(total, int) self.assertIsInstance(total, int)
self.assertLessEqual(processed, total) self.assertLessEqual(processed, total)
self.scanner.scan(self.folder, progress) self.scanner.scan(db.Folder[self.folderid], progress)
@db_session
def test_rescan(self): def test_rescan(self):
self.scanner.scan(self.folder) self.scanner.scan(db.Folder[self.folderid])
self.assertEqual(self.store.find(db.Track).count(), 1) commit()
self.assertEqual(db.Track.select().count(), 1)
@db_session
def test_force_rescan(self): def test_force_rescan(self):
self.scanner = Scanner(self.store, True) self.scanner = Scanner(True)
self.scanner.scan(self.folder) self.scanner.scan(db.Folder[self.folderid])
self.assertEqual(self.store.find(db.Track).count(), 1) commit()
self.assertEqual(db.Track.select().count(), 1)
@db_session
def test_scan_file(self): def test_scan_file(self):
track = self.store.find(db.Track).one() track = db.Track.select().first()
self.assertRaises(TypeError, self.scanner.scan_file, None) self.assertRaises(TypeError, self.scanner.scan_file, None)
self.assertRaises(TypeError, self.scanner.scan_file, track) self.assertRaises(TypeError, self.scanner.scan_file, track)
self.scanner.scan_file('/some/inexistent/path') self.scanner.scan_file('/some/inexistent/path')
self.assertEqual(self.store.find(db.Track).count(), 1) commit()
self.assertEqual(db.Track.select().count(), 1)
@db_session
def test_remove_file(self): def test_remove_file(self):
track = self.store.find(db.Track).one() track = db.Track.select().first()
self.assertRaises(TypeError, self.scanner.remove_file, None) self.assertRaises(TypeError, self.scanner.remove_file, None)
self.assertRaises(TypeError, self.scanner.remove_file, track) self.assertRaises(TypeError, self.scanner.remove_file, track)
self.scanner.remove_file('/some/inexistent/path') self.scanner.remove_file('/some/inexistent/path')
self.assertEqual(self.store.find(db.Track).count(), 1) commit()
self.assertEqual(db.Track.select().count(), 1)
self.scanner.remove_file(track.path) self.scanner.remove_file(track.path)
self.scanner.finish() self.scanner.finish()
self.assertEqual(self.store.find(db.Track).count(), 0) commit()
self.assertEqual(self.store.find(db.Album).count(), 0) self.assertEqual(db.Track.select().count(), 0)
self.assertEqual(self.store.find(db.Artist).count(), 0) self.assertEqual(db.Album.select().count(), 0)
self.assertEqual(db.Artist.select().count(), 0)
@db_session
def test_move_file(self): def test_move_file(self):
track = self.store.find(db.Track).one() track = db.Track.select().first()
self.assertRaises(TypeError, self.scanner.move_file, None, 'string') self.assertRaises(TypeError, self.scanner.move_file, None, 'string')
self.assertRaises(TypeError, self.scanner.move_file, track, 'string') self.assertRaises(TypeError, self.scanner.move_file, track, 'string')
self.assertRaises(TypeError, self.scanner.move_file, 'string', None) self.assertRaises(TypeError, self.scanner.move_file, 'string', None)
self.assertRaises(TypeError, self.scanner.move_file, 'string', track) self.assertRaises(TypeError, self.scanner.move_file, 'string', track)
self.scanner.move_file('/some/inexistent/path', track.path) self.scanner.move_file('/some/inexistent/path', track.path)
self.assertEqual(self.store.find(db.Track).count(), 1) commit()
self.assertEqual(db.Track.select().count(), 1)
self.scanner.move_file(track.path, track.path) self.scanner.move_file(track.path, track.path)
self.assertEqual(self.store.find(db.Track).count(), 1) commit()
self.assertEqual(db.Track.select().count(), 1)
self.assertRaises(Exception, self.scanner.move_file, track.path, '/some/inexistent/path') self.assertRaises(Exception, self.scanner.move_file, track.path, '/some/inexistent/path')
with self.__temporary_track_copy() as tf: with self.__temporary_track_copy() as tf:
self.scanner.scan(self.folder) self.scanner.scan(db.Folder[self.folderid])
self.assertEqual(self.store.find(db.Track).count(), 2) commit()
self.assertEqual(db.Track.select().count(), 2)
self.scanner.move_file(tf.name, track.path) self.scanner.move_file(tf.name, track.path)
self.assertEqual(self.store.find(db.Track).count(), 1) commit()
self.assertEqual(db.Track.select().count(), 1)
track = self.store.find(db.Track).one() track = db.Track.select().first()
new_path = os.path.abspath(os.path.join(os.path.dirname(track.path), '..', 'silence.mp3')) new_path = os.path.abspath(os.path.join(os.path.dirname(track.path), '..', 'silence.mp3'))
self.scanner.move_file(track.path, new_path) self.scanner.move_file(track.path, new_path)
self.assertEqual(self.store.find(db.Track).count(), 1) commit()
self.assertEqual(db.Track.select().count(), 1)
self.assertEqual(track.path, new_path) self.assertEqual(track.path, new_path)
@db_session
def test_rescan_corrupt_file(self): def test_rescan_corrupt_file(self):
track = self.store.find(db.Track).one() track = db.Track.select().first()
self.scanner = Scanner(self.store, True) self.scanner = Scanner(True)
with self.__temporary_track_copy() as tf: with self.__temporary_track_copy() as tf:
self.scanner.scan(self.folder) self.scanner.scan(db.Folder[self.folderid])
self.assertEqual(self.store.find(db.Track).count(), 2) commit()
self.assertEqual(db.Track.select().count(), 2)
tf.seek(0, 0) tf.seek(0, 0)
tf.write('\x00' * 4096) tf.write('\x00' * 4096)
tf.truncate() tf.truncate()
self.scanner.scan(self.folder) self.scanner.scan(db.Folder[self.folderid])
self.assertEqual(self.store.find(db.Track).count(), 1) commit()
self.assertEqual(db.Track.select().count(), 1)
@db_session
def test_rescan_removed_file(self): def test_rescan_removed_file(self):
track = self.store.find(db.Track).one() track = db.Track.select().first()
with self.__temporary_track_copy() as tf: with self.__temporary_track_copy() as tf:
self.scanner.scan(self.folder) self.scanner.scan(db.Folder[self.folderid])
self.assertEqual(self.store.find(db.Track).count(), 2) commit()
self.assertEqual(db.Track.select().count(), 2)
self.scanner.scan(self.folder) self.scanner.scan(db.Folder[self.folderid])
self.assertEqual(self.store.find(db.Track).count(), 1) commit()
self.assertEqual(db.Track.select().count(), 1)
@db_session
def test_scan_tag_change(self): def test_scan_tag_change(self):
self.scanner = Scanner(self.store, True) self.scanner = Scanner(True)
folder = db.Folder[self.folderid]
with self.__temporary_track_copy() as tf: with self.__temporary_track_copy() as tf:
self.scanner.scan(self.folder) self.scanner.scan(folder)
copy = self.store.find(db.Track, db.Track.path == tf.name).one() commit()
copy = db.Track.get(path = tf.name)
self.assertEqual(copy.artist.name, 'Some artist') self.assertEqual(copy.artist.name, 'Some artist')
self.assertEqual(copy.album.name, 'Awesome album') self.assertEqual(copy.album.name, 'Awesome album')
@ -158,12 +184,12 @@ class ScannerTestCase(unittest.TestCase):
tags['album'] = 'Crappy album' tags['album'] = 'Crappy album'
tags.save() tags.save()
self.scanner.scan(self.folder) self.scanner.scan(folder)
self.scanner.finish() self.scanner.finish()
self.assertEqual(copy.artist.name, 'Renamed artist') self.assertEqual(copy.artist.name, 'Renamed artist')
self.assertEqual(copy.album.name, 'Crappy album') self.assertEqual(copy.album.name, 'Crappy album')
self.assertIsNotNone(self.store.find(db.Artist, db.Artist.name == 'Some artist').one()) self.assertIsNotNone(db.Artist.get(name = 'Some artist'))
self.assertIsNotNone(self.store.find(db.Album, db.Album.name == 'Awesome album').one()) self.assertIsNotNone(db.Album.get(name = 'Awesome album'))
def test_stats(self): def test_stats(self):
self.assertEqual(self.scanner.stats(), ((1,1,1),(0,0,0))) self.assertEqual(self.scanner.stats(), ((1,1,1),(0,0,0)))

View File

@ -18,9 +18,10 @@ import time
import unittest import unittest
from contextlib import contextmanager from contextlib import contextmanager
from pony.orm import db_session
from threading import Thread from threading import Thread
from supysonic.db import get_store, Track, Artist from supysonic.db import init_database, release_database, Track, Artist
from supysonic.managers.folder import FolderManager from supysonic.managers.folder import FolderManager
from supysonic.watcher import SupysonicWatcher from supysonic.watcher import SupysonicWatcher
@ -38,29 +39,14 @@ class WatcherTestConfig(TestConfig):
self.BASE['database_uri'] = db_uri self.BASE['database_uri'] = db_uri
class WatcherTestBase(unittest.TestCase): class WatcherTestBase(unittest.TestCase):
@contextmanager
def _get_store(self):
store = None
try:
store = get_store('sqlite:///' + self.__dbfile)
yield store
store.commit()
store.close()
except:
store.rollback()
store.close()
raise
def setUp(self): def setUp(self):
self.__dbfile = tempfile.mkstemp()[1] self.__dbfile = tempfile.mkstemp()[1]
conf = WatcherTestConfig('sqlite:///' + self.__dbfile) dburi = 'sqlite:///' + self.__dbfile
self.__sleep_time = conf.DAEMON['wait_delay'] + 1 init_database(dburi, True)
release_database()
with self._get_store() as store: conf = WatcherTestConfig(dburi)
with io.open('schema/sqlite.sql', 'r') as sql: self.__sleep_time = conf.DAEMON['wait_delay'] + 1
schema = sql.read()
for statement in schema.split(';'):
store.execute(statement)
self.__watcher = SupysonicWatcher(conf) self.__watcher = SupysonicWatcher(conf)
self.__thread = Thread(target = self.__watcher.run) self.__thread = Thread(target = self.__watcher.run)
@ -82,6 +68,12 @@ class WatcherTestBase(unittest.TestCase):
def _sleep(self): def _sleep(self):
time.sleep(self.__sleep_time) time.sleep(self.__sleep_time)
@contextmanager
def _tempdbrebind(self):
init_database('sqlite:///' + self.__dbfile)
try: yield
finally: release_database()
class NothingToWatchTestCase(WatcherTestBase): class NothingToWatchTestCase(WatcherTestBase):
def test_spawn_useless_watcher(self): def test_spawn_useless_watcher(self):
self._start() self._start()
@ -93,8 +85,7 @@ class WatcherTestCase(WatcherTestBase):
def setUp(self): def setUp(self):
super(WatcherTestCase, self).setUp() super(WatcherTestCase, self).setUp()
self.__dir = tempfile.mkdtemp() self.__dir = tempfile.mkdtemp()
with self._get_store() as store: FolderManager.add('Folder', self.__dir)
FolderManager.add(store, 'Folder', self.__dir)
self._start() self._start()
def tearDown(self): def tearDown(self):
@ -115,9 +106,9 @@ class WatcherTestCase(WatcherTestBase):
shutil.copyfile('tests/assets/folder/silence.mp3', path) shutil.copyfile('tests/assets/folder/silence.mp3', path)
return path return path
@db_session
def assertTrackCountEqual(self, expected): def assertTrackCountEqual(self, expected):
with self._get_store() as store: self.assertEqual(Track.select().count(), expected)
self.assertEqual(store.find(Track).count(), expected)
def test_add(self): def test_add(self):
self._addfile() self._addfile()
@ -128,6 +119,7 @@ class WatcherTestCase(WatcherTestBase):
def test_add_nowait_stop(self): def test_add_nowait_stop(self):
self._addfile() self._addfile()
self._stop() self._stop()
with self._tempdbrebind():
self.assertTrackCountEqual(1) self.assertTrackCountEqual(1)
def test_add_multiple(self): def test_add_multiple(self):
@ -136,46 +128,46 @@ class WatcherTestCase(WatcherTestBase):
self._addfile() self._addfile()
self.assertTrackCountEqual(0) self.assertTrackCountEqual(0)
self._sleep() self._sleep()
with self._get_store() as store: with db_session:
self.assertEqual(store.find(Track).count(), 3) self.assertEqual(Track.select().count(), 3)
self.assertEqual(store.find(Artist).count(), 1) self.assertEqual(Artist.select().count(), 1)
def test_change(self): def test_change(self):
path = self._addfile() path = self._addfile()
self._sleep() self._sleep()
trackid = None trackid = None
with self._get_store() as store: with db_session:
self.assertEqual(store.find(Track).count(), 1) self.assertEqual(Track.select().count(), 1)
self.assertEqual(store.find(Artist, Artist.name == 'Some artist').count(), 1) self.assertEqual(Artist.select(lambda a: a.name == 'Some artist').count(), 1)
trackid = store.find(Track).one().id trackid = Track.select().first().id
tags = mutagen.File(path, easy = True) tags = mutagen.File(path, easy = True)
tags['artist'] = 'Renamed' tags['artist'] = 'Renamed'
tags.save() tags.save()
self._sleep() self._sleep()
with self._get_store() as store: with db_session:
self.assertEqual(store.find(Track).count(), 1) self.assertEqual(Track.select().count(), 1)
self.assertEqual(store.find(Artist, Artist.name == 'Some artist').count(), 0) self.assertEqual(Artist.select(lambda a: a.name == 'Some artist').count(), 0)
self.assertEqual(store.find(Artist, Artist.name == 'Renamed').count(), 1) self.assertEqual(Artist.select(lambda a: a.name == 'Renamed').count(), 1)
self.assertEqual(store.find(Track).one().id, trackid) self.assertEqual(Track.select().first().id, trackid)
def test_rename(self): def test_rename(self):
path = self._addfile() path = self._addfile()
self._sleep() self._sleep()
trackid = None trackid = None
with self._get_store() as store: with db_session:
self.assertEqual(store.find(Track).count(), 1) self.assertEqual(Track.select().count(), 1)
trackid = store.find(Track).one().id trackid = Track.select().first().id
newpath = self._temppath() newpath = self._temppath()
shutil.move(path, newpath) shutil.move(path, newpath)
self._sleep() self._sleep()
with self._get_store() as store: with db_session:
track = store.find(Track).one() track = Track.select().first()
self.assertIsNotNone(track) self.assertIsNotNone(track)
self.assertNotEqual(track.path, path) self.assertNotEqual(track.path, path)
self.assertEqual(track.path, newpath) self.assertEqual(track.path, newpath)

View File

@ -11,6 +11,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import Folder from supysonic.db import Folder
from .frontendtestbase import FrontendTestBase from .frontendtestbase import FrontendTestBase
@ -50,20 +52,22 @@ class FolderTestCase(FrontendTestBase):
self.assertIn('Add Folder', rv.data) self.assertIn('Add Folder', rv.data)
rv = self.client.post('/folder/add', data = { 'name': 'name', 'path': 'tests/assets' }, follow_redirects = True) rv = self.client.post('/folder/add', data = { 'name': 'name', 'path': 'tests/assets' }, follow_redirects = True)
self.assertIn('created', rv.data) self.assertIn('created', rv.data)
self.assertEqual(self.store.find(Folder).count(), 1) with db_session:
self.assertEqual(Folder.select().count(), 1)
def test_delete(self): def test_delete(self):
folder = Folder() with db_session:
folder.name = 'folder' folder = Folder(
folder.path = 'tests/assets' name = 'folder',
folder.root = True path = 'tests/assets',
self.store.add(folder) root = True
self.store.commit() )
self._login('bob', 'B0b') self._login('bob', 'B0b')
rv = self.client.get('/folder/del/' + str(folder.id), follow_redirects = True) rv = self.client.get('/folder/del/' + str(folder.id), follow_redirects = True)
self.assertIn('There\'s nothing much to see', rv.data) self.assertIn('There\'s nothing much to see', rv.data)
self.assertEqual(self.store.find(Folder).count(), 1) with db_session:
self.assertEqual(Folder.select().count(), 1)
self._logout() self._logout()
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
@ -73,15 +77,17 @@ class FolderTestCase(FrontendTestBase):
self.assertIn('No such folder', rv.data) self.assertIn('No such folder', rv.data)
rv = self.client.get('/folder/del/' + str(folder.id), follow_redirects = True) rv = self.client.get('/folder/del/' + str(folder.id), follow_redirects = True)
self.assertIn('Music folders', rv.data) self.assertIn('Music folders', rv.data)
self.assertEqual(self.store.find(Folder).count(), 0) with db_session:
self.assertEqual(Folder.select().count(), 0)
def test_scan(self): def test_scan(self):
folder = Folder() with db_session:
folder.name = 'folder' folder = Folder(
folder.path = 'tests/assets' name = 'folder',
folder.root = True path = 'tests/assets',
self.store.add(folder) root = True,
self.store.commit() )
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
rv = self.client.get('/folder/scan/string', follow_redirects = True) rv = self.client.get('/folder/scan/string', follow_redirects = True)

View File

@ -12,6 +12,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import User from supysonic.db import User
from .frontendtestbase import FrontendTestBase from .frontendtestbase import FrontendTestBase
@ -50,8 +52,9 @@ class LoginTestCase(FrontendTestBase):
def test_root_with_valid_session(self): def test_root_with_valid_session(self):
# Root with valid session # Root with valid session
with db_session:
with self.client.session_transaction() as sess: with self.client.session_transaction() as sess:
sess['userid'] = self.store.find(User, User.name == 'alice').one().id sess['userid'] = User.get(name = 'alice').id
rv = self.client.get('/', follow_redirects=True) rv = self.client.get('/', follow_redirects=True)
self.assertIn('alice', rv.data) self.assertIn('alice', rv.data)
self.assertIn('Log out', rv.data) self.assertIn('Log out', rv.data)

View File

@ -11,6 +11,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import Folder, Artist, Album, Track, Playlist, User from supysonic.db import Folder, Artist, Album, Track, Playlist, User
from .frontendtestbase import FrontendTestBase from .frontendtestbase import FrontendTestBase
@ -19,43 +21,34 @@ class PlaylistTestCase(FrontendTestBase):
def setUp(self): def setUp(self):
super(PlaylistTestCase, self).setUp() super(PlaylistTestCase, self).setUp()
folder = Folder() with db_session:
folder.name = 'Root' folder = Folder(name = 'Root', path = 'tests/assets', root = True)
folder.path = 'tests/assets' artist = Artist(name = 'Artist!')
folder.root = True album = Album(name = 'Album!', artist = artist)
artist = Artist() track = Track(
artist.name = 'Artist!' path = 'tests/assets/23bytes',
title = '23bytes',
artist = artist,
album = album,
folder = folder,
root_folder = folder,
duration = 2,
disc = 1,
number = 1,
content_type = 'audio/mpeg',
bitrate = 320,
last_modification = 0
)
album = Album() playlist = Playlist(
album.name = 'Album!' name = 'Playlist!',
album.artist = artist user = User.get(name = 'alice')
)
track = Track()
track.path = 'tests/assets/23bytes'
track.title = '23bytes'
track.artist = artist
track.album = album
track.folder = folder
track.root_folder = folder
track.duration = 2
track.disc = 1
track.number = 1
track.content_type = 'audio/mpeg'
track.bitrate = 320
track.last_modification = 0
playlist = Playlist()
playlist.name = 'Playlist!'
playlist.user = self.store.find(User, User.name == 'alice').one()
for _ in range(4): for _ in range(4):
playlist.add(track) playlist.add(track)
self.store.add(track) self.playlistid = playlist.id
self.store.add(playlist)
self.store.commit()
self.playlist = playlist
def test_index(self): def test_index(self):
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
@ -68,7 +61,7 @@ class PlaylistTestCase(FrontendTestBase):
self.assertIn('Invalid', rv.data) self.assertIn('Invalid', rv.data)
rv = self.client.get('/playlist/' + str(uuid.uuid4()), follow_redirects = True) rv = self.client.get('/playlist/' + str(uuid.uuid4()), follow_redirects = True)
self.assertIn('Unknown', rv.data) self.assertIn('Unknown', rv.data)
rv = self.client.get('/playlist/' + str(self.playlist.id)) rv = self.client.get('/playlist/' + str(self.playlistid))
self.assertIn('Playlist!', rv.data) self.assertIn('Playlist!', rv.data)
self.assertIn('23bytes', rv.data) self.assertIn('23bytes', rv.data)
self.assertIn('Artist!', rv.data) self.assertIn('Artist!', rv.data)
@ -80,22 +73,25 @@ class PlaylistTestCase(FrontendTestBase):
self.assertIn('Invalid', rv.data) self.assertIn('Invalid', rv.data)
rv = self.client.post('/playlist/' + str(uuid.uuid4()), follow_redirects = True) rv = self.client.post('/playlist/' + str(uuid.uuid4()), follow_redirects = True)
self.assertIn('Unknown', rv.data) self.assertIn('Unknown', rv.data)
rv = self.client.post('/playlist/' + str(self.playlist.id), follow_redirects = True) rv = self.client.post('/playlist/' + str(self.playlistid), follow_redirects = True)
self.assertNotIn('updated', rv.data) self.assertNotIn('updated', rv.data)
self.assertIn('not allowed', rv.data) self.assertIn('not allowed', rv.data)
self._logout() self._logout()
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
rv = self.client.post('/playlist/' + str(self.playlist.id), follow_redirects = True) rv = self.client.post('/playlist/' + str(self.playlistid), follow_redirects = True)
self.assertNotIn('updated', rv.data) self.assertNotIn('updated', rv.data)
self.assertIn('Missing', rv.data) self.assertIn('Missing', rv.data)
self.assertEqual(self.playlist.name, 'Playlist!') with db_session:
self.assertEqual(Playlist[self.playlistid].name, 'Playlist!')
rv = self.client.post('/playlist/' + str(self.playlist.id), data = { 'name': 'abc', 'public': True }, follow_redirects = True) rv = self.client.post('/playlist/' + str(self.playlistid), data = { 'name': 'abc', 'public': True }, follow_redirects = True)
self.assertIn('updated', rv.data) self.assertIn('updated', rv.data)
self.assertNotIn('not allowed', rv.data) self.assertNotIn('not allowed', rv.data)
self.assertEqual(self.playlist.name, 'abc') with db_session:
self.assertTrue(self.playlist.public) playlist = Playlist[self.playlistid]
self.assertEqual(playlist.name, 'abc')
self.assertTrue(playlist.public)
def test_delete(self): def test_delete(self):
self._login('bob', 'B0b') self._login('bob', 'B0b')
@ -103,15 +99,17 @@ class PlaylistTestCase(FrontendTestBase):
self.assertIn('Invalid', rv.data) self.assertIn('Invalid', rv.data)
rv = self.client.get('/playlist/del/' + str(uuid.uuid4()), follow_redirects = True) rv = self.client.get('/playlist/del/' + str(uuid.uuid4()), follow_redirects = True)
self.assertIn('Unknown', rv.data) self.assertIn('Unknown', rv.data)
rv = self.client.get('/playlist/del/' + str(self.playlist.id), follow_redirects = True) rv = self.client.get('/playlist/del/' + str(self.playlistid), follow_redirects = True)
self.assertIn('not allowed', rv.data) self.assertIn('not allowed', rv.data)
self.assertEqual(self.store.find(Playlist).count(), 1) with db_session:
self.assertEqual(Playlist.select().count(), 1)
self._logout() self._logout()
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
rv = self.client.get('/playlist/del/' + str(self.playlist.id), follow_redirects = True) rv = self.client.get('/playlist/del/' + str(self.playlistid), follow_redirects = True)
self.assertIn('deleted', rv.data) self.assertIn('deleted', rv.data)
self.assertEqual(self.store.find(Playlist).count(), 0) with db_session:
self.assertEqual(Playlist.select().count(), 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -11,6 +11,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import User, ClientPrefs from supysonic.db import User, ClientPrefs
from .frontendtestbase import FrontendTestBase from .frontendtestbase import FrontendTestBase
@ -19,7 +21,8 @@ class UserTestCase(FrontendTestBase):
def setUp(self): def setUp(self):
super(UserTestCase, self).setUp() super(UserTestCase, self).setUp()
self.users = { u.name: u for u in self.store.find(User) } with db_session:
self.users = { u.name: u.id for u in User.select() }
def test_index(self): def test_index(self):
self._login('bob', 'B0b') self._login('bob', 'B0b')
@ -38,18 +41,15 @@ class UserTestCase(FrontendTestBase):
self.assertIn('Invalid', rv.data) self.assertIn('Invalid', rv.data)
rv = self.client.get('/user/' + str(uuid.uuid4()), follow_redirects = True) rv = self.client.get('/user/' + str(uuid.uuid4()), follow_redirects = True)
self.assertIn('No such user', rv.data) self.assertIn('No such user', rv.data)
rv = self.client.get('/user/' + str(self.users['bob'].id)) rv = self.client.get('/user/' + str(self.users['bob']))
self.assertIn('bob', rv.data) self.assertIn('bob', rv.data)
self._logout() self._logout()
prefs = ClientPrefs() with db_session:
prefs.user_id = self.users['bob'].id ClientPrefs(user = User[self.users['bob']], client_name = 'tests')
prefs.client_name = 'tests'
self.store.add(prefs)
self.store.commit()
self._login('bob', 'B0b') self._login('bob', 'B0b')
rv = self.client.get('/user/' + str(self.users['alice'].id), follow_redirects = True) rv = self.client.get('/user/' + str(self.users['alice']), follow_redirects = True)
self.assertIn('There\'s nothing much to see', rv.data) self.assertIn('There\'s nothing much to see', rv.data)
self.assertNotIn('<h2>bob</h2>', rv.data) self.assertNotIn('<h2>bob</h2>', rv.data)
rv = self.client.get('/user/me') rv = self.client.get('/user/me')
@ -68,19 +68,19 @@ class UserTestCase(FrontendTestBase):
self.client.post('/user/me', data = { 'n_': 'o' }) self.client.post('/user/me', data = { 'n_': 'o' })
self.client.post('/user/me', data = { 'inexisting_client': 'setting' }) self.client.post('/user/me', data = { 'inexisting_client': 'setting' })
prefs = ClientPrefs() with db_session:
prefs.user_id = self.users['alice'].id ClientPrefs(user = User[self.users['alice']], client_name = 'tests')
prefs.client_name = 'tests'
self.store.add(prefs)
self.store.commit()
rv = self.client.post('/user/me', data = { 'tests_format': 'mp3', 'tests_bitrate': 128 }) rv = self.client.post('/user/me', data = { 'tests_format': 'mp3', 'tests_bitrate': 128 })
self.assertIn('updated', rv.data) self.assertIn('updated', rv.data)
with db_session:
prefs = ClientPrefs[User[self.users['alice']], 'tests']
self.assertEqual(prefs.format, 'mp3') self.assertEqual(prefs.format, 'mp3')
self.assertEqual(prefs.bitrate, 128) self.assertEqual(prefs.bitrate, 128)
self.client.post('/user/me', data = { 'tests_delete': 1 }) self.client.post('/user/me', data = { 'tests_delete': 1 })
self.assertEqual(self.store.find(ClientPrefs).count(), 0) with db_session:
self.assertEqual(ClientPrefs.select().count(), 0)
def test_change_username_get(self): def test_change_username_get(self):
self._login('bob', 'B0b') self._login('bob', 'B0b')
@ -93,13 +93,13 @@ class UserTestCase(FrontendTestBase):
self.assertIn('Invalid', rv.data) self.assertIn('Invalid', rv.data)
rv = self.client.get('/user/{}/changeusername'.format(uuid.uuid4()), follow_redirects = True) rv = self.client.get('/user/{}/changeusername'.format(uuid.uuid4()), follow_redirects = True)
self.assertIn('No such user', rv.data) self.assertIn('No such user', rv.data)
self.client.get('/user/{}/changeusername'.format(self.users['bob'].id)) self.client.get('/user/{}/changeusername'.format(self.users['bob']))
def test_change_username_post(self): def test_change_username_post(self):
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
self.client.post('/user/whatever/changeusername') self.client.post('/user/whatever/changeusername')
path = '/user/{}/changeusername'.format(self.users['bob'].id) path = '/user/{}/changeusername'.format(self.users['bob'])
rv = self.client.post(path, follow_redirects = True) rv = self.client.post(path, follow_redirects = True)
self.assertIn('required', rv.data) self.assertIn('required', rv.data)
rv = self.client.post(path, data = { 'user': 'bob' }, follow_redirects = True) rv = self.client.post(path, data = { 'user': 'bob' }, follow_redirects = True)
@ -107,10 +107,13 @@ class UserTestCase(FrontendTestBase):
rv = self.client.post(path, data = { 'user': 'b0b', 'admin': 1 }, follow_redirects = True) rv = self.client.post(path, data = { 'user': 'b0b', 'admin': 1 }, follow_redirects = True)
self.assertIn('updated', rv.data) self.assertIn('updated', rv.data)
self.assertIn('b0b', rv.data) self.assertIn('b0b', rv.data)
self.assertEqual(self.users['bob'].name, 'b0b') with db_session:
self.assertTrue(self.users['bob'].admin) bob = User[self.users['bob']]
self.assertEqual(bob.name, 'b0b')
self.assertTrue(bob.admin)
rv = self.client.post(path, data = { 'user': 'alice' }, follow_redirects = True) rv = self.client.post(path, data = { 'user': 'alice' }, follow_redirects = True)
self.assertEqual(self.users['bob'].name, 'b0b') with db_session:
self.assertEqual(User[self.users['bob']].name, 'b0b')
def test_change_mail_get(self): def test_change_mail_get(self):
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
@ -126,7 +129,7 @@ class UserTestCase(FrontendTestBase):
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
rv = self.client.get('/user/me/changepass') rv = self.client.get('/user/me/changepass')
self.assertIn('Current password', rv.data) self.assertIn('Current password', rv.data)
rv = self.client.get('/user/{}/changepass'.format(self.users['bob'].id)) rv = self.client.get('/user/{}/changepass'.format(self.users['bob']))
self.assertNotIn('Current password', rv.data) self.assertNotIn('Current password', rv.data)
def test_change_password_post(self): def test_change_password_post(self):
@ -151,7 +154,7 @@ class UserTestCase(FrontendTestBase):
rv = self._login('alice', 'alice') rv = self._login('alice', 'alice')
self.assertIn('Logged in', rv.data) self.assertIn('Logged in', rv.data)
path = '/user/{}/changepass'.format(self.users['bob'].id) path = '/user/{}/changepass'.format(self.users['bob'])
rv = self.client.post(path) rv = self.client.post(path)
self.assertIn('required', rv.data) self.assertIn('required', rv.data)
rv = self.client.post(path, data = { 'new': 'alice' }) rv = self.client.post(path, data = { 'new': 'alice' })
@ -162,7 +165,6 @@ class UserTestCase(FrontendTestBase):
rv = self._login('bob', 'alice') rv = self._login('bob', 'alice')
self.assertIn('Logged in', rv.data) self.assertIn('Logged in', rv.data)
def test_add_get(self): def test_add_get(self):
self._login('bob', 'B0b') self._login('bob', 'B0b')
rv = self.client.get('/user/add', follow_redirects = True) rv = self.client.get('/user/add', follow_redirects = True)
@ -186,22 +188,25 @@ class UserTestCase(FrontendTestBase):
self.assertIn('passwords don', rv.data) self.assertIn('passwords don', rv.data)
rv = self.client.post('/user/add', data = { 'user': 'alice', 'passwd': 'passwd', 'passwd_confirm': 'passwd' }) rv = self.client.post('/user/add', data = { 'user': 'alice', 'passwd': 'passwd', 'passwd_confirm': 'passwd' })
self.assertIn('already a user with that name', rv.data) self.assertIn('already a user with that name', rv.data)
self.assertEqual(self.store.find(User).count(), 2) with db_session:
self.assertEqual(User.select().count(), 2)
rv = self.client.post('/user/add', data = { 'user': 'user', 'passwd': 'passwd', 'passwd_confirm': 'passwd', 'admin': 1 }, follow_redirects = True) rv = self.client.post('/user/add', data = { 'user': 'user', 'passwd': 'passwd', 'passwd_confirm': 'passwd', 'admin': 1 }, follow_redirects = True)
self.assertIn('added', rv.data) self.assertIn('added', rv.data)
self.assertEqual(self.store.find(User).count(), 3) with db_session:
self.assertEqual(User.select().count(), 3)
self._logout() self._logout()
rv = self._login('user', 'passwd') rv = self._login('user', 'passwd')
self.assertIn('Logged in', rv.data) self.assertIn('Logged in', rv.data)
def test_delete(self): def test_delete(self):
path = '/user/del/{}'.format(self.users['bob'].id) path = '/user/del/{}'.format(self.users['bob'])
self._login('bob', 'B0b') self._login('bob', 'B0b')
rv = self.client.get(path, follow_redirects = True) rv = self.client.get(path, follow_redirects = True)
self.assertIn('There\'s nothing much to see', rv.data) self.assertIn('There\'s nothing much to see', rv.data)
self.assertEqual(self.store.find(User).count(), 2) with db_session:
self.assertEqual(User.select().count(), 2)
self._logout() self._logout()
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
@ -211,7 +216,8 @@ class UserTestCase(FrontendTestBase):
self.assertIn('No such user', rv.data) self.assertIn('No such user', rv.data)
rv = self.client.get(path, follow_redirects = True) rv = self.client.get(path, follow_redirects = True)
self.assertIn('Deleted', rv.data) self.assertIn('Deleted', rv.data)
self.assertEqual(self.store.find(User).count(), 1) with db_session:
self.assertEqual(User.select().count(), 1)
self._logout() self._logout()
rv = self._login('bob', 'B0b') rv = self._login('bob', 'B0b')
self.assertIn('No such user', rv.data) self.assertIn('No such user', rv.data)

View File

@ -20,129 +20,133 @@ import tempfile
import unittest import unittest
import uuid import uuid
from pony.orm import db_session, ObjectNotFound
class FolderManagerTestCase(unittest.TestCase): class FolderManagerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
# Create an empty sqlite database in memory # Create an empty sqlite database in memory
self.store = db.get_store("sqlite:") db.init_database('sqlite:', True)
# Read schema from file
with io.open('schema/sqlite.sql', 'r') as sql:
schema = sql.read()
# Create tables on memory database
for command in schema.split(';'):
self.store.execute(command)
# Create some temporary directories # Create some temporary directories
self.media_dir = tempfile.mkdtemp() self.media_dir = tempfile.mkdtemp()
self.music_dir = tempfile.mkdtemp() self.music_dir = tempfile.mkdtemp()
# Add test folders
self.assertEqual(FolderManager.add(self.store, 'media', self.media_dir), FolderManager.SUCCESS)
self.assertEqual(FolderManager.add(self.store, 'music', self.music_dir), FolderManager.SUCCESS)
folder = db.Folder()
folder.root = False
folder.name = 'non-root'
folder.path = os.path.join(self.music_dir, 'subfolder')
self.store.add(folder)
artist = db.Artist()
artist.name = 'Artist'
album = db.Album()
album.name = 'Album'
album.artist = artist
root = self.store.find(db.Folder, db.Folder.name == 'media').one()
track = db.Track()
track.title = 'Track'
track.artist = artist
track.album = album
track.disc = 1
track.number = 1
track.path = os.path.join(self.media_dir, 'somefile')
track.folder = root
track.root_folder = root
track.duration = 2
track.content_type = 'audio/mpeg'
track.bitrate = 320
track.last_modification = 0
self.store.add(track)
self.store.commit()
def tearDown(self): def tearDown(self):
db.release_database()
shutil.rmtree(self.media_dir) shutil.rmtree(self.media_dir)
shutil.rmtree(self.music_dir) shutil.rmtree(self.music_dir)
@db_session
def create_folders(self):
# Add test folders
self.assertEqual(FolderManager.add('media', self.media_dir), FolderManager.SUCCESS)
self.assertEqual(FolderManager.add('music', self.music_dir), FolderManager.SUCCESS)
folder = db.Folder(
root = False,
name = 'non-root',
path = os.path.join(self.music_dir, 'subfolder')
)
artist = db.Artist(name = 'Artist')
album = db.Album(name = 'Album', artist = artist)
root = db.Folder.get(name = 'media')
track = db.Track(
title = 'Track',
artist = artist,
album = album,
disc = 1,
number = 1,
path = os.path.join(self.media_dir, 'somefile'),
folder = root,
root_folder = root,
duration = 2,
content_type = 'audio/mpeg',
bitrate = 320,
last_modification = 0
)
@db_session
def test_get_folder(self): def test_get_folder(self):
self.create_folders()
# Get existing folders # Get existing folders
for name in ['media', 'music']: for name in ['media', 'music']:
folder = self.store.find(db.Folder, db.Folder.name == name, db.Folder.root == True).one() folder = db.Folder.get(name = name, root = True)
self.assertEqual(FolderManager.get(self.store, folder.id), (FolderManager.SUCCESS, folder)) self.assertEqual(FolderManager.get(folder.id), (FolderManager.SUCCESS, folder))
# Get with invalid UUID # Get with invalid UUID
self.assertEqual(FolderManager.get(self.store, 'invalid-uuid'), (FolderManager.INVALID_ID, None)) self.assertEqual(FolderManager.get('invalid-uuid'), (FolderManager.INVALID_ID, None))
self.assertEqual(FolderManager.get(self.store, 0xdeadbeef), (FolderManager.INVALID_ID, None)) self.assertEqual(FolderManager.get(0xdeadbeef), (FolderManager.INVALID_ID, None))
# Non-existent folder # Non-existent folder
self.assertEqual(FolderManager.get(self.store, uuid.uuid4()), (FolderManager.NO_SUCH_FOLDER, None)) self.assertEqual(FolderManager.get(uuid.uuid4()), (FolderManager.NO_SUCH_FOLDER, None))
@db_session
def test_add_folder(self): def test_add_folder(self):
# Added in setUp() self.create_folders()
self.assertEqual(self.store.find(db.Folder).count(), 3) self.assertEqual(db.Folder.select().count(), 3)
# Create duplicate # Create duplicate
self.assertEqual(FolderManager.add(self.store,'media', self.media_dir), FolderManager.NAME_EXISTS) self.assertEqual(FolderManager.add('media', self.media_dir), FolderManager.NAME_EXISTS)
self.assertEqual(self.store.find(db.Folder, db.Folder.name == 'media').count(), 1) self.assertEqual(db.Folder.select(lambda f: f.name == 'media').count(), 1)
# Duplicate path # Duplicate path
self.assertEqual(FolderManager.add(self.store,'new-folder', self.media_dir), FolderManager.PATH_EXISTS) self.assertEqual(FolderManager.add('new-folder', self.media_dir), FolderManager.PATH_EXISTS)
self.assertEqual(self.store.find(db.Folder, db.Folder.path == self.media_dir).count(), 1) self.assertEqual(db.Folder.select(lambda f: f.path == self.media_dir).count(), 1)
# Invalid path # Invalid path
path = os.path.abspath('/this/not/is/valid') path = os.path.abspath('/this/not/is/valid')
self.assertEqual(FolderManager.add(self.store,'invalid-path', path), FolderManager.INVALID_PATH) self.assertEqual(FolderManager.add('invalid-path', path), FolderManager.INVALID_PATH)
self.assertEqual(self.store.find(db.Folder, db.Folder.path == path).count(), 0) self.assertFalse(db.Folder.exists(path = path))
# Subfolder of already added path # Subfolder of already added path
path = os.path.join(self.media_dir, 'subfolder') path = os.path.join(self.media_dir, 'subfolder')
os.mkdir(path) os.mkdir(path)
self.assertEqual(FolderManager.add(self.store,'subfolder', path), FolderManager.PATH_EXISTS) self.assertEqual(FolderManager.add('subfolder', path), FolderManager.PATH_EXISTS)
self.assertEqual(self.store.find(db.Folder).count(), 3) self.assertEqual(db.Folder.select().count(), 3)
# Parent folder of an already added path # Parent folder of an already added path
path = os.path.join(self.media_dir, '..') path = os.path.join(self.media_dir, '..')
self.assertEqual(FolderManager.add(self.store, 'parent', path), FolderManager.SUBPATH_EXISTS) self.assertEqual(FolderManager.add('parent', path), FolderManager.SUBPATH_EXISTS)
self.assertEqual(self.store.find(db.Folder).count(), 3) self.assertEqual(db.Folder.select().count(), 3)
@db_session
def test_delete_folder(self): def test_delete_folder(self):
self.create_folders()
# Delete existing folders # Delete existing folders
for name in ['media', 'music']: for name in ['media', 'music']:
folder = self.store.find(db.Folder, db.Folder.name == name, db.Folder.root == True).one() folder = db.Folder.get(name = name, root = True)
self.assertEqual(FolderManager.delete(self.store, folder.id), FolderManager.SUCCESS) self.assertEqual(FolderManager.delete(folder.id), FolderManager.SUCCESS)
self.assertIsNone(self.store.get(db.Folder, folder.id)) self.assertRaises(ObjectNotFound, db.Folder.__getitem__, folder.id)
# Delete invalid UUID # Delete invalid UUID
self.assertEqual(FolderManager.delete(self.store, 'invalid-uuid'), FolderManager.INVALID_ID) self.assertEqual(FolderManager.delete('invalid-uuid'), FolderManager.INVALID_ID)
self.assertEqual(self.store.find(db.Folder).count(), 1) # 'non-root' remaining self.assertEqual(db.Folder.select().count(), 1) # 'non-root' remaining
# Delete non-existent folder # Delete non-existent folder
self.assertEqual(FolderManager.delete(self.store, uuid.uuid4()), FolderManager.NO_SUCH_FOLDER) self.assertEqual(FolderManager.delete(uuid.uuid4()), FolderManager.NO_SUCH_FOLDER)
self.assertEqual(self.store.find(db.Folder).count(), 1) # 'non-root' remaining self.assertEqual(db.Folder.select().count(), 1) # 'non-root' remaining
# Delete non-root folder # Delete non-root folder
folder = self.store.find(db.Folder, db.Folder.name == 'non-root').one() folder = db.Folder.get(name = 'non-root')
self.assertEqual(FolderManager.delete(self.store, folder.id), FolderManager.NO_SUCH_FOLDER) self.assertEqual(FolderManager.delete(folder.id), FolderManager.NO_SUCH_FOLDER)
self.assertEqual(self.store.find(db.Folder).count(), 1) # 'non-root' remaining self.assertEqual(db.Folder.select().count(), 1) # 'non-root' remaining
@db_session
def test_delete_by_name(self): def test_delete_by_name(self):
self.create_folders()
# Delete existing folders # Delete existing folders
for name in ['media', 'music']: for name in ['media', 'music']:
self.assertEqual(FolderManager.delete_by_name(self.store, name), FolderManager.SUCCESS) self.assertEqual(FolderManager.delete_by_name(name), FolderManager.SUCCESS)
self.assertEqual(self.store.find(db.Folder, db.Folder.name == name).count(), 0) self.assertFalse(db.Folder.exists(name = name))
# Delete non-existent folder # Delete non-existent folder
self.assertEqual(FolderManager.delete_by_name(self.store, 'null'), FolderManager.NO_SUCH_FOLDER) self.assertEqual(FolderManager.delete_by_name('null'), FolderManager.NO_SUCH_FOLDER)
self.assertEqual(self.store.find(db.Folder).count(), 1) # 'non-root' remaining self.assertEqual(db.Folder.select().count(), 1) # 'non-root' remaining
def test_human_readable_error(self): def test_human_readable_error(self):
values = [ FolderManager.SUCCESS, FolderManager.INVALID_ID, FolderManager.NAME_EXISTS, values = [ FolderManager.SUCCESS, FolderManager.INVALID_ID, FolderManager.NAME_EXISTS,

View File

@ -13,61 +13,51 @@
from supysonic import db from supysonic import db
from supysonic.managers.user import UserManager from supysonic.managers.user import UserManager
import io
import unittest import unittest
import uuid import uuid
import io
from pony.orm import db_session, commit
from pony.orm import ObjectNotFound
class UserManagerTestCase(unittest.TestCase): class UserManagerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
# Create an empty sqlite database in memory # Create an empty sqlite database in memory
self.store = db.get_store("sqlite:") db.init_database('sqlite:', True)
# Read schema from file
with io.open('schema/sqlite.sql', 'r') as sql:
schema = sql.read()
# Create tables on memory database
for command in schema.split(';'):
self.store.execute(command)
def tearDown(self):
db.release_database()
@db_session
def create_data(self):
# Create some users # Create some users
self.assertEqual(UserManager.add(self.store, 'alice', 'ALICE', 'test@example.com', True), UserManager.SUCCESS) self.assertEqual(UserManager.add('alice', 'ALICE', 'test@example.com', True), UserManager.SUCCESS)
self.assertEqual(UserManager.add(self.store, 'bob', 'BOB', 'bob@example.com', False), UserManager.SUCCESS) self.assertEqual(UserManager.add('bob', 'BOB', 'bob@example.com', False), UserManager.SUCCESS)
self.assertEqual(UserManager.add(self.store, 'charlie', 'CHARLIE', 'charlie@example.com', False), UserManager.SUCCESS) self.assertEqual(UserManager.add('charlie', 'CHARLIE', 'charlie@example.com', False), UserManager.SUCCESS)
folder = db.Folder() folder = db.Folder(name = 'Root', path = 'tests/assets', root = True)
folder.name = 'Root' artist = db.Artist(name = 'Artist')
folder.path = 'tests/assets' album = db.Album(name = 'Album', artist = artist)
folder.root = True track = db.Track(
title = 'Track',
disc = 1,
number = 1,
duration = 1,
artist = artist,
album = album,
path = 'tests/assets/empty',
folder = folder,
root_folder = folder,
content_type = 'audio/mpeg',
bitrate = 320,
last_modification = 0
)
artist = db.Artist() playlist = db.Playlist(
artist.name = 'Artist' name = 'Playlist',
user = db.User.get(name = 'alice')
album = db.Album() )
album.name = 'Album'
album.artist = artist
track = db.Track()
track.title = 'Track'
track.disc = 1
track.number = 1
track.duration = 1
track.artist = artist
track.album = album
track.path = 'tests/assets/empty'
track.folder = folder
track.root_folder = folder
track.duration = 2
track.content_type = 'audio/mpeg'
track.bitrate = 320
track.last_modification = 0
self.store.add(track)
self.store.commit()
playlist = db.Playlist()
playlist.name = 'Playlist'
playlist.user = self.store.find(db.User, db.User.name == 'alice').one()
playlist.add(track) playlist.add(track)
self.store.add(playlist)
self.store.commit()
def test_encrypt_password(self): def test_encrypt_password(self):
func = UserManager._UserManager__encrypt_password func = UserManager._UserManager__encrypt_password
@ -75,96 +65,116 @@ class UserManagerTestCase(unittest.TestCase):
self.assertEqual(func(u'pass-word',u'pepper'), (u'd68c95a91ed7773aa57c7c044d2309a5bf1da2e7', u'pepper')) self.assertEqual(func(u'pass-word',u'pepper'), (u'd68c95a91ed7773aa57c7c044d2309a5bf1da2e7', u'pepper'))
self.assertEqual(func(u'éèàïô', u'ABC+'), (u'b639ba5217b89c906019d89d5816b407d8730898', u'ABC+')) self.assertEqual(func(u'éèàïô', u'ABC+'), (u'b639ba5217b89c906019d89d5816b407d8730898', u'ABC+'))
@db_session
def test_get_user(self): def test_get_user(self):
self.create_data()
# Get existing users # Get existing users
for name in ['alice', 'bob', 'charlie']: for name in ['alice', 'bob', 'charlie']:
user = self.store.find(db.User, db.User.name == name).one() user = db.User.get(name = name)
self.assertEqual(UserManager.get(self.store, user.id), (UserManager.SUCCESS, user)) self.assertEqual(UserManager.get(user.id), (UserManager.SUCCESS, user))
# Get with invalid UUID # Get with invalid UUID
self.assertEqual(UserManager.get(self.store, 'invalid-uuid'), (UserManager.INVALID_ID, None)) self.assertEqual(UserManager.get('invalid-uuid'), (UserManager.INVALID_ID, None))
self.assertEqual(UserManager.get(self.store, 0xfee1bad), (UserManager.INVALID_ID, None)) self.assertEqual(UserManager.get(0xfee1bad), (UserManager.INVALID_ID, None))
# Non-existent user # Non-existent user
self.assertEqual(UserManager.get(self.store, uuid.uuid4()), (UserManager.NO_SUCH_USER, None)) self.assertEqual(UserManager.get(uuid.uuid4()), (UserManager.NO_SUCH_USER, None))
@db_session
def test_add_user(self): def test_add_user(self):
# Added in setUp() self.create_data()
self.assertEqual(self.store.find(db.User).count(), 3) self.assertEqual(db.User.select().count(), 3)
# Create duplicate # Create duplicate
self.assertEqual(UserManager.add(self.store, 'alice', 'Alic3', 'alice@example.com', True), UserManager.NAME_EXISTS) self.assertEqual(UserManager.add('alice', 'Alic3', 'alice@example.com', True), UserManager.NAME_EXISTS)
@db_session
def test_delete_user(self): def test_delete_user(self):
self.create_data()
# Delete invalid UUID # Delete invalid UUID
self.assertEqual(UserManager.delete(self.store, 'invalid-uuid'), UserManager.INVALID_ID) self.assertEqual(UserManager.delete('invalid-uuid'), UserManager.INVALID_ID)
self.assertEqual(UserManager.delete(self.store, 0xfee1b4d), UserManager.INVALID_ID) self.assertEqual(UserManager.delete(0xfee1b4d), UserManager.INVALID_ID)
self.assertEqual(self.store.find(db.User).count(), 3) self.assertEqual(db.User.select().count(), 3)
# Delete non-existent user # Delete non-existent user
self.assertEqual(UserManager.delete(self.store, uuid.uuid4()), UserManager.NO_SUCH_USER) self.assertEqual(UserManager.delete(uuid.uuid4()), UserManager.NO_SUCH_USER)
self.assertEqual(self.store.find(db.User).count(), 3) self.assertEqual(db.User.select().count(), 3)
# Delete existing users # Delete existing users
for name in ['alice', 'bob', 'charlie']: for name in ['alice', 'bob', 'charlie']:
user = self.store.find(db.User, db.User.name == name).one() user = db.User.get(name = name)
self.assertEqual(UserManager.delete(self.store, user.id), UserManager.SUCCESS) self.assertEqual(UserManager.delete(user.id), UserManager.SUCCESS)
self.assertIsNone(self.store.get(db.User, user.id)) self.assertRaises(ObjectNotFound, db.User.__getitem__, user.id)
self.assertEqual(self.store.find(db.User).count(), 0) commit()
self.assertEqual(db.User.select().count(), 0)
@db_session
def test_delete_by_name(self): def test_delete_by_name(self):
self.create_data()
# Delete existing users # Delete existing users
for name in ['alice', 'bob', 'charlie']: for name in ['alice', 'bob', 'charlie']:
self.assertEqual(UserManager.delete_by_name(self.store, name), UserManager.SUCCESS) self.assertEqual(UserManager.delete_by_name(name), UserManager.SUCCESS)
self.assertEqual(self.store.find(db.User, db.User.name == name).count(), 0) self.assertFalse(db.User.exists(name = name))
# Delete non-existent user # Delete non-existent user
self.assertEqual(UserManager.delete_by_name(self.store, 'null'), UserManager.NO_SUCH_USER) self.assertEqual(UserManager.delete_by_name('null'), UserManager.NO_SUCH_USER)
@db_session
def test_try_auth(self): def test_try_auth(self):
self.create_data()
# Test authentication # Test authentication
for name in ['alice', 'bob', 'charlie']: for name in ['alice', 'bob', 'charlie']:
user = self.store.find(db.User, db.User.name == name).one() user = db.User.get(name = name)
self.assertEqual(UserManager.try_auth(self.store, name, name.upper()), (UserManager.SUCCESS, user)) self.assertEqual(UserManager.try_auth(name, name.upper()), (UserManager.SUCCESS, user))
# Wrong password # Wrong password
self.assertEqual(UserManager.try_auth(self.store, 'alice', 'bad'), (UserManager.WRONG_PASS, None)) self.assertEqual(UserManager.try_auth('alice', 'bad'), (UserManager.WRONG_PASS, None))
self.assertEqual(UserManager.try_auth(self.store, 'alice', 'alice'), (UserManager.WRONG_PASS, None)) self.assertEqual(UserManager.try_auth('alice', 'alice'), (UserManager.WRONG_PASS, None))
# Non-existent user # Non-existent user
self.assertEqual(UserManager.try_auth(self.store, 'null', 'null'), (UserManager.NO_SUCH_USER, None)) self.assertEqual(UserManager.try_auth('null', 'null'), (UserManager.NO_SUCH_USER, None))
@db_session
def test_change_password(self): def test_change_password(self):
self.create_data()
# With existing users # With existing users
for name in ['alice', 'bob', 'charlie']: for name in ['alice', 'bob', 'charlie']:
user = self.store.find(db.User, db.User.name == name).one() user = db.User.get(name = name)
# Good password # Good password
self.assertEqual(UserManager.change_password(self.store, user.id, name.upper(), 'newpass'), UserManager.SUCCESS) self.assertEqual(UserManager.change_password(user.id, name.upper(), 'newpass'), UserManager.SUCCESS)
self.assertEqual(UserManager.try_auth(self.store, name, 'newpass'), (UserManager.SUCCESS, user)) self.assertEqual(UserManager.try_auth(name, 'newpass'), (UserManager.SUCCESS, user))
# Old password # Old password
self.assertEqual(UserManager.try_auth(self.store, name, name.upper()), (UserManager.WRONG_PASS, None)) self.assertEqual(UserManager.try_auth(name, name.upper()), (UserManager.WRONG_PASS, None))
# Wrong password # Wrong password
self.assertEqual(UserManager.change_password(self.store, user.id, 'badpass', 'newpass'), UserManager.WRONG_PASS) self.assertEqual(UserManager.change_password(user.id, 'badpass', 'newpass'), UserManager.WRONG_PASS)
# Ensure we still got the same number of users # Ensure we still got the same number of users
self.assertEqual(self.store.find(db.User).count(), 3) self.assertEqual(db.User.select().count(), 3)
# With invalid UUID # With invalid UUID
self.assertEqual(UserManager.change_password(self.store, 'invalid-uuid', 'oldpass', 'newpass'), UserManager.INVALID_ID) self.assertEqual(UserManager.change_password('invalid-uuid', 'oldpass', 'newpass'), UserManager.INVALID_ID)
# Non-existent user # Non-existent user
self.assertEqual(UserManager.change_password(self.store, uuid.uuid4(), 'oldpass', 'newpass'), UserManager.NO_SUCH_USER) self.assertEqual(UserManager.change_password(uuid.uuid4(), 'oldpass', 'newpass'), UserManager.NO_SUCH_USER)
@db_session
def test_change_password2(self): def test_change_password2(self):
self.create_data()
# With existing users # With existing users
for name in ['alice', 'bob', 'charlie']: for name in ['alice', 'bob', 'charlie']:
self.assertEqual(UserManager.change_password2(self.store, name, 'newpass'), UserManager.SUCCESS) self.assertEqual(UserManager.change_password2(name, 'newpass'), UserManager.SUCCESS)
user = self.store.find(db.User, db.User.name == name).one() user = db.User.get(name = name)
self.assertEqual(UserManager.try_auth(self.store, name, 'newpass'), (UserManager.SUCCESS, user)) self.assertEqual(UserManager.try_auth(name, 'newpass'), (UserManager.SUCCESS, user))
self.assertEqual(UserManager.try_auth(self.store, name, name.upper()), (UserManager.WRONG_PASS, None)) self.assertEqual(UserManager.try_auth(name, name.upper()), (UserManager.WRONG_PASS, None))
# Non-existent user # Non-existent user
self.assertEqual(UserManager.change_password2(self.store, 'null', 'newpass'), UserManager.NO_SUCH_USER) self.assertEqual(UserManager.change_password2('null', 'newpass'), UserManager.NO_SUCH_USER)
def test_human_readable_error(self): def test_human_readable_error(self):
values = [ UserManager.SUCCESS, UserManager.INVALID_ID, UserManager.NO_SUCH_USER, UserManager.NAME_EXISTS, values = [ UserManager.SUCCESS, UserManager.INVALID_ID, UserManager.NO_SUCH_USER, UserManager.NAME_EXISTS,

View File

@ -10,22 +10,20 @@
import inspect import inspect
import io import io
import os
import shutil import shutil
import sys import sys
import unittest import unittest
import tempfile import tempfile
from supysonic.db import init_database, release_database
from supysonic.config import DefaultConfig from supysonic.config import DefaultConfig
from supysonic.managers.user import UserManager from supysonic.managers.user import UserManager
from supysonic.web import create_application, store from supysonic.web import create_application
class TestConfig(DefaultConfig): class TestConfig(DefaultConfig):
TESTING = True TESTING = True
LOGGER_HANDLER_POLICY = 'never' LOGGER_HANDLER_POLICY = 'never'
BASE = {
'database_uri': 'sqlite:',
'scanner_extensions': None
}
MIMETYPES = { MIMETYPES = {
'mp3': 'audio/mpeg', 'mp3': 'audio/mpeg',
'weirdextension': 'application/octet-stream' 'weirdextension': 'application/octet-stream'
@ -60,31 +58,37 @@ class TestBase(unittest.TestCase):
__with_api__ = False __with_api__ = False
def setUp(self): def setUp(self):
self.__dbfile = tempfile.mkstemp()[1]
self.__dir = tempfile.mkdtemp() self.__dir = tempfile.mkdtemp()
config = TestConfig(self.__with_webui__, self.__with_api__) config = TestConfig(self.__with_webui__, self.__with_api__)
config.BASE['database_uri'] = 'sqlite:///' + self.__dbfile
config.WEBAPP['cache_dir'] = self.__dir config.WEBAPP['cache_dir'] = self.__dir
init_database(config.BASE['database_uri'], True)
release_database()
app = create_application(config) app = create_application(config)
self.__ctx = app.app_context() self.__ctx = app.app_context()
self.__ctx.push() self.__ctx.push()
self.store = store
with io.open('schema/sqlite.sql', 'r') as sql:
schema = sql.read()
for statement in schema.split(';'):
self.store.execute(statement)
self.store.commit()
self.client = app.test_client() self.client = app.test_client()
UserManager.add(self.store, 'alice', 'Alic3', 'test@example.com', True) UserManager.add('alice', 'Alic3', 'test@example.com', True)
UserManager.add(self.store, 'bob', 'B0b', 'bob@example.com', False) UserManager.add('bob', 'B0b', 'bob@example.com', False)
@staticmethod
def __should_unload_module(module):
if module.startswith('supysonic'):
return not module.startswith('supysonic.db')
return False
def tearDown(self): def tearDown(self):
self.__ctx.pop() self.__ctx.pop()
release_database()
shutil.rmtree(self.__dir) shutil.rmtree(self.__dir)
os.remove(self.__dbfile)
to_unload = [ m for m in sys.modules if m.startswith('supysonic') ] to_unload = [ m for m in sorted(sys.modules) if self.__should_unload_module(m) ]
for m in to_unload: for m in to_unload:
del sys.modules[m] del sys.modules[m]