diff --git a/bin/supysonic-cli b/bin/supysonic-cli index cd1a33e..c63ff57 100755 --- a/bin/supysonic-cli +++ b/bin/supysonic-cli @@ -12,11 +12,11 @@ import sys from supysonic.cli import SupysonicCLI from supysonic.config import IniConfig -from supysonic.db import get_database, release_database +from supysonic.db import init_database, release_database if __name__ == "__main__": config = IniConfig.from_common_locations() - db = get_database(config.BASE['database_uri']) + init_database(config.BASE['database_uri']) cli = SupysonicCLI(config) if len(sys.argv) > 1: @@ -24,5 +24,5 @@ if __name__ == "__main__": else: cli.cmdloop() - release_database(db) + release_database() diff --git a/supysonic/config.py b/supysonic/config.py index a40fefc..26d0162 100644 --- a/supysonic/config.py +++ b/supysonic/config.py @@ -20,7 +20,7 @@ class DefaultConfig(object): tempdir = os.path.join(tempfile.gettempdir(), 'supysonic') BASE = { - 'database_uri': 'sqlite://' + os.path.join(tempdir, 'supysonic.db'), + 'database_uri': 'sqlite:///' + os.path.join(tempdir, 'supysonic.db'), 'scanner_extensions': None } WEBAPP = { diff --git a/supysonic/db.py b/supysonic/db.py index ee61870..9b1da74 100644 --- a/supysonic/db.py +++ b/supysonic/db.py @@ -248,7 +248,7 @@ class User(db.Entity): password = Required(str) salt = Required(str) admin = Required(bool, default = False) - lastfm_session = Optional(str) + lastfm_session = Optional(str, nullable = True) lastfm_status = Required(bool, default = True) # True: ok/unlinked, False: invalid session last_play = Optional(Track, column = 'last_play_id') @@ -450,15 +450,11 @@ def parse_uri(database_uri): return dict(provider = 'mysql', user = uri.username, passwd = uri.password, host = uri.hostname, db = uri.path[1:]) return dict() -def get_database(database_uri, create_tables = False): +def init_database(database_uri, create_tables = False): db.bind(**parse_uri(database_uri)) db.generate_mapping(create_tables = create_tables) - return db - -def release_database(db): - if not isinstance(db, Database): - raise TypeError('Expecting a pony.orm.Database instance') +def release_database(): db.disconnect() db.provider = None db.schema = None diff --git a/supysonic/frontend/__init__.py b/supysonic/frontend/__init__.py index 3db8481..f4845eb 100644 --- a/supysonic/frontend/__init__.py +++ b/supysonic/frontend/__init__.py @@ -11,8 +11,8 @@ from flask import session, request, redirect, url_for, current_app as app from functools import wraps +from pony.orm import db_session -from ..web import store from ..db import Artist, Album, Track from ..managers.user import UserManager @@ -27,7 +27,7 @@ def login_check(): request.user = None should_login = True if session.get('userid'): - code, user = UserManager.get(store, session.get('userid')) + code, user = UserManager.get(session.get('userid')) if code != UserManager.SUCCESS: session.clear() else: @@ -39,13 +39,14 @@ def login_check(): return redirect(url_for('login', returnUrl = request.script_root + request.url[len(request.url_root)-1:])) @app.route('/') +@db_session def index(): stats = { - 'artists': store.find(Artist).count(), - 'albums': store.find(Album).count(), - 'tracks': store.find(Track).count() + 'artists': Artist.select().count(), + 'albums': Album.select().count(), + 'tracks': Track.select().count() } - return render_template('home.html', stats = stats, admin = UserManager.get(store, session.get('userid'))[1].admin) + return render_template('home.html', stats = stats) def admin_only(f): @wraps(f) diff --git a/supysonic/frontend/folder.py b/supysonic/frontend/folder.py index f36cc81..64deffe 100644 --- a/supysonic/frontend/folder.py +++ b/supysonic/frontend/folder.py @@ -22,19 +22,20 @@ import os.path import uuid from flask import request, flash, render_template, redirect, url_for, current_app as app +from pony.orm import db_session from ..db import Folder from ..managers.user import UserManager from ..managers.folder import FolderManager from ..scanner import Scanner -from ..web import store from . import admin_only @app.route('/folder') @admin_only +@db_session def folder_index(): - return render_template('folders.html', folders = store.find(Folder, Folder.root == True)) + return render_template('folders.html', folders = Folder.select(lambda f: f.root)) @app.route('/folder/add') @admin_only @@ -55,7 +56,7 @@ def add_folder_post(): if error: return render_template('addfolder.html') - ret = FolderManager.add(store, name, path) + ret = FolderManager.add(name, path) if ret != FolderManager.SUCCESS: flash(FolderManager.error_str(ret)) return render_template('addfolder.html') @@ -73,7 +74,7 @@ def del_folder(id): flash('Invalid folder id') return redirect(url_for('folder_index')) - ret = FolderManager.delete(store, idid) + ret = FolderManager.delete(idid) if ret != FolderManager.SUCCESS: flash(FolderManager.error_str(ret)) else: @@ -84,16 +85,19 @@ def del_folder(id): @app.route('/folder/scan') @app.route('/folder/scan/') @admin_only +@db_session def scan_folder(id = None): extensions = app.config['BASE']['scanner_extensions'] if extensions: extensions = extensions.split(' ') - scanner = Scanner(store, extensions = extensions) + + scanner = Scanner(extensions = extensions) + if id is None: - for folder in store.find(Folder, Folder.root == True): + for folder in Folder.select(lambda f: f.root): scanner.scan(folder) else: - status, folder = FolderManager.get(store, id) + status, folder = FolderManager.get(id) if status != FolderManager.SUCCESS: flash(FolderManager.error_str(status)) return redirect(url_for('folder_index')) @@ -101,7 +105,6 @@ def scan_folder(id = None): scanner.finish() added, deleted = scanner.stats() - store.commit() flash('Added: %i artists, %i albums, %i tracks' % (added[0], added[1], added[2])) flash('Deleted: %i artists, %i albums, %i tracks' % (deleted[0], deleted[1], deleted[2])) diff --git a/supysonic/frontend/playlist.py b/supysonic/frontend/playlist.py index 00e4692..d3f62d2 100644 --- a/supysonic/frontend/playlist.py +++ b/supysonic/frontend/playlist.py @@ -21,33 +21,38 @@ import uuid from flask import request, flash, render_template, redirect, url_for, current_app as app +from pony.orm import db_session +from pony.orm import ObjectNotFound -from ..web import store from ..db import Playlist from ..managers.user import UserManager @app.route('/playlist') +@db_session def playlist_index(): return render_template('playlists.html', - mine = store.find(Playlist, Playlist.user_id == request.user.id), - others = store.find(Playlist, Playlist.user_id != request.user.id, Playlist.public == True)) + mine = Playlist.select(lambda p: p.user == request.user), + others = Playlist.select(lambda p: p.user != request.user and p.public)) @app.route('/playlist/') +@db_session def playlist_details(uid): try: - uid = uuid.UUID(uid) if type(uid) in (str, unicode) else uid + uid = uuid.UUID(uid) except: flash('Invalid playlist id') return redirect(url_for('playlist_index')) - playlist = store.get(Playlist, uid) - if not playlist: + try: + playlist = Playlist[uid] + except ObjectNotFound: flash('Unknown playlist') return redirect(url_for('playlist_index')) return render_template('playlist.html', playlist = playlist) @app.route('/playlist/', methods = [ 'POST' ]) +@db_session def playlist_update(uid): try: uid = uuid.UUID(uid) @@ -55,24 +60,25 @@ def playlist_update(uid): flash('Invalid playlist id') return redirect(url_for('playlist_index')) - playlist = store.get(Playlist, uid) - if not playlist: + try: + playlist = Playlist[uid] + except ObjectNotFound: flash('Unknown playlist') return redirect(url_for('playlist_index')) - if playlist.user_id != request.user.id: + if playlist.user.id != request.user.id: flash("You're not allowed to edit this playlist") elif not request.form.get('name'): flash('Missing playlist name') else: playlist.name = request.form.get('name') playlist.public = request.form.get('public') in (True, 'True', 1, '1', 'on', 'checked') - store.commit() flash('Playlist updated.') return playlist_details(uid) @app.route('/playlist/del/') +@db_session def playlist_delete(uid): try: uid = uuid.UUID(uid) @@ -80,14 +86,16 @@ def playlist_delete(uid): flash('Invalid playlist id') return redirect(url_for('playlist_index')) - playlist = store.get(Playlist, uid) - if not playlist: + try: + playlist = Playlist[uid] + except ObjectNotFound: flash('Unknown playlist') - elif playlist.user_id != request.user.id: + return redirect(url_for('playlist_index')) + + if playlist.user.id != request.user.id: flash("You're not allowed to delete this playlist") else: - store.remove(playlist) - store.commit() + playlist.delete() flash('Playlist deleted') return redirect(url_for('playlist_index')) diff --git a/supysonic/frontend/user.py b/supysonic/frontend/user.py index 573aa71..b016090 100644 --- a/supysonic/frontend/user.py +++ b/supysonic/frontend/user.py @@ -20,15 +20,16 @@ from flask import request, session, flash, render_template, redirect, url_for, current_app as app from functools import wraps +from pony.orm import db_session from ..db import User, ClientPrefs from ..lastfm import LastFm from ..managers.user import UserManager -from ..web import store from . import admin_only def me_or_uuid(f, arg = 'uid'): + @db_session @wraps(f) def decorated_func(*args, **kwargs): if kwargs: @@ -37,11 +38,11 @@ def me_or_uuid(f, arg = 'uid'): uid = args[0] if uid == 'me': - user = request.user + user = User[request.user.id] # Refetch user from previous transaction elif not request.user.admin: return redirect(url_for('index')) else: - code, user = UserManager.get(store, uid) + code, user = UserManager.get(uid) if code != UserManager.SUCCESS: flash(UserManager.error_str(code)) return redirect(url_for('index')) @@ -57,14 +58,14 @@ def me_or_uuid(f, arg = 'uid'): @app.route('/user') @admin_only +@db_session def user_index(): - return render_template('users.html', users = store.find(User)) + return render_template('users.html', users = User.select()) @app.route('/user/') @me_or_uuid def user_profile(uid, user): - prefs = store.find(ClientPrefs, ClientPrefs.user_id == user.id) - return render_template('profile.html', user = user, has_lastfm = app.config['LASTFM']['api_key'] != None, clients = prefs) + return render_template('profile.html', user = user, has_lastfm = app.config['LASTFM']['api_key'] != None, clients = user.clients) @app.route('/user/', methods = [ 'POST' ]) @me_or_uuid @@ -87,25 +88,24 @@ def update_clients(uid, user): app.logger.debug(clients_opts) for client, opts in clients_opts.iteritems(): - prefs = store.get(ClientPrefs, (user.id, client)) - if not prefs: + prefs = user.clients.select(lambda c: c.client_name == client).first() + if prefs is None: continue if 'delete' in opts and opts['delete'] in [ 'on', 'true', 'checked', 'selected', '1' ]: - store.remove(prefs) + prefs.delete() continue prefs.format = opts['format'] if 'format' in opts and opts['format'] else None prefs.bitrate = int(opts['bitrate']) if 'bitrate' in opts and opts['bitrate'] else None - store.commit() flash('Clients preferences updated.') return user_profile(uid, user) @app.route('/user//changeusername') @admin_only def change_username_form(uid): - code, user = UserManager.get(store, uid) + code, user = UserManager.get(uid) if code != UserManager.SUCCESS: flash(UserManager.error_str(code)) return redirect(url_for('index')) @@ -114,8 +114,9 @@ def change_username_form(uid): @app.route('/user//changeusername', methods = [ 'POST' ]) @admin_only +@db_session def change_username_post(uid): - code, user = UserManager.get(store, uid) + code, user = UserManager.get(uid) if code != UserManager.SUCCESS: return redirect(url_for('index')) @@ -123,7 +124,7 @@ def change_username_post(uid): if username in ('', None): flash('The username is required') return render_template('change_username.html', user = user) - if user.name != username and store.find(User, User.name == username).one(): + if user.name != username and User.get(name = username) is not None: flash('This name is already taken') return render_template('change_username.html', user = user) @@ -135,7 +136,6 @@ def change_username_post(uid): if user.name != username or user.admin != admin: user.name = username user.admin = admin - store.commit() flash("User '%s' updated." % username) else: flash("No changes for '%s'." % username) @@ -150,10 +150,9 @@ def change_mail_form(uid, user): @app.route('/user//changemail', methods = [ 'POST' ]) @me_or_uuid def change_mail_post(uid, user): - mail = request.form.get('mail') + mail = request.form.get('mail', '') # No validation, lol. user.mail = mail - store.commit() return redirect(url_for('user_profile', uid = uid)) @app.route('/user//changepass') @@ -182,9 +181,9 @@ def change_password_post(uid, user): if not error: if user.id == request.user.id: - status = UserManager.change_password(store, user.id, current, new) + status = UserManager.change_password(user.id, current, new) else: - status = UserManager.change_password2(store, user.name, new) + status = UserManager.change_password2(user.name, new) if status != UserManager.SUCCESS: flash(UserManager.error_str(status)) @@ -214,13 +213,12 @@ def add_user_post(): flash("The passwords don't match.") error = True - if admin is None: - admin = True if store.find(User, User.admin == True).count() == 0 else False - else: - admin = True + admin = admin is not None + if mail is None: + mail = '' if not error: - status = UserManager.add(store, name, passwd, mail, admin) + status = UserManager.add(name, passwd, mail, admin) if status == UserManager.SUCCESS: flash("User '%s' successfully added" % name) return redirect(url_for('user_index')) @@ -232,7 +230,7 @@ def add_user_post(): @app.route('/user/del/') @admin_only def del_user(uid): - status = UserManager.delete(store, uid) + status = UserManager.delete(uid) if status == UserManager.SUCCESS: flash('Deleted user') else: @@ -250,7 +248,6 @@ def lastfm_reg(uid, user): lfm = LastFm(app.config['LASTFM'], user, app.logger) status, error = lfm.link_account(token) - store.commit() flash(error if not status else 'Successfully linked LastFM account') return redirect(url_for('user_profile', uid = uid)) @@ -260,7 +257,6 @@ def lastfm_reg(uid, user): def lastfm_unreg(uid, user): lfm = LastFm(app.config['LASTFM'], user, app.logger) lfm.unlink_account() - store.commit() flash('Unlinked LastFM account') return redirect(url_for('user_profile', uid = uid)) @@ -284,7 +280,7 @@ def login(): error = True if not error: - status, user = UserManager.try_auth(store, name, password) + status, user = UserManager.try_auth(name, password) if status == UserManager.SUCCESS: session['userid'] = str(user.id) flash('Logged in!') diff --git a/supysonic/watcher.py b/supysonic/watcher.py index e36d8cb..4a0d90c 100644 --- a/supysonic/watcher.py +++ b/supysonic/watcher.py @@ -28,7 +28,7 @@ from threading import Thread, Condition, Timer from watchdog.observers import Observer from watchdog.events import PatternMatchingEventHandler -from .db import get_database, release_database, Folder +from .db import init_database, release_database, Folder from .scanner import Scanner OP_SCAN = 1 @@ -205,7 +205,7 @@ class SupysonicWatcher(object): def __init__(self, config): self.__config = config self.__running = True - self.__db = get_database(config.BASE['database_uri']) + init_database(config.BASE['database_uri']) def run(self): logger = logging.getLogger(__name__) @@ -230,7 +230,7 @@ class SupysonicWatcher(object): shouldrun = folders.exists() if not shouldrun: logger.info("No folder set. Exiting.") - release_database(self.__db) + release_database() return queue = ScannerProcessingQueue(self.__config.DAEMON['wait_delay'], logger) @@ -258,7 +258,7 @@ class SupysonicWatcher(object): observer.join() queue.stop() queue.join() - release_database(self.__db) + release_database() def stop(self): self.__running = False diff --git a/supysonic/web.py b/supysonic/web.py index 266a227..3d53b83 100644 --- a/supysonic/web.py +++ b/supysonic/web.py @@ -11,25 +11,11 @@ import mimetypes -from flask import Flask, g, current_app +from flask import Flask from os import makedirs, path -from werkzeug.local import LocalProxy from .config import IniConfig -from .db import get_store - -# Supysonic database open -def get_db(): - if not hasattr(g, 'database'): - g.database = get_store(current_app.config['BASE']['database_uri']) - return g.database - -# Supysonic database close -def close_db(error): - if hasattr(g, 'database'): - g.database.close() - -store = LocalProxy(get_db) +from .db import init_database, release_database def create_application(config = None): global app @@ -42,9 +28,6 @@ def create_application(config = None): config = IniConfig.from_common_locations() app.config.from_object(config) - # Close database connection on teardown - app.teardown_appcontext(close_db) - # Set loglevel logfile = app.config['WEBAPP']['log_file'] if logfile: @@ -63,6 +46,9 @@ def create_application(config = None): handler.setLevel(mapping.get(loglevel.upper(), logging.NOTSET)) app.logger.addHandler(handler) + # Initialize database + init_database(app.config['BASE']['database_uri']) + # Insert unknown mimetypes for k, v in app.config['MIMETYPES'].iteritems(): extension = '.' + k.lower() diff --git a/tests/base/test_cli.py b/tests/base/test_cli.py index 22bafb9..62418c8 100644 --- a/tests/base/test_cli.py +++ b/tests/base/test_cli.py @@ -18,7 +18,7 @@ from contextlib import contextmanager from pony.orm import db_session from StringIO import StringIO -from supysonic.db import Folder, User, get_database, release_database +from supysonic.db import Folder, User, init_database, release_database from supysonic.cli import SupysonicCLI from ..testbase import TestConfig @@ -30,7 +30,7 @@ class CLITestCase(unittest.TestCase): conf = TestConfig(False, False) self.__dbfile = tempfile.mkstemp()[1] conf.BASE['database_uri'] = 'sqlite:///' + self.__dbfile - self.__store = get_database(conf.BASE['database_uri'], True) + init_database(conf.BASE['database_uri'], True) self.__stdout = StringIO() self.__stderr = StringIO() @@ -39,7 +39,7 @@ class CLITestCase(unittest.TestCase): def tearDown(self): self.__stdout.close() self.__stderr.close() - release_database(self.__store) + release_database() os.unlink(self.__dbfile) @contextmanager diff --git a/tests/base/test_db.py b/tests/base/test_db.py index bacd6b0..36a9896 100644 --- a/tests/base/test_db.py +++ b/tests/base/test_db.py @@ -22,10 +22,10 @@ date_regex = re.compile(r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$') class DbTestCase(unittest.TestCase): def setUp(self): - self.store = db.get_database('sqlite:', True) + db.init_database('sqlite:', True) def tearDown(self): - db.release_database(self.store) + db.release_database() def create_some_folders(self): root_folder = db.Folder( @@ -104,7 +104,6 @@ class DbTestCase(unittest.TestCase): @db_session def test_folder_base(self): root_folder, child_folder = self.create_some_folders() - self.store.commit() MockUser = namedtuple('User', [ 'id' ]) user = MockUser(uuid.uuid4()) @@ -149,7 +148,6 @@ class DbTestCase(unittest.TestCase): rated = root_folder, rating = 5 ) - self.store.commit() root = root_folder.as_subsonic_child(user) self.assertIn('starred', root) @@ -169,7 +167,6 @@ class DbTestCase(unittest.TestCase): user = self.create_user() star = db.StarredArtist(user = user, starred = artist) - self.store.commit() artist_dict = artist.as_subsonic_artist(user) self.assertIsInstance(artist_dict, dict) @@ -183,7 +180,6 @@ class DbTestCase(unittest.TestCase): db.Album(name = 'Test Artist', artist = artist) # self-titled db.Album(name = 'The Album After The First One', artist = artist) - self.store.commit() artist_dict = artist.as_subsonic_artist(user) self.assertEqual(artist_dict['albumCount'], 2) @@ -198,13 +194,11 @@ class DbTestCase(unittest.TestCase): user = user, starred = album ) - self.store.commit() # No tracks, shouldn't be stored under normal circumstances self.assertRaises(ValueError, album.as_subsonic_album, user) self.create_some_tracks(artist, album) - self.store.commit() album_dict = album.as_subsonic_album(user) self.assertIsInstance(album_dict, dict) @@ -227,7 +221,6 @@ class DbTestCase(unittest.TestCase): @db_session def test_track(self): track1, track2 = self.create_some_tracks() - self.store.commit() # Assuming SQLite doesn't enforce foreign key constraints MockUser = namedtuple('User', [ 'id' ]) @@ -245,7 +238,6 @@ class DbTestCase(unittest.TestCase): @db_session def test_user(self): user = self.create_user() - self.store.commit() user_dict = user.as_subsonic_user() self.assertIsInstance(user_dict, dict) @@ -258,7 +250,6 @@ class DbTestCase(unittest.TestCase): user = user, message = 'Hello world!' ) - self.store.commit() line_dict = line.responsize() self.assertIsInstance(line_dict, dict) diff --git a/tests/base/test_scanner.py b/tests/base/test_scanner.py index 719cb09..63d68e5 100644 --- a/tests/base/test_scanner.py +++ b/tests/base/test_scanner.py @@ -24,7 +24,7 @@ from supysonic.scanner import Scanner class ScannerTestCase(unittest.TestCase): def setUp(self): - self.store = db.get_database('sqlite:', True) + db.init_database('sqlite:', True) FolderManager.add('folder', os.path.abspath('tests/assets')) with db_session: @@ -37,7 +37,7 @@ class ScannerTestCase(unittest.TestCase): def tearDown(self): self.scanner.finish() - db.release_database(self.store) + db.release_database() @contextmanager def __temporary_track_copy(self): diff --git a/tests/base/test_watcher.py b/tests/base/test_watcher.py index 63a82a0..672bff7 100644 --- a/tests/base/test_watcher.py +++ b/tests/base/test_watcher.py @@ -21,7 +21,7 @@ from contextlib import contextmanager from pony.orm import db_session from threading import Thread -from supysonic.db import get_database, release_database, Track, Artist +from supysonic.db import init_database, release_database, Track, Artist from supysonic.managers.folder import FolderManager from supysonic.watcher import SupysonicWatcher @@ -42,7 +42,8 @@ class WatcherTestBase(unittest.TestCase): def setUp(self): self.__dbfile = tempfile.mkstemp()[1] dburi = 'sqlite:///' + self.__dbfile - release_database(get_database(dburi, True)) + init_database(dburi, True) + release_database() conf = WatcherTestConfig(dburi) self.__sleep_time = conf.DAEMON['wait_delay'] + 1 @@ -69,9 +70,9 @@ class WatcherTestBase(unittest.TestCase): @contextmanager def _tempdbrebind(self): - db = get_database('sqlite:///' + self.__dbfile) + init_database('sqlite:///' + self.__dbfile) try: yield - finally: release_database(db) + finally: release_database() class NothingToWatchTestCase(WatcherTestBase): def test_spawn_useless_watcher(self): diff --git a/tests/frontend/test_folder.py b/tests/frontend/test_folder.py index 30c9251..e3cd968 100644 --- a/tests/frontend/test_folder.py +++ b/tests/frontend/test_folder.py @@ -11,6 +11,8 @@ import uuid +from pony.orm import db_session + from supysonic.db import Folder from .frontendtestbase import FrontendTestBase @@ -50,20 +52,22 @@ class FolderTestCase(FrontendTestBase): self.assertIn('Add Folder', rv.data) rv = self.client.post('/folder/add', data = { 'name': 'name', 'path': 'tests/assets' }, follow_redirects = True) self.assertIn('created', rv.data) - self.assertEqual(self.store.find(Folder).count(), 1) + with db_session: + self.assertEqual(Folder.select().count(), 1) def test_delete(self): - folder = Folder() - folder.name = 'folder' - folder.path = 'tests/assets' - folder.root = True - self.store.add(folder) - self.store.commit() + with db_session: + folder = Folder( + name = 'folder', + path = 'tests/assets', + root = True + ) self._login('bob', 'B0b') rv = self.client.get('/folder/del/' + str(folder.id), follow_redirects = True) self.assertIn('There\'s nothing much to see', rv.data) - self.assertEqual(self.store.find(Folder).count(), 1) + with db_session: + self.assertEqual(Folder.select().count(), 1) self._logout() self._login('alice', 'Alic3') @@ -73,15 +77,17 @@ class FolderTestCase(FrontendTestBase): self.assertIn('No such folder', rv.data) rv = self.client.get('/folder/del/' + str(folder.id), follow_redirects = True) self.assertIn('Music folders', rv.data) - self.assertEqual(self.store.find(Folder).count(), 0) + with db_session: + self.assertEqual(Folder.select().count(), 0) def test_scan(self): - folder = Folder() - folder.name = 'folder' - folder.path = 'tests/assets' - folder.root = True - self.store.add(folder) - self.store.commit() + with db_session: + folder = Folder( + name = 'folder', + path = 'tests/assets', + root = True, + ) + self._login('alice', 'Alic3') rv = self.client.get('/folder/scan/string', follow_redirects = True) diff --git a/tests/frontend/test_login.py b/tests/frontend/test_login.py index b908abe..ccf7904 100644 --- a/tests/frontend/test_login.py +++ b/tests/frontend/test_login.py @@ -12,6 +12,8 @@ import uuid +from pony.orm import db_session + from supysonic.db import User from .frontendtestbase import FrontendTestBase @@ -50,8 +52,9 @@ class LoginTestCase(FrontendTestBase): def test_root_with_valid_session(self): # Root with valid session - with self.client.session_transaction() as sess: - sess['userid'] = self.store.find(User, User.name == 'alice').one().id + with db_session: + with self.client.session_transaction() as sess: + sess['userid'] = User.get(name = 'alice').id rv = self.client.get('/', follow_redirects=True) self.assertIn('alice', rv.data) self.assertIn('Log out', rv.data) diff --git a/tests/frontend/test_playlist.py b/tests/frontend/test_playlist.py index 80e6bed..b52dbfe 100644 --- a/tests/frontend/test_playlist.py +++ b/tests/frontend/test_playlist.py @@ -11,6 +11,8 @@ import uuid +from pony.orm import db_session + from supysonic.db import Folder, Artist, Album, Track, Playlist, User from .frontendtestbase import FrontendTestBase @@ -19,43 +21,34 @@ class PlaylistTestCase(FrontendTestBase): def setUp(self): super(PlaylistTestCase, self).setUp() - folder = Folder() - folder.name = 'Root' - folder.path = 'tests/assets' - folder.root = True + with db_session: + folder = Folder(name = 'Root', path = 'tests/assets', root = True) + artist = Artist(name = 'Artist!') + album = Album(name = 'Album!', artist = artist) - artist = Artist() - artist.name = 'Artist!' + track = Track( + path = 'tests/assets/23bytes', + title = '23bytes', + artist = artist, + album = album, + folder = folder, + root_folder = folder, + duration = 2, + disc = 1, + number = 1, + content_type = 'audio/mpeg', + bitrate = 320, + last_modification = 0 + ) - album = Album() - album.name = 'Album!' - album.artist = artist + playlist = Playlist( + name = 'Playlist!', + user = User.get(name = 'alice') + ) + for _ in range(4): + playlist.add(track) - track = Track() - track.path = 'tests/assets/23bytes' - track.title = '23bytes' - track.artist = artist - track.album = album - track.folder = folder - track.root_folder = folder - track.duration = 2 - track.disc = 1 - track.number = 1 - track.content_type = 'audio/mpeg' - track.bitrate = 320 - track.last_modification = 0 - - playlist = Playlist() - playlist.name = 'Playlist!' - playlist.user = self.store.find(User, User.name == 'alice').one() - for _ in range(4): - playlist.add(track) - - self.store.add(track) - self.store.add(playlist) - self.store.commit() - - self.playlist = playlist + self.playlistid = playlist.id def test_index(self): self._login('alice', 'Alic3') @@ -68,7 +61,7 @@ class PlaylistTestCase(FrontendTestBase): self.assertIn('Invalid', rv.data) rv = self.client.get('/playlist/' + str(uuid.uuid4()), follow_redirects = True) self.assertIn('Unknown', rv.data) - rv = self.client.get('/playlist/' + str(self.playlist.id)) + rv = self.client.get('/playlist/' + str(self.playlistid)) self.assertIn('Playlist!', rv.data) self.assertIn('23bytes', rv.data) self.assertIn('Artist!', rv.data) @@ -80,22 +73,25 @@ class PlaylistTestCase(FrontendTestBase): self.assertIn('Invalid', rv.data) rv = self.client.post('/playlist/' + str(uuid.uuid4()), follow_redirects = True) self.assertIn('Unknown', rv.data) - rv = self.client.post('/playlist/' + str(self.playlist.id), follow_redirects = True) + rv = self.client.post('/playlist/' + str(self.playlistid), follow_redirects = True) self.assertNotIn('updated', rv.data) self.assertIn('not allowed', rv.data) self._logout() self._login('alice', 'Alic3') - rv = self.client.post('/playlist/' + str(self.playlist.id), follow_redirects = True) + rv = self.client.post('/playlist/' + str(self.playlistid), follow_redirects = True) self.assertNotIn('updated', rv.data) self.assertIn('Missing', rv.data) - self.assertEqual(self.playlist.name, 'Playlist!') + with db_session: + self.assertEqual(Playlist[self.playlistid].name, 'Playlist!') - rv = self.client.post('/playlist/' + str(self.playlist.id), data = { 'name': 'abc', 'public': True }, follow_redirects = True) + rv = self.client.post('/playlist/' + str(self.playlistid), data = { 'name': 'abc', 'public': True }, follow_redirects = True) self.assertIn('updated', rv.data) self.assertNotIn('not allowed', rv.data) - self.assertEqual(self.playlist.name, 'abc') - self.assertTrue(self.playlist.public) + with db_session: + playlist = Playlist[self.playlistid] + self.assertEqual(playlist.name, 'abc') + self.assertTrue(playlist.public) def test_delete(self): self._login('bob', 'B0b') @@ -103,15 +99,17 @@ class PlaylistTestCase(FrontendTestBase): self.assertIn('Invalid', rv.data) rv = self.client.get('/playlist/del/' + str(uuid.uuid4()), follow_redirects = True) self.assertIn('Unknown', rv.data) - rv = self.client.get('/playlist/del/' + str(self.playlist.id), follow_redirects = True) + rv = self.client.get('/playlist/del/' + str(self.playlistid), follow_redirects = True) self.assertIn('not allowed', rv.data) - self.assertEqual(self.store.find(Playlist).count(), 1) + with db_session: + self.assertEqual(Playlist.select().count(), 1) self._logout() self._login('alice', 'Alic3') - rv = self.client.get('/playlist/del/' + str(self.playlist.id), follow_redirects = True) + rv = self.client.get('/playlist/del/' + str(self.playlistid), follow_redirects = True) self.assertIn('deleted', rv.data) - self.assertEqual(self.store.find(Playlist).count(), 0) + with db_session: + self.assertEqual(Playlist.select().count(), 0) if __name__ == '__main__': unittest.main() diff --git a/tests/frontend/test_user.py b/tests/frontend/test_user.py index cb695d4..26bd89d 100644 --- a/tests/frontend/test_user.py +++ b/tests/frontend/test_user.py @@ -11,6 +11,8 @@ import uuid +from pony.orm import db_session + from supysonic.db import User, ClientPrefs from .frontendtestbase import FrontendTestBase @@ -19,7 +21,8 @@ class UserTestCase(FrontendTestBase): def setUp(self): super(UserTestCase, self).setUp() - self.users = { u.name: u for u in self.store.find(User) } + with db_session: + self.users = { u.name: u.id for u in User.select() } def test_index(self): self._login('bob', 'B0b') @@ -38,18 +41,15 @@ class UserTestCase(FrontendTestBase): self.assertIn('Invalid', rv.data) rv = self.client.get('/user/' + str(uuid.uuid4()), follow_redirects = True) self.assertIn('No such user', rv.data) - rv = self.client.get('/user/' + str(self.users['bob'].id)) + rv = self.client.get('/user/' + str(self.users['bob'])) self.assertIn('bob', rv.data) self._logout() - prefs = ClientPrefs() - prefs.user_id = self.users['bob'].id - prefs.client_name = 'tests' - self.store.add(prefs) - self.store.commit() + with db_session: + ClientPrefs(user = User[self.users['bob']], client_name = 'tests') self._login('bob', 'B0b') - rv = self.client.get('/user/' + str(self.users['alice'].id), follow_redirects = True) + rv = self.client.get('/user/' + str(self.users['alice']), follow_redirects = True) self.assertIn('There\'s nothing much to see', rv.data) self.assertNotIn('

bob

', rv.data) rv = self.client.get('/user/me') @@ -68,19 +68,19 @@ class UserTestCase(FrontendTestBase): self.client.post('/user/me', data = { 'n_': 'o' }) self.client.post('/user/me', data = { 'inexisting_client': 'setting' }) - prefs = ClientPrefs() - prefs.user_id = self.users['alice'].id - prefs.client_name = 'tests' - self.store.add(prefs) - self.store.commit() + with db_session: + ClientPrefs(user = User[self.users['alice']], client_name = 'tests') rv = self.client.post('/user/me', data = { 'tests_format': 'mp3', 'tests_bitrate': 128 }) self.assertIn('updated', rv.data) - self.assertEqual(prefs.format, 'mp3') - self.assertEqual(prefs.bitrate, 128) + with db_session: + prefs = ClientPrefs[User[self.users['alice']], 'tests'] + self.assertEqual(prefs.format, 'mp3') + self.assertEqual(prefs.bitrate, 128) self.client.post('/user/me', data = { 'tests_delete': 1 }) - self.assertEqual(self.store.find(ClientPrefs).count(), 0) + with db_session: + self.assertEqual(ClientPrefs.select().count(), 0) def test_change_username_get(self): self._login('bob', 'B0b') @@ -93,13 +93,13 @@ class UserTestCase(FrontendTestBase): self.assertIn('Invalid', rv.data) rv = self.client.get('/user/{}/changeusername'.format(uuid.uuid4()), follow_redirects = True) self.assertIn('No such user', rv.data) - self.client.get('/user/{}/changeusername'.format(self.users['bob'].id)) + self.client.get('/user/{}/changeusername'.format(self.users['bob'])) def test_change_username_post(self): self._login('alice', 'Alic3') self.client.post('/user/whatever/changeusername') - path = '/user/{}/changeusername'.format(self.users['bob'].id) + path = '/user/{}/changeusername'.format(self.users['bob']) rv = self.client.post(path, follow_redirects = True) self.assertIn('required', rv.data) rv = self.client.post(path, data = { 'user': 'bob' }, follow_redirects = True) @@ -107,10 +107,13 @@ class UserTestCase(FrontendTestBase): rv = self.client.post(path, data = { 'user': 'b0b', 'admin': 1 }, follow_redirects = True) self.assertIn('updated', rv.data) self.assertIn('b0b', rv.data) - self.assertEqual(self.users['bob'].name, 'b0b') - self.assertTrue(self.users['bob'].admin) + with db_session: + bob = User[self.users['bob']] + self.assertEqual(bob.name, 'b0b') + self.assertTrue(bob.admin) rv = self.client.post(path, data = { 'user': 'alice' }, follow_redirects = True) - self.assertEqual(self.users['bob'].name, 'b0b') + with db_session: + self.assertEqual(User[self.users['bob']].name, 'b0b') def test_change_mail_get(self): self._login('alice', 'Alic3') @@ -126,7 +129,7 @@ class UserTestCase(FrontendTestBase): self._login('alice', 'Alic3') rv = self.client.get('/user/me/changepass') self.assertIn('Current password', rv.data) - rv = self.client.get('/user/{}/changepass'.format(self.users['bob'].id)) + rv = self.client.get('/user/{}/changepass'.format(self.users['bob'])) self.assertNotIn('Current password', rv.data) def test_change_password_post(self): @@ -151,7 +154,7 @@ class UserTestCase(FrontendTestBase): rv = self._login('alice', 'alice') self.assertIn('Logged in', rv.data) - path = '/user/{}/changepass'.format(self.users['bob'].id) + path = '/user/{}/changepass'.format(self.users['bob']) rv = self.client.post(path) self.assertIn('required', rv.data) rv = self.client.post(path, data = { 'new': 'alice' }) @@ -162,7 +165,6 @@ class UserTestCase(FrontendTestBase): rv = self._login('bob', 'alice') self.assertIn('Logged in', rv.data) - def test_add_get(self): self._login('bob', 'B0b') rv = self.client.get('/user/add', follow_redirects = True) @@ -186,22 +188,25 @@ class UserTestCase(FrontendTestBase): self.assertIn('passwords don', rv.data) rv = self.client.post('/user/add', data = { 'user': 'alice', 'passwd': 'passwd', 'passwd_confirm': 'passwd' }) self.assertIn('already a user with that name', rv.data) - self.assertEqual(self.store.find(User).count(), 2) + with db_session: + self.assertEqual(User.select().count(), 2) rv = self.client.post('/user/add', data = { 'user': 'user', 'passwd': 'passwd', 'passwd_confirm': 'passwd', 'admin': 1 }, follow_redirects = True) self.assertIn('added', rv.data) - self.assertEqual(self.store.find(User).count(), 3) + with db_session: + self.assertEqual(User.select().count(), 3) self._logout() rv = self._login('user', 'passwd') self.assertIn('Logged in', rv.data) def test_delete(self): - path = '/user/del/{}'.format(self.users['bob'].id) + path = '/user/del/{}'.format(self.users['bob']) self._login('bob', 'B0b') rv = self.client.get(path, follow_redirects = True) self.assertIn('There\'s nothing much to see', rv.data) - self.assertEqual(self.store.find(User).count(), 2) + with db_session: + self.assertEqual(User.select().count(), 2) self._logout() self._login('alice', 'Alic3') @@ -211,7 +216,8 @@ class UserTestCase(FrontendTestBase): self.assertIn('No such user', rv.data) rv = self.client.get(path, follow_redirects = True) self.assertIn('Deleted', rv.data) - self.assertEqual(self.store.find(User).count(), 1) + with db_session: + self.assertEqual(User.select().count(), 1) self._logout() rv = self._login('bob', 'B0b') self.assertIn('No such user', rv.data) diff --git a/tests/managers/test_manager_folder.py b/tests/managers/test_manager_folder.py index 97e721d..f1a5c5f 100644 --- a/tests/managers/test_manager_folder.py +++ b/tests/managers/test_manager_folder.py @@ -25,12 +25,17 @@ from pony.orm import db_session, ObjectNotFound class FolderManagerTestCase(unittest.TestCase): def setUp(self): # Create an empty sqlite database in memory - self.store = db.get_database('sqlite:', True) + db.init_database('sqlite:', True) # Create some temporary directories self.media_dir = tempfile.mkdtemp() self.music_dir = tempfile.mkdtemp() + def tearDown(self): + db.release_database() + shutil.rmtree(self.media_dir) + shutil.rmtree(self.music_dir) + @db_session def create_folders(self): # Add test folders @@ -62,11 +67,6 @@ class FolderManagerTestCase(unittest.TestCase): last_modification = 0 ) - def tearDown(self): - db.release_database(self.store) - shutil.rmtree(self.media_dir) - shutil.rmtree(self.music_dir) - @db_session def test_get_folder(self): self.create_folders() diff --git a/tests/managers/test_manager_user.py b/tests/managers/test_manager_user.py index 9cf2cba..632dc22 100644 --- a/tests/managers/test_manager_user.py +++ b/tests/managers/test_manager_user.py @@ -17,13 +17,16 @@ import io import unittest import uuid -from pony.orm import db_session +from pony.orm import db_session, commit from pony.orm import ObjectNotFound class UserManagerTestCase(unittest.TestCase): def setUp(self): # Create an empty sqlite database in memory - self.store = db.get_database('sqlite:', True) + db.init_database('sqlite:', True) + + def tearDown(self): + db.release_database() @db_session def create_data(self): @@ -56,9 +59,6 @@ class UserManagerTestCase(unittest.TestCase): ) playlist.add(track) - def tearDown(self): - db.release_database(self.store) - def test_encrypt_password(self): func = UserManager._UserManager__encrypt_password self.assertEqual(func(u'password',u'salt'), (u'59b3e8d637cf97edbe2384cf59cb7453dfe30789', u'salt')) @@ -107,7 +107,7 @@ class UserManagerTestCase(unittest.TestCase): user = db.User.get(name = name) self.assertEqual(UserManager.delete(user.id), UserManager.SUCCESS) self.assertRaises(ObjectNotFound, db.User.__getitem__, user.id) - self.store.commit() + commit() self.assertEqual(db.User.select().count(), 0) @db_session diff --git a/tests/testbase.py b/tests/testbase.py index bd4c2aa..1ae5330 100644 --- a/tests/testbase.py +++ b/tests/testbase.py @@ -10,22 +10,20 @@ import inspect import io +import os import shutil import sys import unittest import tempfile +from supysonic.db import init_database, release_database from supysonic.config import DefaultConfig from supysonic.managers.user import UserManager -from supysonic.web import create_application, store +from supysonic.web import create_application class TestConfig(DefaultConfig): TESTING = True LOGGER_HANDLER_POLICY = 'never' - BASE = { - 'database_uri': 'sqlite:', - 'scanner_extensions': None - } MIMETYPES = { 'mp3': 'audio/mpeg', 'weirdextension': 'application/octet-stream' @@ -60,31 +58,37 @@ class TestBase(unittest.TestCase): __with_api__ = False def setUp(self): + self.__dbfile = tempfile.mkstemp()[1] self.__dir = tempfile.mkdtemp() config = TestConfig(self.__with_webui__, self.__with_api__) + config.BASE['database_uri'] = 'sqlite:///' + self.__dbfile config.WEBAPP['cache_dir'] = self.__dir - app = create_application(config) - self.__ctx = app.app_context() - self.__ctx.push() + init_database(config.BASE['database_uri'], True) + release_database() - self.store = store - with io.open('schema/sqlite.sql', 'r') as sql: - schema = sql.read() - for statement in schema.split(';'): - self.store.execute(statement) - self.store.commit() + app = create_application(config) + #self.__ctx = app.app_context() + #self.__ctx.push() self.client = app.test_client() - UserManager.add(self.store, 'alice', 'Alic3', 'test@example.com', True) - UserManager.add(self.store, 'bob', 'B0b', 'bob@example.com', False) + UserManager.add('alice', 'Alic3', 'test@example.com', True) + UserManager.add('bob', 'B0b', 'bob@example.com', False) + + @staticmethod + def __should_unload_module(module): + if module.startswith('supysonic'): + return not module.startswith('supysonic.db') + return False def tearDown(self): - self.__ctx.pop() + #self.__ctx.pop() + release_database() shutil.rmtree(self.__dir) + os.remove(self.__dbfile) - to_unload = [ m for m in sys.modules if m.startswith('supysonic') ] + to_unload = [ m for m in sorted(sys.modules) if self.__should_unload_module(m) ] for m in to_unload: del sys.modules[m]