1
0
mirror of https://github.com/spl0k/supysonic.git synced 2024-12-22 17:06:17 +00:00

Fixing a good chunk of supysonic.db

This commit is contained in:
Alban Féron 2022-12-10 15:14:37 +01:00
parent 0b6891a5c4
commit 6bdee81e57
No known key found for this signature in database
GPG Key ID: 8CE0313646D16165
2 changed files with 108 additions and 157 deletions

View File

@ -25,9 +25,10 @@ from peewee import (
IntegerField, IntegerField,
TextField, TextField,
) )
from peewee import CompositeKey from peewee import CompositeKey, DatabaseProxy
from playhouse.flask_utils import FlaskDB from peewee import fn
from urllib.parse import urlparse, parse_qsl from playhouse.db_url import parseresult_to_dict, schemes
from urllib.parse import urlparse
from uuid import UUID, uuid4 from uuid import UUID, uuid4
SCHEMA_VERSION = "20200607" SCHEMA_VERSION = "20200607"
@ -41,7 +42,7 @@ def PrimaryKeyField(**kwargs):
return BinaryUUIDField(primary_key=True, default=uuid4, **kwargs) return BinaryUUIDField(primary_key=True, default=uuid4, **kwargs)
db = FlaskDB() db = DatabaseProxy()
class Meta(db.Model): class Meta(db.Model):
@ -105,16 +106,20 @@ class Folder(PathMixin, db.Model):
try: try:
starred = StarredFolder[user.id, self.id] starred = StarredFolder[user.id, self.id]
info["starred"] = starred.date.isoformat() info["starred"] = starred.date.isoformat()
except ObjectNotFound: except StarredFolder.DoesNotExist:
pass pass
try: try:
rating = RatingFolder[user.id, self.id] rating = RatingFolder[user.id, self.id]
info["userRating"] = rating.rating info["userRating"] = rating.rating
except ObjectNotFound: except RatingFolder.DoesNotExist:
pass pass
avgRating = avg(self.ratings.rating) avgRating = (
RatingFolder.select(fn.avg(RatingFolder.rating))
.where(RatingFolder.rated == self)
.scalar()
)
if avgRating: if avgRating:
info["averageRating"] = avgRating info["averageRating"] = avgRating
@ -126,7 +131,7 @@ class Folder(PathMixin, db.Model):
try: try:
starred = StarredFolder[user.id, self.id] starred = StarredFolder[user.id, self.id]
info["starred"] = starred.date.isoformat() info["starred"] = starred.date.isoformat()
except ObjectNotFound: except StarredFolder.DoesNotExist:
pass pass
return info return info
@ -179,7 +184,7 @@ class Artist(db.Model):
try: try:
starred = StarredArtist[user.id, self.id] starred = StarredArtist[user.id, self.id]
info["starred"] = starred.date.isoformat() info["starred"] = starred.date.isoformat()
except ObjectNotFound: except StarredArtist.DoesNotExist:
pass pass
return info return info
@ -198,43 +203,53 @@ class Album(db.Model):
artist = ForeignKeyField(Artist, backref="albums") artist = ForeignKeyField(Artist, backref="albums")
def as_subsonic_album(self, user): # "AlbumID3" type in XSD def as_subsonic_album(self, user): # "AlbumID3" type in XSD
duration, created, year = self.tracks.select(
fn.sum(Track.duration), fn.min(Track.created), fn.min(Track.year)
).scalar(as_tuple=True)
info = { info = {
"id": str(self.id), "id": str(self.id),
"name": self.name, "name": self.name,
"artist": self.artist.name, "artist": self.artist.name,
"artistId": str(self.artist.id), "artistId": str(self.artist.id),
"songCount": self.tracks.count(), "songCount": self.tracks.count(),
"duration": sum(self.tracks.duration), "duration": duration,
"created": min(self.tracks.created).isoformat(), "created": created.isoformat(),
} }
track_with_cover = self.tracks.select( track_with_cover = (
lambda t: t.folder.cover_art is not None self.tracks.join(Folder).where(Folder.cover_art.is_null(False)).first()
).first() )
if track_with_cover is not None: if track_with_cover is not None:
info["coverArt"] = str(track_with_cover.folder.id) info["coverArt"] = str(track_with_cover.folder.id)
else: else:
track_with_cover = self.tracks.select(lambda t: t.has_art).first() track_with_cover = self.tracks.where(Track.has_art).first()
if track_with_cover is not None: if track_with_cover is not None:
info["coverArt"] = str(track_with_cover.id) info["coverArt"] = str(track_with_cover.id)
if count(self.tracks.year) > 0: if year:
info["year"] = min(self.tracks.year) info["year"] = year
genre = ", ".join(self.tracks.genre.distinct()) genre = ", ".join(
g
for (g,) in self.tracks.select(Track.genre)
.where(Track.genre.is_null(False))
.distinct()
.tuples()
)
if genre: if genre:
info["genre"] = genre info["genre"] = genre
try: try:
starred = StarredAlbum[user.id, self.id] starred = StarredAlbum[user.id, self.id]
info["starred"] = starred.date.isoformat() info["starred"] = starred.date.isoformat()
except ObjectNotFound: except StarredAlbum.DoesNotExist:
pass pass
return info return info
def sort_key(self): def sort_key(self):
year = min(t.year if t.year else 9999 for t in self.tracks) year = self.tracks.select(fn.min(Track.year)).scalar() or 9999
return f"{year}{self.name.lower()}" return f"{year}{self.name.lower()}"
@classmethod @classmethod
@ -305,16 +320,20 @@ class Track(PathMixin, db.Model):
try: try:
starred = StarredTrack[user.id, self.id] starred = StarredTrack[user.id, self.id]
info["starred"] = starred.date.isoformat() info["starred"] = starred.date.isoformat()
except ObjectNotFound: except StarredTrack.DoesNotExist:
pass pass
try: try:
rating = RatingTrack[user.id, self.id] rating = RatingTrack[user.id, self.id]
info["userRating"] = rating.rating info["userRating"] = rating.rating
except ObjectNotFound: except RatingTrack.DoesNotExist:
pass pass
avgRating = avg(self.ratings.rating) avgRating = (
RatingTrack.select(fn.avg(RatingTrack.rating))
.where(RatingTrack.rated == self)
.scalar()
)
if avgRating: if avgRating:
info["averageRating"] = avgRating info["averageRating"] = avgRating
@ -483,7 +502,7 @@ class Playlist(db.Model):
tid = UUID(t) tid = UUID(t)
track = Track[tid] track = Track[tid]
tracks.append(track) tracks.append(track)
except (ValueError, ObjectNotFound): except (ValueError, Track.DoesNotExist):
should_fix = True should_fix = True
if should_fix: if should_fix:
@ -535,108 +554,61 @@ class RadioStation(db.Model):
return info return info
def parse_uri(database_uri):
if not isinstance(database_uri, str):
raise TypeError("Expecting a string")
uri = urlparse(database_uri)
args = dict(parse_qsl(uri.query))
if uri.port is not None:
args["port"] = uri.port
if uri.scheme == "sqlite":
path = uri.path
if not path:
path = ":memory:"
elif path[0] == "/":
path = path[1:]
return {"provider": "sqlite", "filename": path, "create_db": True, **args}
elif uri.scheme in ("postgres", "postgresql"):
return {
"provider": "postgres",
"user": uri.username,
"password": uri.password,
"host": uri.hostname,
"dbname": uri.path[1:],
**args,
}
elif uri.scheme == "mysql":
args.setdefault("charset", "utf8mb4")
args.setdefault("binary_prefix", True)
return {
"provider": "mysql",
"user": uri.username,
"passwd": uri.password,
"host": uri.hostname,
"db": uri.path[1:],
**args,
}
return {}
def execute_sql_resource_script(respath): def execute_sql_resource_script(respath):
sql = pkg_resources.resource_string(__package__, respath).decode("utf-8") sql = pkg_resources.resource_string(__package__, respath).decode("utf-8")
for statement in sql.split(";"): for statement in sql.split(";"):
statement = statement.strip() statement = statement.strip()
if statement and not statement.startswith("--"): if statement and not statement.startswith("--"):
metadb.execute(statement) db.execute_sql(statement)
def init_database(database_uri): def init_database(database_uri):
settings = parse_uri(database_uri) uri = urlparse(database_uri)
args = parseresult_to_dict(uri)
if uri.scheme.startswith("mysql"):
args.setdefault("charset", "utf8mb4")
args.setdefault("binary_prefix", True)
metadb.bind(**settings) if uri.scheme.startswith("mysql"):
metadb.generate_mapping(check_tables=False) provider = "mysql"
elif uri.scheme.startswith("postgres"):
provider = "postgres"
elif uri.scheme.startswith("sqlite"):
provider = "sqlite"
else:
raise RuntimeError(f"Unsupported database: {uri.scheme}")
db_class = schemes.get(uri.scheme)
db.initialize(db_class(**args))
db.connect()
# Check if we should create the tables # Check if we should create the tables
try: if not db.table_exists("meta"):
metadb.check_tables() execute_sql_resource_script(f"schema/{provider}.sql")
except DatabaseError: Meta.create(key="schema_version", value=SCHEMA_VERSION)
with db_session:
execute_sql_resource_script("schema/" + settings["provider"] + ".sql")
Meta(key="schema_version", value=SCHEMA_VERSION)
# Check for schema changes # Check for schema changes
with db_session:
version = Meta["schema_version"] version = Meta["schema_version"]
if version.value < SCHEMA_VERSION: if version.value < SCHEMA_VERSION:
migrations = sorted( migrations = sorted(
pkg_resources.resource_listdir( pkg_resources.resource_listdir(__package__, f"schema/migration/{provider}")
__package__, "schema/migration/" + settings["provider"]
)
) )
for migration in migrations: for migration in migrations:
date, ext = os.path.splitext(migration) date, ext = os.path.splitext(migration)
if date <= version.value: if date <= version.value:
continue continue
if ext == ".sql": if ext == ".sql":
execute_sql_resource_script( execute_sql_resource_script(f"schema/migration/{provider}/{migration}")
"schema/migration/{}/{}".format(settings["provider"], migration)
)
elif ext == ".py": elif ext == ".py":
m = importlib.import_module( m = importlib.import_module(
".schema.migration.{}.{}".format(settings["provider"], date), f".schema.migration.{provider}.{date}", __package__
__package__,
) )
m.apply(settings.copy()) m.apply(args.copy())
version.value = SCHEMA_VERSION version.value = SCHEMA_VERSION
version.save()
# Hack for in-memory SQLite databases (used in tests), otherwise 'db' and 'metadb' would be two distinct databases
# and 'db' wouldn't have any table
if settings["provider"] == "sqlite" and settings["filename"] == ":memory:":
db.provider = metadb.provider
else:
metadb.disconnect()
db.bind(**settings)
# Force requests to Meta to use the same connection as other tables
metadb.provider = db.provider
db.generate_mapping(check_tables=False)
def release_database(): def release_database():
metadb.disconnect() db.close()
db.disconnect() db.initialize(None)
db.provider = metadb.provider = None
db.schema = metadb.schema = None

View File

@ -1,7 +1,7 @@
# This file is part of Supysonic. # This file is part of Supysonic.
# Supysonic is a Python implementation of the Subsonic server API. # Supysonic is a Python implementation of the Subsonic server API.
# #
# Copyright (C) 2017-2018 Alban 'spl0k' Féron # Copyright (C) 2017-2022 Alban 'spl0k' Féron
# #
# Distributed under terms of the GNU AGPLv3 license. # Distributed under terms of the GNU AGPLv3 license.
@ -10,7 +10,6 @@ import unittest
import uuid import uuid
from collections import namedtuple from collections import namedtuple
from pony.orm import db_session
from supysonic import db from supysonic import db
@ -30,9 +29,9 @@ class DbTestCase(unittest.TestCase):
db.release_database() db.release_database()
def create_some_folders(self): def create_some_folders(self):
root_folder = db.Folder(root=True, name="Root folder", path="tests") root_folder = db.Folder.create(root=True, name="Root folder", path="tests")
db.Folder( f1 = db.Folder.create(
root=False, root=False,
name="Child folder", name="Child folder",
path="tests/assets", path="tests/assets",
@ -40,30 +39,25 @@ class DbTestCase(unittest.TestCase):
parent=root_folder, parent=root_folder,
) )
db.Folder( f2 = db.Folder.create(
root=False, root=False,
name="Child folder (No Art)", name="Child folder (No Art)",
path="tests/formats", path="tests/formats",
parent=root_folder, parent=root_folder,
) )
# Folder IDs don't get populated until we query the db. return root_folder, f1, f2
return (
db.Folder.get(name="Root folder"),
db.Folder.get(name="Child folder"),
db.Folder.get(name="Child Folder (No Art)"),
)
def create_some_tracks(self, artist=None, album=None): def create_some_tracks(self, artist=None, album=None):
root, child, child_2 = self.create_some_folders() root, child, child_2 = self.create_some_folders()
if not artist: if not artist:
artist = db.Artist(name="Test artist") artist = db.Artist.create(name="Test artist")
if not album: if not album:
album = db.Album(artist=artist, name="Test Album") album = db.Album.create(artist=artist, name="Test Album")
track1 = db.Track( track1 = db.Track.create(
title="Track Title", title="Track Title",
album=album, album=album,
artist=artist, artist=artist,
@ -78,7 +72,7 @@ class DbTestCase(unittest.TestCase):
folder=child, folder=child,
) )
track2 = db.Track( track2 = db.Track.create(
title="One Awesome Song", title="One Awesome Song",
album=album, album=album,
artist=artist, artist=artist,
@ -95,9 +89,9 @@ class DbTestCase(unittest.TestCase):
return track1, track2 return track1, track2
def create_track_in(self, folder, root, artist=None, album=None, has_art=True): def create_track_in(self, folder, root, artist=None, album=None, has_art=True):
artist = artist or db.Artist(name="Snazzy Artist") artist = artist or db.Artist.create(name="Snazzy Artist")
album = album or db.Album(artist=artist, name="Rockin' Album") album = album or db.Album.create(artist=artist, name="Rockin' Album")
return db.Track( return db.Track.create(
title="Nifty Number", title="Nifty Number",
album=album, album=album,
artist=artist, artist=artist,
@ -113,15 +107,13 @@ class DbTestCase(unittest.TestCase):
) )
def create_user(self, name="Test User"): def create_user(self, name="Test User"):
return db.User(name=name, password="secret", salt="ABC+") return db.User.create(name=name, password="secret", salt="ABC+")
def create_playlist(self): def create_playlist(self):
playlist = db.Playlist.create(user=self.create_user(), name="Playlist!")
playlist = db.Playlist(user=self.create_user(), name="Playlist!")
return playlist return playlist
@db_session
def test_folder_base(self): def test_folder_base(self):
root_folder, child_folder, child_noart = self.create_some_folders() root_folder, child_folder, child_noart = self.create_some_folders()
track_embededart = self.create_track_in(child_noart, root_folder) track_embededart = self.create_track_in(child_noart, root_folder)
@ -153,15 +145,14 @@ class DbTestCase(unittest.TestCase):
self.assertIn("coverArt", noart) self.assertIn("coverArt", noart)
self.assertEqual(noart["coverArt"], str(track_embededart.id)) self.assertEqual(noart["coverArt"], str(track_embededart.id))
@db_session
def test_folder_annotation(self): def test_folder_annotation(self):
root_folder, child_folder, _ = self.create_some_folders() root_folder, child_folder, _ = self.create_some_folders()
user = self.create_user() user = self.create_user()
db.StarredFolder(user=user, starred=root_folder) db.StarredFolder.create(user=user, starred=root_folder)
db.RatingFolder(user=user, rated=root_folder, rating=2) db.RatingFolder.create(user=user, rated=root_folder, rating=2)
other = self.create_user("Other") other = self.create_user("Other")
db.RatingFolder(user=other, rated=root_folder, rating=5) db.RatingFolder.create(user=other, rated=root_folder, rating=5)
root = root_folder.as_subsonic_child(user) root = root_folder.as_subsonic_child(user)
self.assertIn("starred", root) self.assertIn("starred", root)
@ -175,12 +166,11 @@ class DbTestCase(unittest.TestCase):
self.assertNotIn("starred", child) self.assertNotIn("starred", child)
self.assertNotIn("userRating", child) self.assertNotIn("userRating", child)
@db_session
def test_artist(self): def test_artist(self):
artist = db.Artist(name="Test Artist") artist = db.Artist.create(name="Test Artist")
user = self.create_user() user = self.create_user()
db.StarredArtist(user=user, starred=artist) db.StarredArtist.create(user=user, starred=artist)
artist_dict = artist.as_subsonic_artist(user) artist_dict = artist.as_subsonic_artist(user)
self.assertIsInstance(artist_dict, dict) self.assertIsInstance(artist_dict, dict)
@ -192,22 +182,18 @@ class DbTestCase(unittest.TestCase):
self.assertEqual(artist_dict["albumCount"], 0) self.assertEqual(artist_dict["albumCount"], 0)
self.assertRegex(artist_dict["starred"], date_regex) self.assertRegex(artist_dict["starred"], date_regex)
db.Album(name="Test Artist", artist=artist) # self-titled db.Album.create(name="Test Artist", artist=artist) # self-titled
db.Album(name="The Album After The First One", artist=artist) db.Album.create(name="The Album After The First One", artist=artist)
artist_dict = artist.as_subsonic_artist(user) artist_dict = artist.as_subsonic_artist(user)
self.assertEqual(artist_dict["albumCount"], 2) self.assertEqual(artist_dict["albumCount"], 2)
@db_session
def test_album(self): def test_album(self):
artist = db.Artist(name="Test Artist") artist = db.Artist.create(name="Test Artist")
album = db.Album(artist=artist, name="Test Album") album = db.Album.create(artist=artist, name="Test Album")
user = self.create_user() user = self.create_user()
db.StarredAlbum(user=user, starred=album) db.StarredAlbum.create(user=user, starred=album)
# No tracks, shouldn't be stored under normal circumstances
self.assertRaises(ValueError, album.as_subsonic_album, user)
root_folder, folder_art, folder_noart = self.create_some_folders() root_folder, folder_art, folder_noart = self.create_some_folders()
track1 = self.create_track_in( track1 = self.create_track_in(
@ -234,7 +220,6 @@ class DbTestCase(unittest.TestCase):
self.assertRegex(album_dict["created"], date_regex) self.assertRegex(album_dict["created"], date_regex)
self.assertRegex(album_dict["starred"], date_regex) self.assertRegex(album_dict["starred"], date_regex)
@db_session
def test_track(self): def test_track(self):
track1, track2 = self.create_some_tracks() track1, track2 = self.create_some_tracks()
@ -256,14 +241,12 @@ class DbTestCase(unittest.TestCase):
self.assertEqual(track2_dict["coverArt"], track2_dict["parent"]) self.assertEqual(track2_dict["coverArt"], track2_dict["parent"])
# ... we'll test the rest against the API XSD. # ... we'll test the rest against the API XSD.
@db_session
def test_user(self): def test_user(self):
user = self.create_user() user = self.create_user()
user_dict = user.as_subsonic_user() user_dict = user.as_subsonic_user()
self.assertIsInstance(user_dict, dict) self.assertIsInstance(user_dict, dict)
@db_session
def test_chat(self): def test_chat(self):
user = self.create_user() user = self.create_user()
@ -274,13 +257,11 @@ class DbTestCase(unittest.TestCase):
self.assertIn("username", line_dict) self.assertIn("username", line_dict)
self.assertEqual(line_dict["username"], user.name) self.assertEqual(line_dict["username"], user.name)
@db_session
def test_playlist(self): def test_playlist(self):
playlist = self.create_playlist() playlist = self.create_playlist()
playlist_dict = playlist.as_subsonic_playlist(playlist.user) playlist_dict = playlist.as_subsonic_playlist(playlist.user)
self.assertIsInstance(playlist_dict, dict) self.assertIsInstance(playlist_dict, dict)
@db_session
def test_playlist_tracks(self): def test_playlist_tracks(self):
playlist = self.create_playlist() playlist = self.create_playlist()
track1, track2 = self.create_some_tracks() track1, track2 = self.create_some_tracks()
@ -304,7 +285,6 @@ class DbTestCase(unittest.TestCase):
self.assertRaises(ValueError, playlist.add, "some string") self.assertRaises(ValueError, playlist.add, "some string")
self.assertRaises(NameError, playlist.add, 2345) self.assertRaises(NameError, playlist.add, 2345)
@db_session
def test_playlist_remove_tracks(self): def test_playlist_remove_tracks(self):
playlist = self.create_playlist() playlist = self.create_playlist()
track1, track2 = self.create_some_tracks() track1, track2 = self.create_some_tracks()
@ -324,7 +304,6 @@ class DbTestCase(unittest.TestCase):
playlist.remove_at_indexes([1, 1]) playlist.remove_at_indexes([1, 1])
self.assertSequenceEqual(playlist.get_tracks(), [track2, track1]) self.assertSequenceEqual(playlist.get_tracks(), [track2, track1])
@db_session
def test_playlist_fixing(self): def test_playlist_fixing(self):
playlist = self.create_playlist() playlist = self.create_playlist()
track1, track2 = self.create_some_tracks() track1, track2 = self.create_some_tracks()
@ -334,7 +313,7 @@ class DbTestCase(unittest.TestCase):
playlist.add(track2) playlist.add(track2)
self.assertSequenceEqual(playlist.get_tracks(), [track1, track2]) self.assertSequenceEqual(playlist.get_tracks(), [track1, track2])
track2.delete() track2.delete_instance()
self.assertSequenceEqual(playlist.get_tracks(), [track1]) self.assertSequenceEqual(playlist.get_tracks(), [track1])
playlist.tracks = "{0},{0},some random garbage,{0}".format(track1.id) playlist.tracks = "{0},{0},some random garbage,{0}".format(track1.id)