1
0
mirror of https://github.com/spl0k/supysonic.git synced 2024-12-22 17:06:17 +00:00

I'm on a pony

This commit is contained in:
spl0k 2017-12-19 23:16:55 +01:00
parent 2428ffeb57
commit 6daedc6919
20 changed files with 230 additions and 231 deletions

View File

@ -12,11 +12,11 @@ import sys
from supysonic.cli import SupysonicCLI from supysonic.cli import SupysonicCLI
from supysonic.config import IniConfig from supysonic.config import IniConfig
from supysonic.db import get_database, release_database from supysonic.db import init_database, release_database
if __name__ == "__main__": if __name__ == "__main__":
config = IniConfig.from_common_locations() config = IniConfig.from_common_locations()
db = get_database(config.BASE['database_uri']) init_database(config.BASE['database_uri'])
cli = SupysonicCLI(config) cli = SupysonicCLI(config)
if len(sys.argv) > 1: if len(sys.argv) > 1:
@ -24,5 +24,5 @@ if __name__ == "__main__":
else: else:
cli.cmdloop() cli.cmdloop()
release_database(db) release_database()

View File

@ -20,7 +20,7 @@ class DefaultConfig(object):
tempdir = os.path.join(tempfile.gettempdir(), 'supysonic') tempdir = os.path.join(tempfile.gettempdir(), 'supysonic')
BASE = { BASE = {
'database_uri': 'sqlite://' + os.path.join(tempdir, 'supysonic.db'), 'database_uri': 'sqlite:///' + os.path.join(tempdir, 'supysonic.db'),
'scanner_extensions': None 'scanner_extensions': None
} }
WEBAPP = { WEBAPP = {

View File

@ -248,7 +248,7 @@ class User(db.Entity):
password = Required(str) password = Required(str)
salt = Required(str) salt = Required(str)
admin = Required(bool, default = False) admin = Required(bool, default = False)
lastfm_session = Optional(str) lastfm_session = Optional(str, nullable = True)
lastfm_status = Required(bool, default = True) # True: ok/unlinked, False: invalid session lastfm_status = Required(bool, default = True) # True: ok/unlinked, False: invalid session
last_play = Optional(Track, column = 'last_play_id') last_play = Optional(Track, column = 'last_play_id')
@ -450,15 +450,11 @@ def parse_uri(database_uri):
return dict(provider = 'mysql', user = uri.username, passwd = uri.password, host = uri.hostname, db = uri.path[1:]) return dict(provider = 'mysql', user = uri.username, passwd = uri.password, host = uri.hostname, db = uri.path[1:])
return dict() return dict()
def get_database(database_uri, create_tables = False): def init_database(database_uri, create_tables = False):
db.bind(**parse_uri(database_uri)) db.bind(**parse_uri(database_uri))
db.generate_mapping(create_tables = create_tables) db.generate_mapping(create_tables = create_tables)
return db
def release_database(db):
if not isinstance(db, Database):
raise TypeError('Expecting a pony.orm.Database instance')
def release_database():
db.disconnect() db.disconnect()
db.provider = None db.provider = None
db.schema = None db.schema = None

View File

@ -11,8 +11,8 @@
from flask import session, request, redirect, url_for, current_app as app from flask import session, request, redirect, url_for, current_app as app
from functools import wraps from functools import wraps
from pony.orm import db_session
from ..web import store
from ..db import Artist, Album, Track from ..db import Artist, Album, Track
from ..managers.user import UserManager from ..managers.user import UserManager
@ -27,7 +27,7 @@ def login_check():
request.user = None request.user = None
should_login = True should_login = True
if session.get('userid'): if session.get('userid'):
code, user = UserManager.get(store, session.get('userid')) code, user = UserManager.get(session.get('userid'))
if code != UserManager.SUCCESS: if code != UserManager.SUCCESS:
session.clear() session.clear()
else: else:
@ -39,13 +39,14 @@ def login_check():
return redirect(url_for('login', returnUrl = request.script_root + request.url[len(request.url_root)-1:])) return redirect(url_for('login', returnUrl = request.script_root + request.url[len(request.url_root)-1:]))
@app.route('/') @app.route('/')
@db_session
def index(): def index():
stats = { stats = {
'artists': store.find(Artist).count(), 'artists': Artist.select().count(),
'albums': store.find(Album).count(), 'albums': Album.select().count(),
'tracks': store.find(Track).count() 'tracks': Track.select().count()
} }
return render_template('home.html', stats = stats, admin = UserManager.get(store, session.get('userid'))[1].admin) return render_template('home.html', stats = stats)
def admin_only(f): def admin_only(f):
@wraps(f) @wraps(f)

View File

@ -22,19 +22,20 @@ import os.path
import uuid import uuid
from flask import request, flash, render_template, redirect, url_for, current_app as app from flask import request, flash, render_template, redirect, url_for, current_app as app
from pony.orm import db_session
from ..db import Folder from ..db import Folder
from ..managers.user import UserManager from ..managers.user import UserManager
from ..managers.folder import FolderManager from ..managers.folder import FolderManager
from ..scanner import Scanner from ..scanner import Scanner
from ..web import store
from . import admin_only from . import admin_only
@app.route('/folder') @app.route('/folder')
@admin_only @admin_only
@db_session
def folder_index(): def folder_index():
return render_template('folders.html', folders = store.find(Folder, Folder.root == True)) return render_template('folders.html', folders = Folder.select(lambda f: f.root))
@app.route('/folder/add') @app.route('/folder/add')
@admin_only @admin_only
@ -55,7 +56,7 @@ def add_folder_post():
if error: if error:
return render_template('addfolder.html') return render_template('addfolder.html')
ret = FolderManager.add(store, name, path) ret = FolderManager.add(name, path)
if ret != FolderManager.SUCCESS: if ret != FolderManager.SUCCESS:
flash(FolderManager.error_str(ret)) flash(FolderManager.error_str(ret))
return render_template('addfolder.html') return render_template('addfolder.html')
@ -73,7 +74,7 @@ def del_folder(id):
flash('Invalid folder id') flash('Invalid folder id')
return redirect(url_for('folder_index')) return redirect(url_for('folder_index'))
ret = FolderManager.delete(store, idid) ret = FolderManager.delete(idid)
if ret != FolderManager.SUCCESS: if ret != FolderManager.SUCCESS:
flash(FolderManager.error_str(ret)) flash(FolderManager.error_str(ret))
else: else:
@ -84,16 +85,19 @@ def del_folder(id):
@app.route('/folder/scan') @app.route('/folder/scan')
@app.route('/folder/scan/<id>') @app.route('/folder/scan/<id>')
@admin_only @admin_only
@db_session
def scan_folder(id = None): def scan_folder(id = None):
extensions = app.config['BASE']['scanner_extensions'] extensions = app.config['BASE']['scanner_extensions']
if extensions: if extensions:
extensions = extensions.split(' ') extensions = extensions.split(' ')
scanner = Scanner(store, extensions = extensions)
scanner = Scanner(extensions = extensions)
if id is None: if id is None:
for folder in store.find(Folder, Folder.root == True): for folder in Folder.select(lambda f: f.root):
scanner.scan(folder) scanner.scan(folder)
else: else:
status, folder = FolderManager.get(store, id) status, folder = FolderManager.get(id)
if status != FolderManager.SUCCESS: if status != FolderManager.SUCCESS:
flash(FolderManager.error_str(status)) flash(FolderManager.error_str(status))
return redirect(url_for('folder_index')) return redirect(url_for('folder_index'))
@ -101,7 +105,6 @@ def scan_folder(id = None):
scanner.finish() scanner.finish()
added, deleted = scanner.stats() added, deleted = scanner.stats()
store.commit()
flash('Added: %i artists, %i albums, %i tracks' % (added[0], added[1], added[2])) flash('Added: %i artists, %i albums, %i tracks' % (added[0], added[1], added[2]))
flash('Deleted: %i artists, %i albums, %i tracks' % (deleted[0], deleted[1], deleted[2])) flash('Deleted: %i artists, %i albums, %i tracks' % (deleted[0], deleted[1], deleted[2]))

View File

@ -21,33 +21,38 @@
import uuid import uuid
from flask import request, flash, render_template, redirect, url_for, current_app as app from flask import request, flash, render_template, redirect, url_for, current_app as app
from pony.orm import db_session
from pony.orm import ObjectNotFound
from ..web import store
from ..db import Playlist from ..db import Playlist
from ..managers.user import UserManager from ..managers.user import UserManager
@app.route('/playlist') @app.route('/playlist')
@db_session
def playlist_index(): def playlist_index():
return render_template('playlists.html', return render_template('playlists.html',
mine = store.find(Playlist, Playlist.user_id == request.user.id), mine = Playlist.select(lambda p: p.user == request.user),
others = store.find(Playlist, Playlist.user_id != request.user.id, Playlist.public == True)) others = Playlist.select(lambda p: p.user != request.user and p.public))
@app.route('/playlist/<uid>') @app.route('/playlist/<uid>')
@db_session
def playlist_details(uid): def playlist_details(uid):
try: try:
uid = uuid.UUID(uid) if type(uid) in (str, unicode) else uid uid = uuid.UUID(uid)
except: except:
flash('Invalid playlist id') flash('Invalid playlist id')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))
playlist = store.get(Playlist, uid) try:
if not playlist: playlist = Playlist[uid]
except ObjectNotFound:
flash('Unknown playlist') flash('Unknown playlist')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))
return render_template('playlist.html', playlist = playlist) return render_template('playlist.html', playlist = playlist)
@app.route('/playlist/<uid>', methods = [ 'POST' ]) @app.route('/playlist/<uid>', methods = [ 'POST' ])
@db_session
def playlist_update(uid): def playlist_update(uid):
try: try:
uid = uuid.UUID(uid) uid = uuid.UUID(uid)
@ -55,24 +60,25 @@ def playlist_update(uid):
flash('Invalid playlist id') flash('Invalid playlist id')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))
playlist = store.get(Playlist, uid) try:
if not playlist: playlist = Playlist[uid]
except ObjectNotFound:
flash('Unknown playlist') flash('Unknown playlist')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))
if playlist.user_id != request.user.id: if playlist.user.id != request.user.id:
flash("You're not allowed to edit this playlist") flash("You're not allowed to edit this playlist")
elif not request.form.get('name'): elif not request.form.get('name'):
flash('Missing playlist name') flash('Missing playlist name')
else: else:
playlist.name = request.form.get('name') playlist.name = request.form.get('name')
playlist.public = request.form.get('public') in (True, 'True', 1, '1', 'on', 'checked') playlist.public = request.form.get('public') in (True, 'True', 1, '1', 'on', 'checked')
store.commit()
flash('Playlist updated.') flash('Playlist updated.')
return playlist_details(uid) return playlist_details(uid)
@app.route('/playlist/del/<uid>') @app.route('/playlist/del/<uid>')
@db_session
def playlist_delete(uid): def playlist_delete(uid):
try: try:
uid = uuid.UUID(uid) uid = uuid.UUID(uid)
@ -80,14 +86,16 @@ def playlist_delete(uid):
flash('Invalid playlist id') flash('Invalid playlist id')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))
playlist = store.get(Playlist, uid) try:
if not playlist: playlist = Playlist[uid]
except ObjectNotFound:
flash('Unknown playlist') flash('Unknown playlist')
elif playlist.user_id != request.user.id: return redirect(url_for('playlist_index'))
if playlist.user.id != request.user.id:
flash("You're not allowed to delete this playlist") flash("You're not allowed to delete this playlist")
else: else:
store.remove(playlist) playlist.delete()
store.commit()
flash('Playlist deleted') flash('Playlist deleted')
return redirect(url_for('playlist_index')) return redirect(url_for('playlist_index'))

View File

@ -20,15 +20,16 @@
from flask import request, session, flash, render_template, redirect, url_for, current_app as app from flask import request, session, flash, render_template, redirect, url_for, current_app as app
from functools import wraps from functools import wraps
from pony.orm import db_session
from ..db import User, ClientPrefs from ..db import User, ClientPrefs
from ..lastfm import LastFm from ..lastfm import LastFm
from ..managers.user import UserManager from ..managers.user import UserManager
from ..web import store
from . import admin_only from . import admin_only
def me_or_uuid(f, arg = 'uid'): def me_or_uuid(f, arg = 'uid'):
@db_session
@wraps(f) @wraps(f)
def decorated_func(*args, **kwargs): def decorated_func(*args, **kwargs):
if kwargs: if kwargs:
@ -37,11 +38,11 @@ def me_or_uuid(f, arg = 'uid'):
uid = args[0] uid = args[0]
if uid == 'me': if uid == 'me':
user = request.user user = User[request.user.id] # Refetch user from previous transaction
elif not request.user.admin: elif not request.user.admin:
return redirect(url_for('index')) return redirect(url_for('index'))
else: else:
code, user = UserManager.get(store, uid) code, user = UserManager.get(uid)
if code != UserManager.SUCCESS: if code != UserManager.SUCCESS:
flash(UserManager.error_str(code)) flash(UserManager.error_str(code))
return redirect(url_for('index')) return redirect(url_for('index'))
@ -57,14 +58,14 @@ def me_or_uuid(f, arg = 'uid'):
@app.route('/user') @app.route('/user')
@admin_only @admin_only
@db_session
def user_index(): def user_index():
return render_template('users.html', users = store.find(User)) return render_template('users.html', users = User.select())
@app.route('/user/<uid>') @app.route('/user/<uid>')
@me_or_uuid @me_or_uuid
def user_profile(uid, user): def user_profile(uid, user):
prefs = store.find(ClientPrefs, ClientPrefs.user_id == user.id) return render_template('profile.html', user = user, has_lastfm = app.config['LASTFM']['api_key'] != None, clients = user.clients)
return render_template('profile.html', user = user, has_lastfm = app.config['LASTFM']['api_key'] != None, clients = prefs)
@app.route('/user/<uid>', methods = [ 'POST' ]) @app.route('/user/<uid>', methods = [ 'POST' ])
@me_or_uuid @me_or_uuid
@ -87,25 +88,24 @@ def update_clients(uid, user):
app.logger.debug(clients_opts) app.logger.debug(clients_opts)
for client, opts in clients_opts.iteritems(): for client, opts in clients_opts.iteritems():
prefs = store.get(ClientPrefs, (user.id, client)) prefs = user.clients.select(lambda c: c.client_name == client).first()
if not prefs: if prefs is None:
continue continue
if 'delete' in opts and opts['delete'] in [ 'on', 'true', 'checked', 'selected', '1' ]: if 'delete' in opts and opts['delete'] in [ 'on', 'true', 'checked', 'selected', '1' ]:
store.remove(prefs) prefs.delete()
continue continue
prefs.format = opts['format'] if 'format' in opts and opts['format'] else None prefs.format = opts['format'] if 'format' in opts and opts['format'] else None
prefs.bitrate = int(opts['bitrate']) if 'bitrate' in opts and opts['bitrate'] else None prefs.bitrate = int(opts['bitrate']) if 'bitrate' in opts and opts['bitrate'] else None
store.commit()
flash('Clients preferences updated.') flash('Clients preferences updated.')
return user_profile(uid, user) return user_profile(uid, user)
@app.route('/user/<uid>/changeusername') @app.route('/user/<uid>/changeusername')
@admin_only @admin_only
def change_username_form(uid): def change_username_form(uid):
code, user = UserManager.get(store, uid) code, user = UserManager.get(uid)
if code != UserManager.SUCCESS: if code != UserManager.SUCCESS:
flash(UserManager.error_str(code)) flash(UserManager.error_str(code))
return redirect(url_for('index')) return redirect(url_for('index'))
@ -114,8 +114,9 @@ def change_username_form(uid):
@app.route('/user/<uid>/changeusername', methods = [ 'POST' ]) @app.route('/user/<uid>/changeusername', methods = [ 'POST' ])
@admin_only @admin_only
@db_session
def change_username_post(uid): def change_username_post(uid):
code, user = UserManager.get(store, uid) code, user = UserManager.get(uid)
if code != UserManager.SUCCESS: if code != UserManager.SUCCESS:
return redirect(url_for('index')) return redirect(url_for('index'))
@ -123,7 +124,7 @@ def change_username_post(uid):
if username in ('', None): if username in ('', None):
flash('The username is required') flash('The username is required')
return render_template('change_username.html', user = user) return render_template('change_username.html', user = user)
if user.name != username and store.find(User, User.name == username).one(): if user.name != username and User.get(name = username) is not None:
flash('This name is already taken') flash('This name is already taken')
return render_template('change_username.html', user = user) return render_template('change_username.html', user = user)
@ -135,7 +136,6 @@ def change_username_post(uid):
if user.name != username or user.admin != admin: if user.name != username or user.admin != admin:
user.name = username user.name = username
user.admin = admin user.admin = admin
store.commit()
flash("User '%s' updated." % username) flash("User '%s' updated." % username)
else: else:
flash("No changes for '%s'." % username) flash("No changes for '%s'." % username)
@ -150,10 +150,9 @@ def change_mail_form(uid, user):
@app.route('/user/<uid>/changemail', methods = [ 'POST' ]) @app.route('/user/<uid>/changemail', methods = [ 'POST' ])
@me_or_uuid @me_or_uuid
def change_mail_post(uid, user): def change_mail_post(uid, user):
mail = request.form.get('mail') mail = request.form.get('mail', '')
# No validation, lol. # No validation, lol.
user.mail = mail user.mail = mail
store.commit()
return redirect(url_for('user_profile', uid = uid)) return redirect(url_for('user_profile', uid = uid))
@app.route('/user/<uid>/changepass') @app.route('/user/<uid>/changepass')
@ -182,9 +181,9 @@ def change_password_post(uid, user):
if not error: if not error:
if user.id == request.user.id: if user.id == request.user.id:
status = UserManager.change_password(store, user.id, current, new) status = UserManager.change_password(user.id, current, new)
else: else:
status = UserManager.change_password2(store, user.name, new) status = UserManager.change_password2(user.name, new)
if status != UserManager.SUCCESS: if status != UserManager.SUCCESS:
flash(UserManager.error_str(status)) flash(UserManager.error_str(status))
@ -214,13 +213,12 @@ def add_user_post():
flash("The passwords don't match.") flash("The passwords don't match.")
error = True error = True
if admin is None: admin = admin is not None
admin = True if store.find(User, User.admin == True).count() == 0 else False if mail is None:
else: mail = ''
admin = True
if not error: if not error:
status = UserManager.add(store, name, passwd, mail, admin) status = UserManager.add(name, passwd, mail, admin)
if status == UserManager.SUCCESS: if status == UserManager.SUCCESS:
flash("User '%s' successfully added" % name) flash("User '%s' successfully added" % name)
return redirect(url_for('user_index')) return redirect(url_for('user_index'))
@ -232,7 +230,7 @@ def add_user_post():
@app.route('/user/del/<uid>') @app.route('/user/del/<uid>')
@admin_only @admin_only
def del_user(uid): def del_user(uid):
status = UserManager.delete(store, uid) status = UserManager.delete(uid)
if status == UserManager.SUCCESS: if status == UserManager.SUCCESS:
flash('Deleted user') flash('Deleted user')
else: else:
@ -250,7 +248,6 @@ def lastfm_reg(uid, user):
lfm = LastFm(app.config['LASTFM'], user, app.logger) lfm = LastFm(app.config['LASTFM'], user, app.logger)
status, error = lfm.link_account(token) status, error = lfm.link_account(token)
store.commit()
flash(error if not status else 'Successfully linked LastFM account') flash(error if not status else 'Successfully linked LastFM account')
return redirect(url_for('user_profile', uid = uid)) return redirect(url_for('user_profile', uid = uid))
@ -260,7 +257,6 @@ def lastfm_reg(uid, user):
def lastfm_unreg(uid, user): def lastfm_unreg(uid, user):
lfm = LastFm(app.config['LASTFM'], user, app.logger) lfm = LastFm(app.config['LASTFM'], user, app.logger)
lfm.unlink_account() lfm.unlink_account()
store.commit()
flash('Unlinked LastFM account') flash('Unlinked LastFM account')
return redirect(url_for('user_profile', uid = uid)) return redirect(url_for('user_profile', uid = uid))
@ -284,7 +280,7 @@ def login():
error = True error = True
if not error: if not error:
status, user = UserManager.try_auth(store, name, password) status, user = UserManager.try_auth(name, password)
if status == UserManager.SUCCESS: if status == UserManager.SUCCESS:
session['userid'] = str(user.id) session['userid'] = str(user.id)
flash('Logged in!') flash('Logged in!')

View File

@ -28,7 +28,7 @@ from threading import Thread, Condition, Timer
from watchdog.observers import Observer from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler from watchdog.events import PatternMatchingEventHandler
from .db import get_database, release_database, Folder from .db import init_database, release_database, Folder
from .scanner import Scanner from .scanner import Scanner
OP_SCAN = 1 OP_SCAN = 1
@ -205,7 +205,7 @@ class SupysonicWatcher(object):
def __init__(self, config): def __init__(self, config):
self.__config = config self.__config = config
self.__running = True self.__running = True
self.__db = get_database(config.BASE['database_uri']) init_database(config.BASE['database_uri'])
def run(self): def run(self):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -230,7 +230,7 @@ class SupysonicWatcher(object):
shouldrun = folders.exists() shouldrun = folders.exists()
if not shouldrun: if not shouldrun:
logger.info("No folder set. Exiting.") logger.info("No folder set. Exiting.")
release_database(self.__db) release_database()
return return
queue = ScannerProcessingQueue(self.__config.DAEMON['wait_delay'], logger) queue = ScannerProcessingQueue(self.__config.DAEMON['wait_delay'], logger)
@ -258,7 +258,7 @@ class SupysonicWatcher(object):
observer.join() observer.join()
queue.stop() queue.stop()
queue.join() queue.join()
release_database(self.__db) release_database()
def stop(self): def stop(self):
self.__running = False self.__running = False

View File

@ -11,25 +11,11 @@
import mimetypes import mimetypes
from flask import Flask, g, current_app from flask import Flask
from os import makedirs, path from os import makedirs, path
from werkzeug.local import LocalProxy
from .config import IniConfig from .config import IniConfig
from .db import get_store from .db import init_database, release_database
# Supysonic database open
def get_db():
if not hasattr(g, 'database'):
g.database = get_store(current_app.config['BASE']['database_uri'])
return g.database
# Supysonic database close
def close_db(error):
if hasattr(g, 'database'):
g.database.close()
store = LocalProxy(get_db)
def create_application(config = None): def create_application(config = None):
global app global app
@ -42,9 +28,6 @@ def create_application(config = None):
config = IniConfig.from_common_locations() config = IniConfig.from_common_locations()
app.config.from_object(config) app.config.from_object(config)
# Close database connection on teardown
app.teardown_appcontext(close_db)
# Set loglevel # Set loglevel
logfile = app.config['WEBAPP']['log_file'] logfile = app.config['WEBAPP']['log_file']
if logfile: if logfile:
@ -63,6 +46,9 @@ def create_application(config = None):
handler.setLevel(mapping.get(loglevel.upper(), logging.NOTSET)) handler.setLevel(mapping.get(loglevel.upper(), logging.NOTSET))
app.logger.addHandler(handler) app.logger.addHandler(handler)
# Initialize database
init_database(app.config['BASE']['database_uri'])
# Insert unknown mimetypes # Insert unknown mimetypes
for k, v in app.config['MIMETYPES'].iteritems(): for k, v in app.config['MIMETYPES'].iteritems():
extension = '.' + k.lower() extension = '.' + k.lower()

View File

@ -18,7 +18,7 @@ from contextlib import contextmanager
from pony.orm import db_session from pony.orm import db_session
from StringIO import StringIO from StringIO import StringIO
from supysonic.db import Folder, User, get_database, release_database from supysonic.db import Folder, User, init_database, release_database
from supysonic.cli import SupysonicCLI from supysonic.cli import SupysonicCLI
from ..testbase import TestConfig from ..testbase import TestConfig
@ -30,7 +30,7 @@ class CLITestCase(unittest.TestCase):
conf = TestConfig(False, False) conf = TestConfig(False, False)
self.__dbfile = tempfile.mkstemp()[1] self.__dbfile = tempfile.mkstemp()[1]
conf.BASE['database_uri'] = 'sqlite:///' + self.__dbfile conf.BASE['database_uri'] = 'sqlite:///' + self.__dbfile
self.__store = get_database(conf.BASE['database_uri'], True) init_database(conf.BASE['database_uri'], True)
self.__stdout = StringIO() self.__stdout = StringIO()
self.__stderr = StringIO() self.__stderr = StringIO()
@ -39,7 +39,7 @@ class CLITestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.__stdout.close() self.__stdout.close()
self.__stderr.close() self.__stderr.close()
release_database(self.__store) release_database()
os.unlink(self.__dbfile) os.unlink(self.__dbfile)
@contextmanager @contextmanager

View File

@ -22,10 +22,10 @@ date_regex = re.compile(r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$')
class DbTestCase(unittest.TestCase): class DbTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = db.get_database('sqlite:', True) db.init_database('sqlite:', True)
def tearDown(self): def tearDown(self):
db.release_database(self.store) db.release_database()
def create_some_folders(self): def create_some_folders(self):
root_folder = db.Folder( root_folder = db.Folder(
@ -104,7 +104,6 @@ class DbTestCase(unittest.TestCase):
@db_session @db_session
def test_folder_base(self): def test_folder_base(self):
root_folder, child_folder = self.create_some_folders() root_folder, child_folder = self.create_some_folders()
self.store.commit()
MockUser = namedtuple('User', [ 'id' ]) MockUser = namedtuple('User', [ 'id' ])
user = MockUser(uuid.uuid4()) user = MockUser(uuid.uuid4())
@ -149,7 +148,6 @@ class DbTestCase(unittest.TestCase):
rated = root_folder, rated = root_folder,
rating = 5 rating = 5
) )
self.store.commit()
root = root_folder.as_subsonic_child(user) root = root_folder.as_subsonic_child(user)
self.assertIn('starred', root) self.assertIn('starred', root)
@ -169,7 +167,6 @@ class DbTestCase(unittest.TestCase):
user = self.create_user() user = self.create_user()
star = db.StarredArtist(user = user, starred = artist) star = db.StarredArtist(user = user, starred = artist)
self.store.commit()
artist_dict = artist.as_subsonic_artist(user) artist_dict = artist.as_subsonic_artist(user)
self.assertIsInstance(artist_dict, dict) self.assertIsInstance(artist_dict, dict)
@ -183,7 +180,6 @@ class DbTestCase(unittest.TestCase):
db.Album(name = 'Test Artist', artist = artist) # self-titled db.Album(name = 'Test Artist', artist = artist) # self-titled
db.Album(name = 'The Album After The First One', artist = artist) db.Album(name = 'The Album After The First One', artist = artist)
self.store.commit()
artist_dict = artist.as_subsonic_artist(user) artist_dict = artist.as_subsonic_artist(user)
self.assertEqual(artist_dict['albumCount'], 2) self.assertEqual(artist_dict['albumCount'], 2)
@ -198,13 +194,11 @@ class DbTestCase(unittest.TestCase):
user = user, user = user,
starred = album starred = album
) )
self.store.commit()
# No tracks, shouldn't be stored under normal circumstances # No tracks, shouldn't be stored under normal circumstances
self.assertRaises(ValueError, album.as_subsonic_album, user) self.assertRaises(ValueError, album.as_subsonic_album, user)
self.create_some_tracks(artist, album) self.create_some_tracks(artist, album)
self.store.commit()
album_dict = album.as_subsonic_album(user) album_dict = album.as_subsonic_album(user)
self.assertIsInstance(album_dict, dict) self.assertIsInstance(album_dict, dict)
@ -227,7 +221,6 @@ class DbTestCase(unittest.TestCase):
@db_session @db_session
def test_track(self): def test_track(self):
track1, track2 = self.create_some_tracks() track1, track2 = self.create_some_tracks()
self.store.commit()
# Assuming SQLite doesn't enforce foreign key constraints # Assuming SQLite doesn't enforce foreign key constraints
MockUser = namedtuple('User', [ 'id' ]) MockUser = namedtuple('User', [ 'id' ])
@ -245,7 +238,6 @@ class DbTestCase(unittest.TestCase):
@db_session @db_session
def test_user(self): def test_user(self):
user = self.create_user() user = self.create_user()
self.store.commit()
user_dict = user.as_subsonic_user() user_dict = user.as_subsonic_user()
self.assertIsInstance(user_dict, dict) self.assertIsInstance(user_dict, dict)
@ -258,7 +250,6 @@ class DbTestCase(unittest.TestCase):
user = user, user = user,
message = 'Hello world!' message = 'Hello world!'
) )
self.store.commit()
line_dict = line.responsize() line_dict = line.responsize()
self.assertIsInstance(line_dict, dict) self.assertIsInstance(line_dict, dict)

View File

@ -24,7 +24,7 @@ from supysonic.scanner import Scanner
class ScannerTestCase(unittest.TestCase): class ScannerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = db.get_database('sqlite:', True) db.init_database('sqlite:', True)
FolderManager.add('folder', os.path.abspath('tests/assets')) FolderManager.add('folder', os.path.abspath('tests/assets'))
with db_session: with db_session:
@ -37,7 +37,7 @@ class ScannerTestCase(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.scanner.finish() self.scanner.finish()
db.release_database(self.store) db.release_database()
@contextmanager @contextmanager
def __temporary_track_copy(self): def __temporary_track_copy(self):

View File

@ -21,7 +21,7 @@ from contextlib import contextmanager
from pony.orm import db_session from pony.orm import db_session
from threading import Thread from threading import Thread
from supysonic.db import get_database, release_database, Track, Artist from supysonic.db import init_database, release_database, Track, Artist
from supysonic.managers.folder import FolderManager from supysonic.managers.folder import FolderManager
from supysonic.watcher import SupysonicWatcher from supysonic.watcher import SupysonicWatcher
@ -42,7 +42,8 @@ class WatcherTestBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.__dbfile = tempfile.mkstemp()[1] self.__dbfile = tempfile.mkstemp()[1]
dburi = 'sqlite:///' + self.__dbfile dburi = 'sqlite:///' + self.__dbfile
release_database(get_database(dburi, True)) init_database(dburi, True)
release_database()
conf = WatcherTestConfig(dburi) conf = WatcherTestConfig(dburi)
self.__sleep_time = conf.DAEMON['wait_delay'] + 1 self.__sleep_time = conf.DAEMON['wait_delay'] + 1
@ -69,9 +70,9 @@ class WatcherTestBase(unittest.TestCase):
@contextmanager @contextmanager
def _tempdbrebind(self): def _tempdbrebind(self):
db = get_database('sqlite:///' + self.__dbfile) init_database('sqlite:///' + self.__dbfile)
try: yield try: yield
finally: release_database(db) finally: release_database()
class NothingToWatchTestCase(WatcherTestBase): class NothingToWatchTestCase(WatcherTestBase):
def test_spawn_useless_watcher(self): def test_spawn_useless_watcher(self):

View File

@ -11,6 +11,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import Folder from supysonic.db import Folder
from .frontendtestbase import FrontendTestBase from .frontendtestbase import FrontendTestBase
@ -50,20 +52,22 @@ class FolderTestCase(FrontendTestBase):
self.assertIn('Add Folder', rv.data) self.assertIn('Add Folder', rv.data)
rv = self.client.post('/folder/add', data = { 'name': 'name', 'path': 'tests/assets' }, follow_redirects = True) rv = self.client.post('/folder/add', data = { 'name': 'name', 'path': 'tests/assets' }, follow_redirects = True)
self.assertIn('created', rv.data) self.assertIn('created', rv.data)
self.assertEqual(self.store.find(Folder).count(), 1) with db_session:
self.assertEqual(Folder.select().count(), 1)
def test_delete(self): def test_delete(self):
folder = Folder() with db_session:
folder.name = 'folder' folder = Folder(
folder.path = 'tests/assets' name = 'folder',
folder.root = True path = 'tests/assets',
self.store.add(folder) root = True
self.store.commit() )
self._login('bob', 'B0b') self._login('bob', 'B0b')
rv = self.client.get('/folder/del/' + str(folder.id), follow_redirects = True) rv = self.client.get('/folder/del/' + str(folder.id), follow_redirects = True)
self.assertIn('There\'s nothing much to see', rv.data) self.assertIn('There\'s nothing much to see', rv.data)
self.assertEqual(self.store.find(Folder).count(), 1) with db_session:
self.assertEqual(Folder.select().count(), 1)
self._logout() self._logout()
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
@ -73,15 +77,17 @@ class FolderTestCase(FrontendTestBase):
self.assertIn('No such folder', rv.data) self.assertIn('No such folder', rv.data)
rv = self.client.get('/folder/del/' + str(folder.id), follow_redirects = True) rv = self.client.get('/folder/del/' + str(folder.id), follow_redirects = True)
self.assertIn('Music folders', rv.data) self.assertIn('Music folders', rv.data)
self.assertEqual(self.store.find(Folder).count(), 0) with db_session:
self.assertEqual(Folder.select().count(), 0)
def test_scan(self): def test_scan(self):
folder = Folder() with db_session:
folder.name = 'folder' folder = Folder(
folder.path = 'tests/assets' name = 'folder',
folder.root = True path = 'tests/assets',
self.store.add(folder) root = True,
self.store.commit() )
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
rv = self.client.get('/folder/scan/string', follow_redirects = True) rv = self.client.get('/folder/scan/string', follow_redirects = True)

View File

@ -12,6 +12,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import User from supysonic.db import User
from .frontendtestbase import FrontendTestBase from .frontendtestbase import FrontendTestBase
@ -50,8 +52,9 @@ class LoginTestCase(FrontendTestBase):
def test_root_with_valid_session(self): def test_root_with_valid_session(self):
# Root with valid session # Root with valid session
with self.client.session_transaction() as sess: with db_session:
sess['userid'] = self.store.find(User, User.name == 'alice').one().id with self.client.session_transaction() as sess:
sess['userid'] = User.get(name = 'alice').id
rv = self.client.get('/', follow_redirects=True) rv = self.client.get('/', follow_redirects=True)
self.assertIn('alice', rv.data) self.assertIn('alice', rv.data)
self.assertIn('Log out', rv.data) self.assertIn('Log out', rv.data)

View File

@ -11,6 +11,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import Folder, Artist, Album, Track, Playlist, User from supysonic.db import Folder, Artist, Album, Track, Playlist, User
from .frontendtestbase import FrontendTestBase from .frontendtestbase import FrontendTestBase
@ -19,43 +21,34 @@ class PlaylistTestCase(FrontendTestBase):
def setUp(self): def setUp(self):
super(PlaylistTestCase, self).setUp() super(PlaylistTestCase, self).setUp()
folder = Folder() with db_session:
folder.name = 'Root' folder = Folder(name = 'Root', path = 'tests/assets', root = True)
folder.path = 'tests/assets' artist = Artist(name = 'Artist!')
folder.root = True album = Album(name = 'Album!', artist = artist)
artist = Artist() track = Track(
artist.name = 'Artist!' path = 'tests/assets/23bytes',
title = '23bytes',
artist = artist,
album = album,
folder = folder,
root_folder = folder,
duration = 2,
disc = 1,
number = 1,
content_type = 'audio/mpeg',
bitrate = 320,
last_modification = 0
)
album = Album() playlist = Playlist(
album.name = 'Album!' name = 'Playlist!',
album.artist = artist user = User.get(name = 'alice')
)
for _ in range(4):
playlist.add(track)
track = Track() self.playlistid = playlist.id
track.path = 'tests/assets/23bytes'
track.title = '23bytes'
track.artist = artist
track.album = album
track.folder = folder
track.root_folder = folder
track.duration = 2
track.disc = 1
track.number = 1
track.content_type = 'audio/mpeg'
track.bitrate = 320
track.last_modification = 0
playlist = Playlist()
playlist.name = 'Playlist!'
playlist.user = self.store.find(User, User.name == 'alice').one()
for _ in range(4):
playlist.add(track)
self.store.add(track)
self.store.add(playlist)
self.store.commit()
self.playlist = playlist
def test_index(self): def test_index(self):
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
@ -68,7 +61,7 @@ class PlaylistTestCase(FrontendTestBase):
self.assertIn('Invalid', rv.data) self.assertIn('Invalid', rv.data)
rv = self.client.get('/playlist/' + str(uuid.uuid4()), follow_redirects = True) rv = self.client.get('/playlist/' + str(uuid.uuid4()), follow_redirects = True)
self.assertIn('Unknown', rv.data) self.assertIn('Unknown', rv.data)
rv = self.client.get('/playlist/' + str(self.playlist.id)) rv = self.client.get('/playlist/' + str(self.playlistid))
self.assertIn('Playlist!', rv.data) self.assertIn('Playlist!', rv.data)
self.assertIn('23bytes', rv.data) self.assertIn('23bytes', rv.data)
self.assertIn('Artist!', rv.data) self.assertIn('Artist!', rv.data)
@ -80,22 +73,25 @@ class PlaylistTestCase(FrontendTestBase):
self.assertIn('Invalid', rv.data) self.assertIn('Invalid', rv.data)
rv = self.client.post('/playlist/' + str(uuid.uuid4()), follow_redirects = True) rv = self.client.post('/playlist/' + str(uuid.uuid4()), follow_redirects = True)
self.assertIn('Unknown', rv.data) self.assertIn('Unknown', rv.data)
rv = self.client.post('/playlist/' + str(self.playlist.id), follow_redirects = True) rv = self.client.post('/playlist/' + str(self.playlistid), follow_redirects = True)
self.assertNotIn('updated', rv.data) self.assertNotIn('updated', rv.data)
self.assertIn('not allowed', rv.data) self.assertIn('not allowed', rv.data)
self._logout() self._logout()
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
rv = self.client.post('/playlist/' + str(self.playlist.id), follow_redirects = True) rv = self.client.post('/playlist/' + str(self.playlistid), follow_redirects = True)
self.assertNotIn('updated', rv.data) self.assertNotIn('updated', rv.data)
self.assertIn('Missing', rv.data) self.assertIn('Missing', rv.data)
self.assertEqual(self.playlist.name, 'Playlist!') with db_session:
self.assertEqual(Playlist[self.playlistid].name, 'Playlist!')
rv = self.client.post('/playlist/' + str(self.playlist.id), data = { 'name': 'abc', 'public': True }, follow_redirects = True) rv = self.client.post('/playlist/' + str(self.playlistid), data = { 'name': 'abc', 'public': True }, follow_redirects = True)
self.assertIn('updated', rv.data) self.assertIn('updated', rv.data)
self.assertNotIn('not allowed', rv.data) self.assertNotIn('not allowed', rv.data)
self.assertEqual(self.playlist.name, 'abc') with db_session:
self.assertTrue(self.playlist.public) playlist = Playlist[self.playlistid]
self.assertEqual(playlist.name, 'abc')
self.assertTrue(playlist.public)
def test_delete(self): def test_delete(self):
self._login('bob', 'B0b') self._login('bob', 'B0b')
@ -103,15 +99,17 @@ class PlaylistTestCase(FrontendTestBase):
self.assertIn('Invalid', rv.data) self.assertIn('Invalid', rv.data)
rv = self.client.get('/playlist/del/' + str(uuid.uuid4()), follow_redirects = True) rv = self.client.get('/playlist/del/' + str(uuid.uuid4()), follow_redirects = True)
self.assertIn('Unknown', rv.data) self.assertIn('Unknown', rv.data)
rv = self.client.get('/playlist/del/' + str(self.playlist.id), follow_redirects = True) rv = self.client.get('/playlist/del/' + str(self.playlistid), follow_redirects = True)
self.assertIn('not allowed', rv.data) self.assertIn('not allowed', rv.data)
self.assertEqual(self.store.find(Playlist).count(), 1) with db_session:
self.assertEqual(Playlist.select().count(), 1)
self._logout() self._logout()
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
rv = self.client.get('/playlist/del/' + str(self.playlist.id), follow_redirects = True) rv = self.client.get('/playlist/del/' + str(self.playlistid), follow_redirects = True)
self.assertIn('deleted', rv.data) self.assertIn('deleted', rv.data)
self.assertEqual(self.store.find(Playlist).count(), 0) with db_session:
self.assertEqual(Playlist.select().count(), 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -11,6 +11,8 @@
import uuid import uuid
from pony.orm import db_session
from supysonic.db import User, ClientPrefs from supysonic.db import User, ClientPrefs
from .frontendtestbase import FrontendTestBase from .frontendtestbase import FrontendTestBase
@ -19,7 +21,8 @@ class UserTestCase(FrontendTestBase):
def setUp(self): def setUp(self):
super(UserTestCase, self).setUp() super(UserTestCase, self).setUp()
self.users = { u.name: u for u in self.store.find(User) } with db_session:
self.users = { u.name: u.id for u in User.select() }
def test_index(self): def test_index(self):
self._login('bob', 'B0b') self._login('bob', 'B0b')
@ -38,18 +41,15 @@ class UserTestCase(FrontendTestBase):
self.assertIn('Invalid', rv.data) self.assertIn('Invalid', rv.data)
rv = self.client.get('/user/' + str(uuid.uuid4()), follow_redirects = True) rv = self.client.get('/user/' + str(uuid.uuid4()), follow_redirects = True)
self.assertIn('No such user', rv.data) self.assertIn('No such user', rv.data)
rv = self.client.get('/user/' + str(self.users['bob'].id)) rv = self.client.get('/user/' + str(self.users['bob']))
self.assertIn('bob', rv.data) self.assertIn('bob', rv.data)
self._logout() self._logout()
prefs = ClientPrefs() with db_session:
prefs.user_id = self.users['bob'].id ClientPrefs(user = User[self.users['bob']], client_name = 'tests')
prefs.client_name = 'tests'
self.store.add(prefs)
self.store.commit()
self._login('bob', 'B0b') self._login('bob', 'B0b')
rv = self.client.get('/user/' + str(self.users['alice'].id), follow_redirects = True) rv = self.client.get('/user/' + str(self.users['alice']), follow_redirects = True)
self.assertIn('There\'s nothing much to see', rv.data) self.assertIn('There\'s nothing much to see', rv.data)
self.assertNotIn('<h2>bob</h2>', rv.data) self.assertNotIn('<h2>bob</h2>', rv.data)
rv = self.client.get('/user/me') rv = self.client.get('/user/me')
@ -68,19 +68,19 @@ class UserTestCase(FrontendTestBase):
self.client.post('/user/me', data = { 'n_': 'o' }) self.client.post('/user/me', data = { 'n_': 'o' })
self.client.post('/user/me', data = { 'inexisting_client': 'setting' }) self.client.post('/user/me', data = { 'inexisting_client': 'setting' })
prefs = ClientPrefs() with db_session:
prefs.user_id = self.users['alice'].id ClientPrefs(user = User[self.users['alice']], client_name = 'tests')
prefs.client_name = 'tests'
self.store.add(prefs)
self.store.commit()
rv = self.client.post('/user/me', data = { 'tests_format': 'mp3', 'tests_bitrate': 128 }) rv = self.client.post('/user/me', data = { 'tests_format': 'mp3', 'tests_bitrate': 128 })
self.assertIn('updated', rv.data) self.assertIn('updated', rv.data)
self.assertEqual(prefs.format, 'mp3') with db_session:
self.assertEqual(prefs.bitrate, 128) prefs = ClientPrefs[User[self.users['alice']], 'tests']
self.assertEqual(prefs.format, 'mp3')
self.assertEqual(prefs.bitrate, 128)
self.client.post('/user/me', data = { 'tests_delete': 1 }) self.client.post('/user/me', data = { 'tests_delete': 1 })
self.assertEqual(self.store.find(ClientPrefs).count(), 0) with db_session:
self.assertEqual(ClientPrefs.select().count(), 0)
def test_change_username_get(self): def test_change_username_get(self):
self._login('bob', 'B0b') self._login('bob', 'B0b')
@ -93,13 +93,13 @@ class UserTestCase(FrontendTestBase):
self.assertIn('Invalid', rv.data) self.assertIn('Invalid', rv.data)
rv = self.client.get('/user/{}/changeusername'.format(uuid.uuid4()), follow_redirects = True) rv = self.client.get('/user/{}/changeusername'.format(uuid.uuid4()), follow_redirects = True)
self.assertIn('No such user', rv.data) self.assertIn('No such user', rv.data)
self.client.get('/user/{}/changeusername'.format(self.users['bob'].id)) self.client.get('/user/{}/changeusername'.format(self.users['bob']))
def test_change_username_post(self): def test_change_username_post(self):
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
self.client.post('/user/whatever/changeusername') self.client.post('/user/whatever/changeusername')
path = '/user/{}/changeusername'.format(self.users['bob'].id) path = '/user/{}/changeusername'.format(self.users['bob'])
rv = self.client.post(path, follow_redirects = True) rv = self.client.post(path, follow_redirects = True)
self.assertIn('required', rv.data) self.assertIn('required', rv.data)
rv = self.client.post(path, data = { 'user': 'bob' }, follow_redirects = True) rv = self.client.post(path, data = { 'user': 'bob' }, follow_redirects = True)
@ -107,10 +107,13 @@ class UserTestCase(FrontendTestBase):
rv = self.client.post(path, data = { 'user': 'b0b', 'admin': 1 }, follow_redirects = True) rv = self.client.post(path, data = { 'user': 'b0b', 'admin': 1 }, follow_redirects = True)
self.assertIn('updated', rv.data) self.assertIn('updated', rv.data)
self.assertIn('b0b', rv.data) self.assertIn('b0b', rv.data)
self.assertEqual(self.users['bob'].name, 'b0b') with db_session:
self.assertTrue(self.users['bob'].admin) bob = User[self.users['bob']]
self.assertEqual(bob.name, 'b0b')
self.assertTrue(bob.admin)
rv = self.client.post(path, data = { 'user': 'alice' }, follow_redirects = True) rv = self.client.post(path, data = { 'user': 'alice' }, follow_redirects = True)
self.assertEqual(self.users['bob'].name, 'b0b') with db_session:
self.assertEqual(User[self.users['bob']].name, 'b0b')
def test_change_mail_get(self): def test_change_mail_get(self):
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
@ -126,7 +129,7 @@ class UserTestCase(FrontendTestBase):
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
rv = self.client.get('/user/me/changepass') rv = self.client.get('/user/me/changepass')
self.assertIn('Current password', rv.data) self.assertIn('Current password', rv.data)
rv = self.client.get('/user/{}/changepass'.format(self.users['bob'].id)) rv = self.client.get('/user/{}/changepass'.format(self.users['bob']))
self.assertNotIn('Current password', rv.data) self.assertNotIn('Current password', rv.data)
def test_change_password_post(self): def test_change_password_post(self):
@ -151,7 +154,7 @@ class UserTestCase(FrontendTestBase):
rv = self._login('alice', 'alice') rv = self._login('alice', 'alice')
self.assertIn('Logged in', rv.data) self.assertIn('Logged in', rv.data)
path = '/user/{}/changepass'.format(self.users['bob'].id) path = '/user/{}/changepass'.format(self.users['bob'])
rv = self.client.post(path) rv = self.client.post(path)
self.assertIn('required', rv.data) self.assertIn('required', rv.data)
rv = self.client.post(path, data = { 'new': 'alice' }) rv = self.client.post(path, data = { 'new': 'alice' })
@ -162,7 +165,6 @@ class UserTestCase(FrontendTestBase):
rv = self._login('bob', 'alice') rv = self._login('bob', 'alice')
self.assertIn('Logged in', rv.data) self.assertIn('Logged in', rv.data)
def test_add_get(self): def test_add_get(self):
self._login('bob', 'B0b') self._login('bob', 'B0b')
rv = self.client.get('/user/add', follow_redirects = True) rv = self.client.get('/user/add', follow_redirects = True)
@ -186,22 +188,25 @@ class UserTestCase(FrontendTestBase):
self.assertIn('passwords don', rv.data) self.assertIn('passwords don', rv.data)
rv = self.client.post('/user/add', data = { 'user': 'alice', 'passwd': 'passwd', 'passwd_confirm': 'passwd' }) rv = self.client.post('/user/add', data = { 'user': 'alice', 'passwd': 'passwd', 'passwd_confirm': 'passwd' })
self.assertIn('already a user with that name', rv.data) self.assertIn('already a user with that name', rv.data)
self.assertEqual(self.store.find(User).count(), 2) with db_session:
self.assertEqual(User.select().count(), 2)
rv = self.client.post('/user/add', data = { 'user': 'user', 'passwd': 'passwd', 'passwd_confirm': 'passwd', 'admin': 1 }, follow_redirects = True) rv = self.client.post('/user/add', data = { 'user': 'user', 'passwd': 'passwd', 'passwd_confirm': 'passwd', 'admin': 1 }, follow_redirects = True)
self.assertIn('added', rv.data) self.assertIn('added', rv.data)
self.assertEqual(self.store.find(User).count(), 3) with db_session:
self.assertEqual(User.select().count(), 3)
self._logout() self._logout()
rv = self._login('user', 'passwd') rv = self._login('user', 'passwd')
self.assertIn('Logged in', rv.data) self.assertIn('Logged in', rv.data)
def test_delete(self): def test_delete(self):
path = '/user/del/{}'.format(self.users['bob'].id) path = '/user/del/{}'.format(self.users['bob'])
self._login('bob', 'B0b') self._login('bob', 'B0b')
rv = self.client.get(path, follow_redirects = True) rv = self.client.get(path, follow_redirects = True)
self.assertIn('There\'s nothing much to see', rv.data) self.assertIn('There\'s nothing much to see', rv.data)
self.assertEqual(self.store.find(User).count(), 2) with db_session:
self.assertEqual(User.select().count(), 2)
self._logout() self._logout()
self._login('alice', 'Alic3') self._login('alice', 'Alic3')
@ -211,7 +216,8 @@ class UserTestCase(FrontendTestBase):
self.assertIn('No such user', rv.data) self.assertIn('No such user', rv.data)
rv = self.client.get(path, follow_redirects = True) rv = self.client.get(path, follow_redirects = True)
self.assertIn('Deleted', rv.data) self.assertIn('Deleted', rv.data)
self.assertEqual(self.store.find(User).count(), 1) with db_session:
self.assertEqual(User.select().count(), 1)
self._logout() self._logout()
rv = self._login('bob', 'B0b') rv = self._login('bob', 'B0b')
self.assertIn('No such user', rv.data) self.assertIn('No such user', rv.data)

View File

@ -25,12 +25,17 @@ from pony.orm import db_session, ObjectNotFound
class FolderManagerTestCase(unittest.TestCase): class FolderManagerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
# Create an empty sqlite database in memory # Create an empty sqlite database in memory
self.store = db.get_database('sqlite:', True) db.init_database('sqlite:', True)
# Create some temporary directories # Create some temporary directories
self.media_dir = tempfile.mkdtemp() self.media_dir = tempfile.mkdtemp()
self.music_dir = tempfile.mkdtemp() self.music_dir = tempfile.mkdtemp()
def tearDown(self):
db.release_database()
shutil.rmtree(self.media_dir)
shutil.rmtree(self.music_dir)
@db_session @db_session
def create_folders(self): def create_folders(self):
# Add test folders # Add test folders
@ -62,11 +67,6 @@ class FolderManagerTestCase(unittest.TestCase):
last_modification = 0 last_modification = 0
) )
def tearDown(self):
db.release_database(self.store)
shutil.rmtree(self.media_dir)
shutil.rmtree(self.music_dir)
@db_session @db_session
def test_get_folder(self): def test_get_folder(self):
self.create_folders() self.create_folders()

View File

@ -17,13 +17,16 @@ import io
import unittest import unittest
import uuid import uuid
from pony.orm import db_session from pony.orm import db_session, commit
from pony.orm import ObjectNotFound from pony.orm import ObjectNotFound
class UserManagerTestCase(unittest.TestCase): class UserManagerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
# Create an empty sqlite database in memory # Create an empty sqlite database in memory
self.store = db.get_database('sqlite:', True) db.init_database('sqlite:', True)
def tearDown(self):
db.release_database()
@db_session @db_session
def create_data(self): def create_data(self):
@ -56,9 +59,6 @@ class UserManagerTestCase(unittest.TestCase):
) )
playlist.add(track) playlist.add(track)
def tearDown(self):
db.release_database(self.store)
def test_encrypt_password(self): def test_encrypt_password(self):
func = UserManager._UserManager__encrypt_password func = UserManager._UserManager__encrypt_password
self.assertEqual(func(u'password',u'salt'), (u'59b3e8d637cf97edbe2384cf59cb7453dfe30789', u'salt')) self.assertEqual(func(u'password',u'salt'), (u'59b3e8d637cf97edbe2384cf59cb7453dfe30789', u'salt'))
@ -107,7 +107,7 @@ class UserManagerTestCase(unittest.TestCase):
user = db.User.get(name = name) user = db.User.get(name = name)
self.assertEqual(UserManager.delete(user.id), UserManager.SUCCESS) self.assertEqual(UserManager.delete(user.id), UserManager.SUCCESS)
self.assertRaises(ObjectNotFound, db.User.__getitem__, user.id) self.assertRaises(ObjectNotFound, db.User.__getitem__, user.id)
self.store.commit() commit()
self.assertEqual(db.User.select().count(), 0) self.assertEqual(db.User.select().count(), 0)
@db_session @db_session

View File

@ -10,22 +10,20 @@
import inspect import inspect
import io import io
import os
import shutil import shutil
import sys import sys
import unittest import unittest
import tempfile import tempfile
from supysonic.db import init_database, release_database
from supysonic.config import DefaultConfig from supysonic.config import DefaultConfig
from supysonic.managers.user import UserManager from supysonic.managers.user import UserManager
from supysonic.web import create_application, store from supysonic.web import create_application
class TestConfig(DefaultConfig): class TestConfig(DefaultConfig):
TESTING = True TESTING = True
LOGGER_HANDLER_POLICY = 'never' LOGGER_HANDLER_POLICY = 'never'
BASE = {
'database_uri': 'sqlite:',
'scanner_extensions': None
}
MIMETYPES = { MIMETYPES = {
'mp3': 'audio/mpeg', 'mp3': 'audio/mpeg',
'weirdextension': 'application/octet-stream' 'weirdextension': 'application/octet-stream'
@ -60,31 +58,37 @@ class TestBase(unittest.TestCase):
__with_api__ = False __with_api__ = False
def setUp(self): def setUp(self):
self.__dbfile = tempfile.mkstemp()[1]
self.__dir = tempfile.mkdtemp() self.__dir = tempfile.mkdtemp()
config = TestConfig(self.__with_webui__, self.__with_api__) config = TestConfig(self.__with_webui__, self.__with_api__)
config.BASE['database_uri'] = 'sqlite:///' + self.__dbfile
config.WEBAPP['cache_dir'] = self.__dir config.WEBAPP['cache_dir'] = self.__dir
app = create_application(config) init_database(config.BASE['database_uri'], True)
self.__ctx = app.app_context() release_database()
self.__ctx.push()
self.store = store app = create_application(config)
with io.open('schema/sqlite.sql', 'r') as sql: #self.__ctx = app.app_context()
schema = sql.read() #self.__ctx.push()
for statement in schema.split(';'):
self.store.execute(statement)
self.store.commit()
self.client = app.test_client() self.client = app.test_client()
UserManager.add(self.store, 'alice', 'Alic3', 'test@example.com', True) UserManager.add('alice', 'Alic3', 'test@example.com', True)
UserManager.add(self.store, 'bob', 'B0b', 'bob@example.com', False) UserManager.add('bob', 'B0b', 'bob@example.com', False)
@staticmethod
def __should_unload_module(module):
if module.startswith('supysonic'):
return not module.startswith('supysonic.db')
return False
def tearDown(self): def tearDown(self):
self.__ctx.pop() #self.__ctx.pop()
release_database()
shutil.rmtree(self.__dir) shutil.rmtree(self.__dir)
os.remove(self.__dbfile)
to_unload = [ m for m in sys.modules if m.startswith('supysonic') ] to_unload = [ m for m in sorted(sys.modules) if self.__should_unload_module(m) ]
for m in to_unload: for m in to_unload:
del sys.modules[m] del sys.modules[m]