diff --git a/supysonic/db.py b/supysonic/db.py index 607f39d..3fef24f 100644 --- a/supysonic/db.py +++ b/supysonic/db.py @@ -25,9 +25,10 @@ from peewee import ( IntegerField, TextField, ) -from peewee import CompositeKey -from playhouse.flask_utils import FlaskDB -from urllib.parse import urlparse, parse_qsl +from peewee import CompositeKey, DatabaseProxy +from peewee import fn +from playhouse.db_url import parseresult_to_dict, schemes +from urllib.parse import urlparse from uuid import UUID, uuid4 SCHEMA_VERSION = "20200607" @@ -41,7 +42,7 @@ def PrimaryKeyField(**kwargs): return BinaryUUIDField(primary_key=True, default=uuid4, **kwargs) -db = FlaskDB() +db = DatabaseProxy() class Meta(db.Model): @@ -105,16 +106,20 @@ class Folder(PathMixin, db.Model): try: starred = StarredFolder[user.id, self.id] info["starred"] = starred.date.isoformat() - except ObjectNotFound: + except StarredFolder.DoesNotExist: pass try: rating = RatingFolder[user.id, self.id] info["userRating"] = rating.rating - except ObjectNotFound: + except RatingFolder.DoesNotExist: pass - avgRating = avg(self.ratings.rating) + avgRating = ( + RatingFolder.select(fn.avg(RatingFolder.rating)) + .where(RatingFolder.rated == self) + .scalar() + ) if avgRating: info["averageRating"] = avgRating @@ -126,7 +131,7 @@ class Folder(PathMixin, db.Model): try: starred = StarredFolder[user.id, self.id] info["starred"] = starred.date.isoformat() - except ObjectNotFound: + except StarredFolder.DoesNotExist: pass return info @@ -179,7 +184,7 @@ class Artist(db.Model): try: starred = StarredArtist[user.id, self.id] info["starred"] = starred.date.isoformat() - except ObjectNotFound: + except StarredArtist.DoesNotExist: pass return info @@ -198,43 +203,53 @@ class Album(db.Model): artist = ForeignKeyField(Artist, backref="albums") def as_subsonic_album(self, user): # "AlbumID3" type in XSD + duration, created, year = self.tracks.select( + fn.sum(Track.duration), fn.min(Track.created), fn.min(Track.year) + ).scalar(as_tuple=True) + info = { "id": str(self.id), "name": self.name, "artist": self.artist.name, "artistId": str(self.artist.id), "songCount": self.tracks.count(), - "duration": sum(self.tracks.duration), - "created": min(self.tracks.created).isoformat(), + "duration": duration, + "created": created.isoformat(), } - track_with_cover = self.tracks.select( - lambda t: t.folder.cover_art is not None - ).first() + track_with_cover = ( + self.tracks.join(Folder).where(Folder.cover_art.is_null(False)).first() + ) if track_with_cover is not None: info["coverArt"] = str(track_with_cover.folder.id) else: - track_with_cover = self.tracks.select(lambda t: t.has_art).first() + track_with_cover = self.tracks.where(Track.has_art).first() if track_with_cover is not None: info["coverArt"] = str(track_with_cover.id) - if count(self.tracks.year) > 0: - info["year"] = min(self.tracks.year) + if year: + info["year"] = year - genre = ", ".join(self.tracks.genre.distinct()) + genre = ", ".join( + g + for (g,) in self.tracks.select(Track.genre) + .where(Track.genre.is_null(False)) + .distinct() + .tuples() + ) if genre: info["genre"] = genre try: starred = StarredAlbum[user.id, self.id] info["starred"] = starred.date.isoformat() - except ObjectNotFound: + except StarredAlbum.DoesNotExist: pass return info def sort_key(self): - year = min(t.year if t.year else 9999 for t in self.tracks) + year = self.tracks.select(fn.min(Track.year)).scalar() or 9999 return f"{year}{self.name.lower()}" @classmethod @@ -305,16 +320,20 @@ class Track(PathMixin, db.Model): try: starred = StarredTrack[user.id, self.id] info["starred"] = starred.date.isoformat() - except ObjectNotFound: + except StarredTrack.DoesNotExist: pass try: rating = RatingTrack[user.id, self.id] info["userRating"] = rating.rating - except ObjectNotFound: + except RatingTrack.DoesNotExist: pass - avgRating = avg(self.ratings.rating) + avgRating = ( + RatingTrack.select(fn.avg(RatingTrack.rating)) + .where(RatingTrack.rated == self) + .scalar() + ) if avgRating: info["averageRating"] = avgRating @@ -483,7 +502,7 @@ class Playlist(db.Model): tid = UUID(t) track = Track[tid] tracks.append(track) - except (ValueError, ObjectNotFound): + except (ValueError, Track.DoesNotExist): should_fix = True if should_fix: @@ -535,108 +554,61 @@ class RadioStation(db.Model): return info -def parse_uri(database_uri): - if not isinstance(database_uri, str): - raise TypeError("Expecting a string") - - uri = urlparse(database_uri) - args = dict(parse_qsl(uri.query)) - if uri.port is not None: - args["port"] = uri.port - - if uri.scheme == "sqlite": - path = uri.path - if not path: - path = ":memory:" - elif path[0] == "/": - path = path[1:] - - return {"provider": "sqlite", "filename": path, "create_db": True, **args} - elif uri.scheme in ("postgres", "postgresql"): - return { - "provider": "postgres", - "user": uri.username, - "password": uri.password, - "host": uri.hostname, - "dbname": uri.path[1:], - **args, - } - elif uri.scheme == "mysql": - args.setdefault("charset", "utf8mb4") - args.setdefault("binary_prefix", True) - return { - "provider": "mysql", - "user": uri.username, - "passwd": uri.password, - "host": uri.hostname, - "db": uri.path[1:], - **args, - } - return {} - - def execute_sql_resource_script(respath): sql = pkg_resources.resource_string(__package__, respath).decode("utf-8") for statement in sql.split(";"): statement = statement.strip() if statement and not statement.startswith("--"): - metadb.execute(statement) + db.execute_sql(statement) def init_database(database_uri): - settings = parse_uri(database_uri) + uri = urlparse(database_uri) + args = parseresult_to_dict(uri) + if uri.scheme.startswith("mysql"): + args.setdefault("charset", "utf8mb4") + args.setdefault("binary_prefix", True) - metadb.bind(**settings) - metadb.generate_mapping(check_tables=False) + if uri.scheme.startswith("mysql"): + provider = "mysql" + elif uri.scheme.startswith("postgres"): + provider = "postgres" + elif uri.scheme.startswith("sqlite"): + provider = "sqlite" + else: + raise RuntimeError(f"Unsupported database: {uri.scheme}") + + db_class = schemes.get(uri.scheme) + db.initialize(db_class(**args)) + db.connect() # Check if we should create the tables - try: - metadb.check_tables() - except DatabaseError: - with db_session: - execute_sql_resource_script("schema/" + settings["provider"] + ".sql") - Meta(key="schema_version", value=SCHEMA_VERSION) + if not db.table_exists("meta"): + execute_sql_resource_script(f"schema/{provider}.sql") + Meta.create(key="schema_version", value=SCHEMA_VERSION) # Check for schema changes - with db_session: - version = Meta["schema_version"] - if version.value < SCHEMA_VERSION: - migrations = sorted( - pkg_resources.resource_listdir( - __package__, "schema/migration/" + settings["provider"] + version = Meta["schema_version"] + if version.value < SCHEMA_VERSION: + migrations = sorted( + pkg_resources.resource_listdir(__package__, f"schema/migration/{provider}") + ) + for migration in migrations: + date, ext = os.path.splitext(migration) + if date <= version.value: + continue + if ext == ".sql": + execute_sql_resource_script(f"schema/migration/{provider}/{migration}") + elif ext == ".py": + m = importlib.import_module( + f".schema.migration.{provider}.{date}", __package__ ) - ) - for migration in migrations: - date, ext = os.path.splitext(migration) - if date <= version.value: - continue - if ext == ".sql": - execute_sql_resource_script( - "schema/migration/{}/{}".format(settings["provider"], migration) - ) - elif ext == ".py": - m = importlib.import_module( - ".schema.migration.{}.{}".format(settings["provider"], date), - __package__, - ) - m.apply(settings.copy()) - version.value = SCHEMA_VERSION + m.apply(args.copy()) - # Hack for in-memory SQLite databases (used in tests), otherwise 'db' and 'metadb' would be two distinct databases - # and 'db' wouldn't have any table - if settings["provider"] == "sqlite" and settings["filename"] == ":memory:": - db.provider = metadb.provider - else: - metadb.disconnect() - db.bind(**settings) - # Force requests to Meta to use the same connection as other tables - metadb.provider = db.provider - - db.generate_mapping(check_tables=False) + version.value = SCHEMA_VERSION + version.save() def release_database(): - metadb.disconnect() - db.disconnect() - db.provider = metadb.provider = None - db.schema = metadb.schema = None + db.close() + db.initialize(None) diff --git a/tests/base/test_db.py b/tests/base/test_db.py index 9fd5d09..7acbc34 100644 --- a/tests/base/test_db.py +++ b/tests/base/test_db.py @@ -1,7 +1,7 @@ # This file is part of Supysonic. # Supysonic is a Python implementation of the Subsonic server API. # -# Copyright (C) 2017-2018 Alban 'spl0k' Féron +# Copyright (C) 2017-2022 Alban 'spl0k' Féron # # Distributed under terms of the GNU AGPLv3 license. @@ -10,7 +10,6 @@ import unittest import uuid from collections import namedtuple -from pony.orm import db_session from supysonic import db @@ -30,9 +29,9 @@ class DbTestCase(unittest.TestCase): db.release_database() def create_some_folders(self): - root_folder = db.Folder(root=True, name="Root folder", path="tests") + root_folder = db.Folder.create(root=True, name="Root folder", path="tests") - db.Folder( + f1 = db.Folder.create( root=False, name="Child folder", path="tests/assets", @@ -40,30 +39,25 @@ class DbTestCase(unittest.TestCase): parent=root_folder, ) - db.Folder( + f2 = db.Folder.create( root=False, name="Child folder (No Art)", path="tests/formats", parent=root_folder, ) - # Folder IDs don't get populated until we query the db. - return ( - db.Folder.get(name="Root folder"), - db.Folder.get(name="Child folder"), - db.Folder.get(name="Child Folder (No Art)"), - ) + return root_folder, f1, f2 def create_some_tracks(self, artist=None, album=None): root, child, child_2 = self.create_some_folders() if not artist: - artist = db.Artist(name="Test artist") + artist = db.Artist.create(name="Test artist") if not album: - album = db.Album(artist=artist, name="Test Album") + album = db.Album.create(artist=artist, name="Test Album") - track1 = db.Track( + track1 = db.Track.create( title="Track Title", album=album, artist=artist, @@ -78,7 +72,7 @@ class DbTestCase(unittest.TestCase): folder=child, ) - track2 = db.Track( + track2 = db.Track.create( title="One Awesome Song", album=album, artist=artist, @@ -95,9 +89,9 @@ class DbTestCase(unittest.TestCase): return track1, track2 def create_track_in(self, folder, root, artist=None, album=None, has_art=True): - artist = artist or db.Artist(name="Snazzy Artist") - album = album or db.Album(artist=artist, name="Rockin' Album") - return db.Track( + artist = artist or db.Artist.create(name="Snazzy Artist") + album = album or db.Album.create(artist=artist, name="Rockin' Album") + return db.Track.create( title="Nifty Number", album=album, artist=artist, @@ -113,15 +107,13 @@ class DbTestCase(unittest.TestCase): ) def create_user(self, name="Test User"): - return db.User(name=name, password="secret", salt="ABC+") + return db.User.create(name=name, password="secret", salt="ABC+") def create_playlist(self): - - playlist = db.Playlist(user=self.create_user(), name="Playlist!") + playlist = db.Playlist.create(user=self.create_user(), name="Playlist!") return playlist - @db_session def test_folder_base(self): root_folder, child_folder, child_noart = self.create_some_folders() track_embededart = self.create_track_in(child_noart, root_folder) @@ -153,15 +145,14 @@ class DbTestCase(unittest.TestCase): self.assertIn("coverArt", noart) self.assertEqual(noart["coverArt"], str(track_embededart.id)) - @db_session def test_folder_annotation(self): root_folder, child_folder, _ = self.create_some_folders() user = self.create_user() - db.StarredFolder(user=user, starred=root_folder) - db.RatingFolder(user=user, rated=root_folder, rating=2) + db.StarredFolder.create(user=user, starred=root_folder) + db.RatingFolder.create(user=user, rated=root_folder, rating=2) other = self.create_user("Other") - db.RatingFolder(user=other, rated=root_folder, rating=5) + db.RatingFolder.create(user=other, rated=root_folder, rating=5) root = root_folder.as_subsonic_child(user) self.assertIn("starred", root) @@ -175,12 +166,11 @@ class DbTestCase(unittest.TestCase): self.assertNotIn("starred", child) self.assertNotIn("userRating", child) - @db_session def test_artist(self): - artist = db.Artist(name="Test Artist") + artist = db.Artist.create(name="Test Artist") user = self.create_user() - db.StarredArtist(user=user, starred=artist) + db.StarredArtist.create(user=user, starred=artist) artist_dict = artist.as_subsonic_artist(user) self.assertIsInstance(artist_dict, dict) @@ -192,22 +182,18 @@ class DbTestCase(unittest.TestCase): self.assertEqual(artist_dict["albumCount"], 0) self.assertRegex(artist_dict["starred"], date_regex) - db.Album(name="Test Artist", artist=artist) # self-titled - db.Album(name="The Album After The First One", artist=artist) + db.Album.create(name="Test Artist", artist=artist) # self-titled + db.Album.create(name="The Album After The First One", artist=artist) artist_dict = artist.as_subsonic_artist(user) self.assertEqual(artist_dict["albumCount"], 2) - @db_session def test_album(self): - artist = db.Artist(name="Test Artist") - album = db.Album(artist=artist, name="Test Album") + artist = db.Artist.create(name="Test Artist") + album = db.Album.create(artist=artist, name="Test Album") user = self.create_user() - db.StarredAlbum(user=user, starred=album) - - # No tracks, shouldn't be stored under normal circumstances - self.assertRaises(ValueError, album.as_subsonic_album, user) + db.StarredAlbum.create(user=user, starred=album) root_folder, folder_art, folder_noart = self.create_some_folders() track1 = self.create_track_in( @@ -234,7 +220,6 @@ class DbTestCase(unittest.TestCase): self.assertRegex(album_dict["created"], date_regex) self.assertRegex(album_dict["starred"], date_regex) - @db_session def test_track(self): track1, track2 = self.create_some_tracks() @@ -256,14 +241,12 @@ class DbTestCase(unittest.TestCase): self.assertEqual(track2_dict["coverArt"], track2_dict["parent"]) # ... we'll test the rest against the API XSD. - @db_session def test_user(self): user = self.create_user() user_dict = user.as_subsonic_user() self.assertIsInstance(user_dict, dict) - @db_session def test_chat(self): user = self.create_user() @@ -274,13 +257,11 @@ class DbTestCase(unittest.TestCase): self.assertIn("username", line_dict) self.assertEqual(line_dict["username"], user.name) - @db_session def test_playlist(self): playlist = self.create_playlist() playlist_dict = playlist.as_subsonic_playlist(playlist.user) self.assertIsInstance(playlist_dict, dict) - @db_session def test_playlist_tracks(self): playlist = self.create_playlist() track1, track2 = self.create_some_tracks() @@ -304,7 +285,6 @@ class DbTestCase(unittest.TestCase): self.assertRaises(ValueError, playlist.add, "some string") self.assertRaises(NameError, playlist.add, 2345) - @db_session def test_playlist_remove_tracks(self): playlist = self.create_playlist() track1, track2 = self.create_some_tracks() @@ -324,7 +304,6 @@ class DbTestCase(unittest.TestCase): playlist.remove_at_indexes([1, 1]) self.assertSequenceEqual(playlist.get_tracks(), [track2, track1]) - @db_session def test_playlist_fixing(self): playlist = self.create_playlist() track1, track2 = self.create_some_tracks() @@ -334,7 +313,7 @@ class DbTestCase(unittest.TestCase): playlist.add(track2) self.assertSequenceEqual(playlist.get_tracks(), [track1, track2]) - track2.delete() + track2.delete_instance() self.assertSequenceEqual(playlist.get_tracks(), [track1]) playlist.tracks = "{0},{0},some random garbage,{0}".format(track1.id)