1
0
mirror of https://github.com/spl0k/supysonic.git synced 2024-11-09 19:52:16 +00:00

Refactoring API error handling

This commit is contained in:
spl0k 2018-02-25 11:39:26 +01:00
parent 86892f375d
commit 177b0cce0d
11 changed files with 189 additions and 169 deletions

View File

@ -30,6 +30,7 @@ from pony.orm import ObjectNotFound
from ..managers.user import UserManager
from ..py23 import dict
from .exceptions import Unauthorized
from .formatters import JSONFormatter, JSONPFormatter, XMLFormatter
api = Blueprint('api', __name__)
@ -56,33 +57,28 @@ def decode_password(password):
@api.before_request
def authorize():
error = request.formatter.error(40, 'Unauthorized'), 401
if request.authorization:
status, user = UserManager.try_auth(request.authorization.username, request.authorization.password)
if status == UserManager.SUCCESS:
request.username = request.authorization.username
request.user = user
return
raise Unauthorized()
(username, password) = map(request.values.get, [ 'u', 'p' ])
if not username or not password:
return error
username = request.values['u']
password = request.values['p']
password = decode_password(password)
status, user = UserManager.try_auth(username, password)
if status != UserManager.SUCCESS:
return error
raise Unauthorized()
request.username = username
request.user = user
@api.before_request
def get_client_prefs():
if 'c' not in request.values:
return request.formatter.error(10, 'Missing required parameter')
client = request.values.get('c')
client = request.values['c']
try:
ClientPrefs[request.user, client]
except ObjectNotFound:
@ -90,24 +86,13 @@ def get_client_prefs():
request.client = client
#@api.errorhandler(404)
@api.route('/<path:invalid>', methods = [ 'GET', 'POST' ]) # blueprint 404 workaround
def not_found(*args, **kwargs):
return request.formatter.error(0, 'Not implemented'), 501
def get_entity(cls, param = 'id'):
eid = request.values.get(param)
if not eid:
return False, request.formatter.error(10, 'Missing %s id' % cls.__name__)
eid = request.values[param]
eid = uuid.UUID(eid)
entity = cls[eid]
return entity
try:
eid = uuid.UUID(eid)
entity = cls[eid]
return True, entity
except ValueError:
return False, request.formatter.error(0, 'Invalid %s id' % cls.__name__)
except ObjectNotFound:
return False, (request.formatter.error(70, '%s not found' % cls.__name__), 404)
from .errors import *
from .system import *
from .browse import *

View File

@ -35,13 +35,10 @@ def rand_songs():
size = request.values.get('size', '10')
genre, fromYear, toYear, musicFolderId = map(request.values.get, [ 'genre', 'fromYear', 'toYear', 'musicFolderId' ])
try:
size = int(size) if size else 10
fromYear = int(fromYear) if fromYear else None
toYear = int(toYear) if toYear else None
fid = uuid.UUID(musicFolderId) if musicFolderId else None
except ValueError:
return request.formatter.error(0, 'Invalid parameter format')
size = int(size) if size else 10
fromYear = int(fromYear) if fromYear else None
toYear = int(toYear) if toYear else None
fid = uuid.UUID(musicFolderId) if musicFolderId else None
query = Track.select()
if fromYear:
@ -65,11 +62,9 @@ def album_list():
ltype, size, offset = map(request.values.get, [ 'type', 'size', 'offset' ])
if not ltype:
return request.formatter.error(10, 'Missing type')
try:
size = int(size) if size else 10
offset = int(offset) if offset else 0
except ValueError:
return request.formatter.error(0, 'Invalid parameter format')
size = int(size) if size else 10
offset = int(offset) if offset else 0
query = select(t.folder for t in Track)
if ltype == 'random':
@ -102,11 +97,9 @@ def album_list_id3():
ltype, size, offset = map(request.values.get, [ 'type', 'size', 'offset' ])
if not ltype:
return request.formatter.error(10, 'Missing type')
try:
size = int(size) if size else 10
offset = int(offset) if offset else 0
except ValueError:
return request.formatter.error(0, 'Invalid parameter format')
size = int(size) if size else 10
offset = int(offset) if offset else 0
query = Album.select()
if ltype == 'random':

View File

@ -175,10 +175,7 @@ def rate():
@api.route('/scrobble.view', methods = [ 'GET', 'POST' ])
def scrobble():
status, res = get_entity(Track)
if not status:
return res
res = get_entity(Track)
t, submission = map(request.values.get, [ 'time', 'submission' ])
if t:

View File

@ -43,23 +43,16 @@ def list_indexes():
musicFolderId = request.values.get('musicFolderId')
ifModifiedSince = request.values.get('ifModifiedSince')
if ifModifiedSince:
try:
ifModifiedSince = int(ifModifiedSince) / 1000
except ValueError:
return request.formatter.error(0, 'Invalid timestamp')
ifModifiedSince = int(ifModifiedSince) / 1000
if musicFolderId is None:
folders = Folder.select(lambda f: f.root)[:]
else:
try:
mfid = uuid.UUID(musicFolderId)
folder = Folder[mfid]
except ValueError:
return request.formatter.error(0, 'Invalid id')
except ObjectNotFound:
return request.formatter.error(70, 'Folder not found')
mfid = uuid.UUID(musicFolderId)
folder = Folder[mfid]
if not folder.root:
return request.formatter.error(70, 'Folder not found')
raise ObjectNotFound(Folder, mfid)
folders = [ folder ]
last_modif = max(map(lambda f: f.last_scan, folders))
@ -100,10 +93,7 @@ def list_indexes():
@api.route('/getMusicDirectory.view', methods = [ 'GET', 'POST' ])
def show_directory():
status, res = get_entity(Folder)
if not status:
return res
res = get_entity(Folder)
directory = dict(
id = str(res.id),
name = res.name,
@ -139,10 +129,7 @@ def list_artists():
@api.route('/getArtist.view', methods = [ 'GET', 'POST' ])
def artist_info():
status, res = get_entity(Artist)
if not status:
return res
res = get_entity(Artist)
info = res.as_subsonic_artist(request.user)
albums = set(res.albums)
albums |= { t.album for t in res.tracks }
@ -152,10 +139,7 @@ def artist_info():
@api.route('/getAlbum.view', methods = [ 'GET', 'POST' ])
def album_info():
status, res = get_entity(Album)
if not status:
return res
res = get_entity(Album)
info = res.as_subsonic_album(request.user)
info['song'] = [ t.as_subsonic_child(request.user, request.client) for t in sorted(res.tracks, key = lambda t: t.sort_key()) ]
@ -163,10 +147,7 @@ def album_info():
@api.route('/getSong.view', methods = [ 'GET', 'POST' ])
def track_info():
status, res = get_entity(Track)
if not status:
return res
res = get_entity(Track)
return request.formatter('song', res.as_subsonic_child(request.user, request.client))
@api.route('/getVideos.view', methods = [ 'GET', 'POST' ])

View File

@ -27,10 +27,7 @@ from . import api
@api.route('/getChatMessages.view', methods = [ 'GET', 'POST' ])
def get_chat():
since = request.values.get('since')
try:
since = int(since) / 1000 if since else None
except ValueError:
return request.formatter.error(0, 'Invalid parameter')
since = int(since) / 1000 if since else None
query = ChatMessage.select().order_by(ChatMessage.time)
if since:

45
supysonic/api/errors.py Normal file
View File

@ -0,0 +1,45 @@
# coding: utf-8
#
# This file is part of Supysonic.
# Supysonic is a Python implementation of the Subsonic server API.
#
# Copyright (C) 2018 Alban 'spl0k' Féron
#
# Distributed under terms of the GNU AGPLv3 license.
from flask import current_app
from pony.orm import rollback
from pony.orm import ObjectNotFound
from werkzeug.exceptions import BadRequestKeyError
from . import api
from .exceptions import GenericError, MissingParameter, NotFound
@api.errorhandler(ValueError)
def value_error(e):
rollback()
return GenericError(e)
@api.errorhandler(BadRequestKeyError)
def key_error(e):
rollback()
return MissingParameter(e.args[0])
@api.errorhandler(ObjectNotFound)
def not_found(e):
rollback()
return NotFound(e.entity.__name__)
@api.errorhandler(Exception)
def generic_error(e): # pragma: nocover
rollback()
if not current_app.testing:
return GenericError(e), 500
else:
raise e
#@api.errorhandler(404)
@api.route('/<path:invalid>', methods = [ 'GET', 'POST' ]) # blueprint 404 workaround
def not_found(*args, **kwargs):
return GenericError('Not implemented'), 501

View File

@ -0,0 +1,74 @@
# coding: utf-8
#
# This file is part of Supysonic.
# Supysonic is a Python implementation of the Subsonic server API.
#
# Copyright (C) 2018 Alban 'spl0k' Féron
#
# Distributed under terms of the GNU AGPLv3 license.
from flask import request
from werkzeug.exceptions import HTTPException
class SubsonicAPIError(HTTPException):
code = 400
api_code = None
message = None
def get_response(self, environ = None):
rv = request.formatter.error(self.api_code, self.message)
rv.status_code = self.code
return rv
def __str__(self):
code = self.api_code if self.api_code is not None else '??'
return '{}: {}'.format(code, self.message)
class GenericError(SubsonicAPIError):
api_code = 0
def __init__(self, message, *args, **kwargs):
super(GenericError, self).__init__(*args, **kwargs)
self.message = message
class MissingParameter(SubsonicAPIError):
api_code = 10
def __init__(self, param, *args, **kwargs):
super(MissingParameter, self).__init__(*args, **kwargs)
self.message = "Required parameter '{}' is missing.".format(param)
class ClientMustUpgrade(SubsonicAPIError):
api_code = 20
message = 'Incompatible Subsonic REST protocol version. Client must upgrade.'
class ServerMustUpgrade(SubsonicAPIError):
code = 501
api_code = 30
message = 'Incompatible Subsonic REST protocol version. Server must upgrade.'
class Unauthorized(SubsonicAPIError):
code = 401
api_code = 40
message = 'Wrong username or password.'
class Forbidden(SubsonicAPIError):
code = 403
api_code = 50
message = 'User is not authorized for the given operation.'
class TrialExpired(SubsonicAPIError):
code = 402
api_code = 60
message = ("The trial period for the Supysonic server is over."
"But since it doesn't use any licensing you shouldn't be seeing this error ever."
"So something went wrong or you got scammed.")
class NotFound(SubsonicAPIError):
code = 404
api_code = 70
def __init__(self, entity, *args, **kwargs):
super(NotFound, self).__init__(*args, **kwargs)
self.message = '{} not found'.format(entity)

View File

@ -47,9 +47,7 @@ def prepare_transcoding_cmdline(base_cmdline, input_file, input_format, output_f
@api.route('/stream.view', methods = [ 'GET', 'POST' ])
def stream_media():
status, res = get_entity(Track)
if not status:
return res
res = get_entity(Track)
maxBitRate, format, timeOffset, size, estimateContentLength = map(request.values.get, [ 'maxBitRate', 'format', 'timeOffset', 'size', 'estimateContentLength' ])
if format:
@ -67,10 +65,7 @@ def stream_media():
dst_bitrate = prefs.bitrate
if maxBitRate:
try:
maxBitRate = int(maxBitRate)
except ValueError:
return request.formatter.error(0, 'Invalid bitrate value')
maxBitRate = int(maxBitRate)
if dst_bitrate > maxBitRate and maxBitRate != 0:
dst_bitrate = maxBitRate
@ -133,27 +128,18 @@ def stream_media():
@api.route('/download.view', methods = [ 'GET', 'POST' ])
def download_media():
status, res = get_entity(Track)
if not status:
return res
res = get_entity(Track)
return send_file(res.path, mimetype = res.content_type, conditional=True)
@api.route('/getCoverArt.view', methods = [ 'GET', 'POST' ])
def cover_art():
status, res = get_entity(Folder)
if not status:
return res
res = get_entity(Folder)
if not res.has_cover_art or not os.path.isfile(os.path.join(res.path, 'cover.jpg')):
return request.formatter.error(70, 'Cover art not found')
size = request.values.get('size')
if size:
try:
size = int(size)
except ValueError:
return request.formatter.error(0, 'Invalid size value')
size = int(size)
else:
return send_file(os.path.join(res.path, 'cover.jpg'))

View File

@ -21,8 +21,6 @@
import uuid
from flask import request
from pony.orm import rollback
from pony.orm import ObjectNotFound
from ..db import Playlist, User, Track
from ..py23 import dict
@ -48,10 +46,7 @@ def list_playlists():
@api.route('/getPlaylist.view', methods = [ 'GET', 'POST' ])
def show_playlist():
status, res = get_entity(Playlist)
if not status:
return res
res = get_entity(Playlist)
if res.user.id != request.user.id and not request.user.admin:
return request.formatter.error('50', 'Private playlist')
@ -64,16 +59,10 @@ def create_playlist():
playlist_id, name = map(request.values.get, [ 'playlistId', 'name' ])
# songId actually doesn't seem to be required
songs = request.values.getlist('songId')
try:
playlist_id = uuid.UUID(playlist_id) if playlist_id else None
except ValueError:
return request.formatter.error(0, 'Invalid playlist id')
playlist_id = uuid.UUID(playlist_id) if playlist_id else None
if playlist_id:
try:
playlist = Playlist[playlist_id]
except ObjectNotFound:
return request.formatter.error(70, 'Unknwon playlist')
playlist = Playlist[playlist_id]
if playlist.user.id != request.user.id and not request.user.admin:
return request.formatter.error(50, "You're not allowed to modify a playlist that isn't yours")
@ -86,26 +75,16 @@ def create_playlist():
else:
return request.formatter.error(10, 'Missing playlist id or name')
try:
songs = map(uuid.UUID, songs)
for sid in songs:
track = Track[sid]
playlist.add(track)
except ValueError:
rollback()
return request.formatter.error(0, 'Invalid song id')
except ObjectNotFound:
rollback()
return request.formatter.error(70, 'Unknown song')
songs = map(uuid.UUID, songs)
for sid in songs:
track = Track[sid]
playlist.add(track)
return request.formatter.empty
@api.route('/deletePlaylist.view', methods = [ 'GET', 'POST' ])
def delete_playlist():
status, res = get_entity(Playlist)
if not status:
return res
res = get_entity(Playlist)
if res.user.id != request.user.id and not request.user.admin:
return request.formatter.error(50, "You're not allowed to delete a playlist that isn't yours")
@ -114,10 +93,7 @@ def delete_playlist():
@api.route('/updatePlaylist.view', methods = [ 'GET', 'POST' ])
def update_playlist():
status, res = get_entity(Playlist, 'playlistId')
if not status:
return res
res = get_entity(Playlist, 'playlistId')
if res.user.id != request.user.id and not request.user.admin:
return request.formatter.error(50, "You're not allowed to delete a playlist that isn't yours")
@ -132,19 +108,14 @@ def update_playlist():
if public:
playlist.public = public in (True, 'True', 'true', 1, '1')
try:
to_add = map(uuid.UUID, to_add)
to_remove = map(int, to_remove)
to_add = map(uuid.UUID, to_add)
to_remove = map(int, to_remove)
for sid in to_add:
track = Track[sid]
playlist.add(track)
for sid in to_add:
track = Track[sid]
playlist.add(track)
playlist.remove_at_indexes(to_remove)
except ValueError:
return request.formatter.error(0, 'Invalid parameter')
except ObjectNotFound:
return request.formatter.error(70, 'Unknown song')
playlist.remove_at_indexes(to_remove)
return request.formatter.empty

View File

@ -30,13 +30,10 @@ from . import api
@api.route('/search.view', methods = [ 'GET', 'POST' ])
def old_search():
artist, album, title, anyf, count, offset, newer_than = map(request.values.get, [ 'artist', 'album', 'title', 'any', 'count', 'offset', 'newerThan' ])
try:
count = int(count) if count else 20
offset = int(offset) if offset else 0
newer_than = int(newer_than) / 1000 if newer_than else 0
except ValueError:
return request.formatter.error(0, 'Invalid parameter')
count = int(count) if count else 20
offset = int(offset) if offset else 0
newer_than = int(newer_than) / 1000 if newer_than else 0
min_date = datetime.fromtimestamp(newer_than)
if artist:
@ -74,15 +71,12 @@ def new_search():
query, artist_count, artist_offset, album_count, album_offset, song_count, song_offset = map(
request.values.get, [ 'query', 'artistCount', 'artistOffset', 'albumCount', 'albumOffset', 'songCount', 'songOffset' ])
try:
artist_count = int(artist_count) if artist_count else 20
artist_offset = int(artist_offset) if artist_offset else 0
album_count = int(album_count) if album_count else 20
album_offset = int(album_offset) if album_offset else 0
song_count = int(song_count) if song_count else 20
song_offset = int(song_offset) if song_offset else 0
except ValueError:
return request.formatter.error(0, 'Invalid parameter')
artist_count = int(artist_count) if artist_count else 20
artist_offset = int(artist_offset) if artist_offset else 0
album_count = int(album_count) if album_count else 20
album_offset = int(album_offset) if album_offset else 0
song_count = int(song_count) if song_count else 20
song_offset = int(song_offset) if song_offset else 0
if not query:
return request.formatter.error(10, 'Missing query parameter')
@ -102,15 +96,12 @@ def search_id3():
query, artist_count, artist_offset, album_count, album_offset, song_count, song_offset = map(
request.values.get, [ 'query', 'artistCount', 'artistOffset', 'albumCount', 'albumOffset', 'songCount', 'songOffset' ])
try:
artist_count = int(artist_count) if artist_count else 20
artist_offset = int(artist_offset) if artist_offset else 0
album_count = int(album_count) if album_count else 20
album_offset = int(album_offset) if album_offset else 0
song_count = int(song_count) if song_count else 20
song_offset = int(song_offset) if song_offset else 0
except ValueError:
return request.formatter.error(0, 'Invalid parameter')
artist_count = int(artist_count) if artist_count else 20
artist_offset = int(artist_offset) if artist_offset else 0
album_count = int(album_count) if album_count else 20
album_offset = int(album_offset) if album_offset else 0
song_count = int(song_count) if song_count else 20
song_offset = int(song_offset) if song_offset else 0
if not query:
return request.formatter.error(10, 'Missing query parameter')

View File

@ -63,9 +63,9 @@ class ApiSetupTestCase(TestBase):
def test_auth_basic(self):
# No auth info
rv = self.client.get('/rest/ping.view?c=tests')
self.assertEqual(rv.status_code, 401)
self.assertEqual(rv.status_code, 400)
self.assertIn('status="failed"', rv.data)
self.assertIn('code="40"', rv.data)
self.assertIn('code="10"', rv.data)
self.__test_auth(self.__basic_auth_get)