diff --git a/supysonic/api/playlists.py b/supysonic/api/playlists.py index 6ae9079..1eb967b 100644 --- a/supysonic/api/playlists.py +++ b/supysonic/api/playlists.py @@ -9,30 +9,29 @@ import uuid from flask import request -from ..db import Playlist, User, Track +from ..db import Playlist, User, Track, db from . import get_entity, api_routing -from .exceptions import Forbidden, MissingParameter, NotFound +from .exceptions import Forbidden, MissingParameter @api_routing("/getPlaylists") def list_playlists(): - query = Playlist.select( - lambda p: p.user.id == request.user.id or p.public - ).order_by(Playlist.name) + query = ( + Playlist.select() + .orwhere(Playlist.user == request.user, Playlist.public) + .order_by(Playlist.name) + ) username = request.values.get("username") if username: if not request.user.admin: raise Forbidden() + # get rather than join in the following query to raise an exception if the + # requested user doesn't exist user = User.get(name=username) - if user is None: - raise NotFound("User") - - query = Playlist.select(lambda p: p.user.name == username).order_by( - Playlist.name - ) + query = Playlist.select().where(Playlist.user == user).order_by(Playlist.name) return request.formatter( "playlists", @@ -43,7 +42,7 @@ def list_playlists(): @api_routing("/getPlaylist") def show_playlist(): res = get_entity(Playlist) - if res.user.id != request.user.id and not res.public and not request.user.admin: + if res.user != request.user and not res.public and not request.user.admin: raise Forbidden() info = res.as_subsonic_playlist(request.user) @@ -54,6 +53,7 @@ def show_playlist(): @api_routing("/createPlaylist") +@db.atomic() def create_playlist(): playlist_id, name = map(request.values.get, ("playlistId", "name")) # songId actually doesn't seem to be required @@ -63,14 +63,14 @@ def create_playlist(): if playlist_id: playlist = Playlist[playlist_id] - if playlist.user.id != request.user.id and not request.user.admin: + if playlist.user != request.user and not request.user.admin: raise Forbidden() playlist.clear() if name: playlist.name = name elif name: - playlist = Playlist(user=request.user, name=name) + playlist = Playlist.create(user=request.user, name=name) else: raise MissingParameter("playlistId or name") @@ -78,6 +78,7 @@ def create_playlist(): sid = uuid.UUID(sid) track = Track[sid] playlist.add(track) + playlist.save() return request.formatter.empty @@ -85,17 +86,17 @@ def create_playlist(): @api_routing("/deletePlaylist") def delete_playlist(): res = get_entity(Playlist) - if res.user.id != request.user.id and not request.user.admin: + if res.user != request.user and not request.user.admin: raise Forbidden() - res.delete() + res.delete_instance() return request.formatter.empty @api_routing("/updatePlaylist") def update_playlist(): res = get_entity(Playlist, "playlistId") - if res.user.id != request.user.id and not request.user.admin: + if res.user != request.user and not request.user.admin: raise Forbidden() playlist = res @@ -119,5 +120,6 @@ def update_playlist(): playlist.add(track) playlist.remove_at_indexes(to_remove) + playlist.save() return request.formatter.empty diff --git a/tests/api/test_playlist.py b/tests/api/test_playlist.py index a7c2a71..d93d379 100644 --- a/tests/api/test_playlist.py +++ b/tests/api/test_playlist.py @@ -1,15 +1,13 @@ # This file is part of Supysonic. # Supysonic is a Python implementation of the Subsonic server API. # -# Copyright (C) 2017-2018 Alban 'spl0k' Féron +# Copyright (C) 2017-2022 Alban 'spl0k' Féron # # Distributed under terms of the GNU AGPLv3 license. import unittest import uuid -from pony.orm import db_session - from supysonic.db import Folder, Artist, Album, Track, Playlist, User from .apitestbase import ApiTestBase @@ -19,41 +17,45 @@ class PlaylistTestCase(ApiTestBase): def setUp(self): super().setUp() - with db_session: - root = Folder(root=True, name="Root folder", path="tests/assets") - artist = Artist(name="Artist") - album = Album(name="Album", artist=artist) + root = Folder.create(root=True, name="Root folder", path="tests/assets") + artist = Artist.create(name="Artist") + album = Album.create(name="Album", artist=artist) - songs = {} - for num, song in enumerate(["One", "Two", "Three", "Four"]): - track = Track( - disc=1, - number=num, - title=song, - duration=2, - album=album, - artist=artist, - bitrate=320, - path="tests/assets/" + song, - last_modification=0, - root_folder=root, - folder=root, - ) - songs[song] = track + songs = {} + for num, song in enumerate(["One", "Two", "Three", "Four"]): + track = Track.create( + disc=1, + number=num, + title=song, + duration=2, + album=album, + artist=artist, + bitrate=320, + path="tests/assets/" + song, + last_modification=0, + root_folder=root, + folder=root, + ) + songs[song] = track - users = {u.name: u for u in User.select()} + users = {u.name: u for u in User.select()} - playlist = Playlist(user=users["alice"], name="Alice's") - playlist.add(songs["One"]) - playlist.add(songs["Three"]) + playlist = Playlist.create(user=users["alice"], name="Alice's") + playlist.add(songs["One"]) + playlist.add(songs["Three"]) + playlist.save() - playlist = Playlist(user=users["alice"], public=True, name="Alice's public") - playlist.add(songs["One"]) - playlist.add(songs["Two"]) + playlist = Playlist.create( + user=users["alice"], public=True, name="Alice's public" + ) + playlist.add(songs["One"]) + playlist.add(songs["Two"]) + playlist.save() - playlist = Playlist(user=users["bob"], name="Bob's") - playlist.add(songs["Two"]) - playlist.add(songs["Four"]) + playlist = Playlist.create(user=users["bob"], name="Bob's") + playlist.add(songs["Two"]) + playlist.add(songs["Four"]) + playlist.save() def test_get_playlists(self): # get own playlists @@ -99,8 +101,12 @@ class PlaylistTestCase(ApiTestBase): self._make_request("getPlaylist", {"id": str(uuid.uuid4())}, error=70) # other's private from non admin - with db_session: - playlist = Playlist.get(lambda p: not p.public and p.user.name == "alice") + playlist = ( + Playlist.select() + .join(User) + .where(~Playlist.public, User.name == "alice") + .get() + ) self._make_request( "getPlaylist", {"u": "bob", "p": "B0b", "id": str(playlist.id)}, error=50 ) @@ -166,8 +172,7 @@ class PlaylistTestCase(ApiTestBase): ) # create more useful playlist - with db_session: - songs = {s.title: str(s.id) for s in Track.select()} + songs = {s.title: str(s.id) for s in Track.select()} self._make_request( "createPlaylist", { @@ -176,8 +181,7 @@ class PlaylistTestCase(ApiTestBase): }, skip_post=True, ) - with db_session: - playlist = Playlist.get(name="songs") + playlist = Playlist.get(name="songs") self.assertIsNotNone(playlist) rv, child = self._make_request( "getPlaylist", {"id": str(playlist.id)}, tag="playlist" @@ -201,7 +205,6 @@ class PlaylistTestCase(ApiTestBase): self.assertEqual(self._xpath(child, "count(./entry)"), 1) self.assertEqual(child[0].get("title"), "Two") - @db_session def assertPlaylistCountEqual(self, count): self.assertEqual(Playlist.select().count(), count) @@ -212,8 +215,7 @@ class PlaylistTestCase(ApiTestBase): self._make_request("deletePlaylist", {"id": str(uuid.uuid4())}, error=70) # delete unowned when not admin - with db_session: - playlist = Playlist.select(lambda p: p.user.name == "alice").first() + playlist = Playlist.select().join(User).where(User.name == "alice").first() self._make_request( "deletePlaylist", {"u": "bob", "p": "B0b", "id": str(playlist.id)}, error=50 ) @@ -226,8 +228,7 @@ class PlaylistTestCase(ApiTestBase): self.assertPlaylistCountEqual(2) # delete unowned when admin - with db_session: - playlist = Playlist.get(lambda p: p.user.name == "bob") + playlist = Playlist.select().join(User).where(User.name == "bob").get() self._make_request("deletePlaylist", {"id": str(playlist.id)}, skip_post=True) self.assertPlaylistCountEqual(1) @@ -238,12 +239,13 @@ class PlaylistTestCase(ApiTestBase): "updatePlaylist", {"playlistId": str(uuid.uuid4())}, error=70 ) - with db_session: - playlist = ( - Playlist.select(lambda p: p.user.name == "alice") - .order_by(Playlist.created) - .first() - ) + playlist = ( + Playlist.select() + .join(User) + .where(User.name == "alice") + .order_by(Playlist.created) + .first() + ) pid = str(playlist.id) self._make_request( "updatePlaylist", {"playlistId": pid, "songIdToAdd": "string"}, error=0 @@ -288,8 +290,7 @@ class PlaylistTestCase(ApiTestBase): self.assertEqual(self._xpath(child, "count(./entry)"), 1) self.assertEqual(self._find(child, "./entry").get("title"), "Three") - with db_session: - songs = {s.title: str(s.id) for s in Track.select()} + songs = {s.title: str(s.id) for s in Track.select()} self._make_request( "updatePlaylist",