1
0
mirror of https://github.com/spl0k/supysonic.git synced 2024-12-22 17:06:17 +00:00

Refactoring API error handling

This commit is contained in:
spl0k 2018-02-25 11:39:26 +01:00
parent 86892f375d
commit 177b0cce0d
11 changed files with 189 additions and 169 deletions

View File

@ -30,6 +30,7 @@ from pony.orm import ObjectNotFound
from ..managers.user import UserManager from ..managers.user import UserManager
from ..py23 import dict from ..py23 import dict
from .exceptions import Unauthorized
from .formatters import JSONFormatter, JSONPFormatter, XMLFormatter from .formatters import JSONFormatter, JSONPFormatter, XMLFormatter
api = Blueprint('api', __name__) api = Blueprint('api', __name__)
@ -56,33 +57,28 @@ def decode_password(password):
@api.before_request @api.before_request
def authorize(): def authorize():
error = request.formatter.error(40, 'Unauthorized'), 401
if request.authorization: if request.authorization:
status, user = UserManager.try_auth(request.authorization.username, request.authorization.password) status, user = UserManager.try_auth(request.authorization.username, request.authorization.password)
if status == UserManager.SUCCESS: if status == UserManager.SUCCESS:
request.username = request.authorization.username request.username = request.authorization.username
request.user = user request.user = user
return return
raise Unauthorized()
(username, password) = map(request.values.get, [ 'u', 'p' ]) username = request.values['u']
if not username or not password: password = request.values['p']
return error
password = decode_password(password) password = decode_password(password)
status, user = UserManager.try_auth(username, password) status, user = UserManager.try_auth(username, password)
if status != UserManager.SUCCESS: if status != UserManager.SUCCESS:
return error raise Unauthorized()
request.username = username request.username = username
request.user = user request.user = user
@api.before_request @api.before_request
def get_client_prefs(): def get_client_prefs():
if 'c' not in request.values: client = request.values['c']
return request.formatter.error(10, 'Missing required parameter')
client = request.values.get('c')
try: try:
ClientPrefs[request.user, client] ClientPrefs[request.user, client]
except ObjectNotFound: except ObjectNotFound:
@ -90,24 +86,13 @@ def get_client_prefs():
request.client = client request.client = client
#@api.errorhandler(404)
@api.route('/<path:invalid>', methods = [ 'GET', 'POST' ]) # blueprint 404 workaround
def not_found(*args, **kwargs):
return request.formatter.error(0, 'Not implemented'), 501
def get_entity(cls, param = 'id'): def get_entity(cls, param = 'id'):
eid = request.values.get(param) eid = request.values[param]
if not eid: eid = uuid.UUID(eid)
return False, request.formatter.error(10, 'Missing %s id' % cls.__name__) entity = cls[eid]
return entity
try: from .errors import *
eid = uuid.UUID(eid)
entity = cls[eid]
return True, entity
except ValueError:
return False, request.formatter.error(0, 'Invalid %s id' % cls.__name__)
except ObjectNotFound:
return False, (request.formatter.error(70, '%s not found' % cls.__name__), 404)
from .system import * from .system import *
from .browse import * from .browse import *

View File

@ -35,13 +35,10 @@ def rand_songs():
size = request.values.get('size', '10') size = request.values.get('size', '10')
genre, fromYear, toYear, musicFolderId = map(request.values.get, [ 'genre', 'fromYear', 'toYear', 'musicFolderId' ]) genre, fromYear, toYear, musicFolderId = map(request.values.get, [ 'genre', 'fromYear', 'toYear', 'musicFolderId' ])
try: size = int(size) if size else 10
size = int(size) if size else 10 fromYear = int(fromYear) if fromYear else None
fromYear = int(fromYear) if fromYear else None toYear = int(toYear) if toYear else None
toYear = int(toYear) if toYear else None fid = uuid.UUID(musicFolderId) if musicFolderId else None
fid = uuid.UUID(musicFolderId) if musicFolderId else None
except ValueError:
return request.formatter.error(0, 'Invalid parameter format')
query = Track.select() query = Track.select()
if fromYear: if fromYear:
@ -65,11 +62,9 @@ def album_list():
ltype, size, offset = map(request.values.get, [ 'type', 'size', 'offset' ]) ltype, size, offset = map(request.values.get, [ 'type', 'size', 'offset' ])
if not ltype: if not ltype:
return request.formatter.error(10, 'Missing type') return request.formatter.error(10, 'Missing type')
try:
size = int(size) if size else 10 size = int(size) if size else 10
offset = int(offset) if offset else 0 offset = int(offset) if offset else 0
except ValueError:
return request.formatter.error(0, 'Invalid parameter format')
query = select(t.folder for t in Track) query = select(t.folder for t in Track)
if ltype == 'random': if ltype == 'random':
@ -102,11 +97,9 @@ def album_list_id3():
ltype, size, offset = map(request.values.get, [ 'type', 'size', 'offset' ]) ltype, size, offset = map(request.values.get, [ 'type', 'size', 'offset' ])
if not ltype: if not ltype:
return request.formatter.error(10, 'Missing type') return request.formatter.error(10, 'Missing type')
try:
size = int(size) if size else 10 size = int(size) if size else 10
offset = int(offset) if offset else 0 offset = int(offset) if offset else 0
except ValueError:
return request.formatter.error(0, 'Invalid parameter format')
query = Album.select() query = Album.select()
if ltype == 'random': if ltype == 'random':

View File

@ -175,10 +175,7 @@ def rate():
@api.route('/scrobble.view', methods = [ 'GET', 'POST' ]) @api.route('/scrobble.view', methods = [ 'GET', 'POST' ])
def scrobble(): def scrobble():
status, res = get_entity(Track) res = get_entity(Track)
if not status:
return res
t, submission = map(request.values.get, [ 'time', 'submission' ]) t, submission = map(request.values.get, [ 'time', 'submission' ])
if t: if t:

View File

@ -43,23 +43,16 @@ def list_indexes():
musicFolderId = request.values.get('musicFolderId') musicFolderId = request.values.get('musicFolderId')
ifModifiedSince = request.values.get('ifModifiedSince') ifModifiedSince = request.values.get('ifModifiedSince')
if ifModifiedSince: if ifModifiedSince:
try: ifModifiedSince = int(ifModifiedSince) / 1000
ifModifiedSince = int(ifModifiedSince) / 1000
except ValueError:
return request.formatter.error(0, 'Invalid timestamp')
if musicFolderId is None: if musicFolderId is None:
folders = Folder.select(lambda f: f.root)[:] folders = Folder.select(lambda f: f.root)[:]
else: else:
try: mfid = uuid.UUID(musicFolderId)
mfid = uuid.UUID(musicFolderId) folder = Folder[mfid]
folder = Folder[mfid]
except ValueError:
return request.formatter.error(0, 'Invalid id')
except ObjectNotFound:
return request.formatter.error(70, 'Folder not found')
if not folder.root: if not folder.root:
return request.formatter.error(70, 'Folder not found') raise ObjectNotFound(Folder, mfid)
folders = [ folder ] folders = [ folder ]
last_modif = max(map(lambda f: f.last_scan, folders)) last_modif = max(map(lambda f: f.last_scan, folders))
@ -100,10 +93,7 @@ def list_indexes():
@api.route('/getMusicDirectory.view', methods = [ 'GET', 'POST' ]) @api.route('/getMusicDirectory.view', methods = [ 'GET', 'POST' ])
def show_directory(): def show_directory():
status, res = get_entity(Folder) res = get_entity(Folder)
if not status:
return res
directory = dict( directory = dict(
id = str(res.id), id = str(res.id),
name = res.name, name = res.name,
@ -139,10 +129,7 @@ def list_artists():
@api.route('/getArtist.view', methods = [ 'GET', 'POST' ]) @api.route('/getArtist.view', methods = [ 'GET', 'POST' ])
def artist_info(): def artist_info():
status, res = get_entity(Artist) res = get_entity(Artist)
if not status:
return res
info = res.as_subsonic_artist(request.user) info = res.as_subsonic_artist(request.user)
albums = set(res.albums) albums = set(res.albums)
albums |= { t.album for t in res.tracks } albums |= { t.album for t in res.tracks }
@ -152,10 +139,7 @@ def artist_info():
@api.route('/getAlbum.view', methods = [ 'GET', 'POST' ]) @api.route('/getAlbum.view', methods = [ 'GET', 'POST' ])
def album_info(): def album_info():
status, res = get_entity(Album) res = get_entity(Album)
if not status:
return res
info = res.as_subsonic_album(request.user) info = res.as_subsonic_album(request.user)
info['song'] = [ t.as_subsonic_child(request.user, request.client) for t in sorted(res.tracks, key = lambda t: t.sort_key()) ] info['song'] = [ t.as_subsonic_child(request.user, request.client) for t in sorted(res.tracks, key = lambda t: t.sort_key()) ]
@ -163,10 +147,7 @@ def album_info():
@api.route('/getSong.view', methods = [ 'GET', 'POST' ]) @api.route('/getSong.view', methods = [ 'GET', 'POST' ])
def track_info(): def track_info():
status, res = get_entity(Track) res = get_entity(Track)
if not status:
return res
return request.formatter('song', res.as_subsonic_child(request.user, request.client)) return request.formatter('song', res.as_subsonic_child(request.user, request.client))
@api.route('/getVideos.view', methods = [ 'GET', 'POST' ]) @api.route('/getVideos.view', methods = [ 'GET', 'POST' ])

View File

@ -27,10 +27,7 @@ from . import api
@api.route('/getChatMessages.view', methods = [ 'GET', 'POST' ]) @api.route('/getChatMessages.view', methods = [ 'GET', 'POST' ])
def get_chat(): def get_chat():
since = request.values.get('since') since = request.values.get('since')
try: since = int(since) / 1000 if since else None
since = int(since) / 1000 if since else None
except ValueError:
return request.formatter.error(0, 'Invalid parameter')
query = ChatMessage.select().order_by(ChatMessage.time) query = ChatMessage.select().order_by(ChatMessage.time)
if since: if since:

45
supysonic/api/errors.py Normal file
View File

@ -0,0 +1,45 @@
# coding: utf-8
#
# This file is part of Supysonic.
# Supysonic is a Python implementation of the Subsonic server API.
#
# Copyright (C) 2018 Alban 'spl0k' Féron
#
# Distributed under terms of the GNU AGPLv3 license.
from flask import current_app
from pony.orm import rollback
from pony.orm import ObjectNotFound
from werkzeug.exceptions import BadRequestKeyError
from . import api
from .exceptions import GenericError, MissingParameter, NotFound
@api.errorhandler(ValueError)
def value_error(e):
rollback()
return GenericError(e)
@api.errorhandler(BadRequestKeyError)
def key_error(e):
rollback()
return MissingParameter(e.args[0])
@api.errorhandler(ObjectNotFound)
def not_found(e):
rollback()
return NotFound(e.entity.__name__)
@api.errorhandler(Exception)
def generic_error(e): # pragma: nocover
rollback()
if not current_app.testing:
return GenericError(e), 500
else:
raise e
#@api.errorhandler(404)
@api.route('/<path:invalid>', methods = [ 'GET', 'POST' ]) # blueprint 404 workaround
def not_found(*args, **kwargs):
return GenericError('Not implemented'), 501

View File

@ -0,0 +1,74 @@
# coding: utf-8
#
# This file is part of Supysonic.
# Supysonic is a Python implementation of the Subsonic server API.
#
# Copyright (C) 2018 Alban 'spl0k' Féron
#
# Distributed under terms of the GNU AGPLv3 license.
from flask import request
from werkzeug.exceptions import HTTPException
class SubsonicAPIError(HTTPException):
code = 400
api_code = None
message = None
def get_response(self, environ = None):
rv = request.formatter.error(self.api_code, self.message)
rv.status_code = self.code
return rv
def __str__(self):
code = self.api_code if self.api_code is not None else '??'
return '{}: {}'.format(code, self.message)
class GenericError(SubsonicAPIError):
api_code = 0
def __init__(self, message, *args, **kwargs):
super(GenericError, self).__init__(*args, **kwargs)
self.message = message
class MissingParameter(SubsonicAPIError):
api_code = 10
def __init__(self, param, *args, **kwargs):
super(MissingParameter, self).__init__(*args, **kwargs)
self.message = "Required parameter '{}' is missing.".format(param)
class ClientMustUpgrade(SubsonicAPIError):
api_code = 20
message = 'Incompatible Subsonic REST protocol version. Client must upgrade.'
class ServerMustUpgrade(SubsonicAPIError):
code = 501
api_code = 30
message = 'Incompatible Subsonic REST protocol version. Server must upgrade.'
class Unauthorized(SubsonicAPIError):
code = 401
api_code = 40
message = 'Wrong username or password.'
class Forbidden(SubsonicAPIError):
code = 403
api_code = 50
message = 'User is not authorized for the given operation.'
class TrialExpired(SubsonicAPIError):
code = 402
api_code = 60
message = ("The trial period for the Supysonic server is over."
"But since it doesn't use any licensing you shouldn't be seeing this error ever."
"So something went wrong or you got scammed.")
class NotFound(SubsonicAPIError):
code = 404
api_code = 70
def __init__(self, entity, *args, **kwargs):
super(NotFound, self).__init__(*args, **kwargs)
self.message = '{} not found'.format(entity)

View File

@ -47,9 +47,7 @@ def prepare_transcoding_cmdline(base_cmdline, input_file, input_format, output_f
@api.route('/stream.view', methods = [ 'GET', 'POST' ]) @api.route('/stream.view', methods = [ 'GET', 'POST' ])
def stream_media(): def stream_media():
status, res = get_entity(Track) res = get_entity(Track)
if not status:
return res
maxBitRate, format, timeOffset, size, estimateContentLength = map(request.values.get, [ 'maxBitRate', 'format', 'timeOffset', 'size', 'estimateContentLength' ]) maxBitRate, format, timeOffset, size, estimateContentLength = map(request.values.get, [ 'maxBitRate', 'format', 'timeOffset', 'size', 'estimateContentLength' ])
if format: if format:
@ -67,10 +65,7 @@ def stream_media():
dst_bitrate = prefs.bitrate dst_bitrate = prefs.bitrate
if maxBitRate: if maxBitRate:
try: maxBitRate = int(maxBitRate)
maxBitRate = int(maxBitRate)
except ValueError:
return request.formatter.error(0, 'Invalid bitrate value')
if dst_bitrate > maxBitRate and maxBitRate != 0: if dst_bitrate > maxBitRate and maxBitRate != 0:
dst_bitrate = maxBitRate dst_bitrate = maxBitRate
@ -133,27 +128,18 @@ def stream_media():
@api.route('/download.view', methods = [ 'GET', 'POST' ]) @api.route('/download.view', methods = [ 'GET', 'POST' ])
def download_media(): def download_media():
status, res = get_entity(Track) res = get_entity(Track)
if not status:
return res
return send_file(res.path, mimetype = res.content_type, conditional=True) return send_file(res.path, mimetype = res.content_type, conditional=True)
@api.route('/getCoverArt.view', methods = [ 'GET', 'POST' ]) @api.route('/getCoverArt.view', methods = [ 'GET', 'POST' ])
def cover_art(): def cover_art():
status, res = get_entity(Folder) res = get_entity(Folder)
if not status:
return res
if not res.has_cover_art or not os.path.isfile(os.path.join(res.path, 'cover.jpg')): if not res.has_cover_art or not os.path.isfile(os.path.join(res.path, 'cover.jpg')):
return request.formatter.error(70, 'Cover art not found') return request.formatter.error(70, 'Cover art not found')
size = request.values.get('size') size = request.values.get('size')
if size: if size:
try: size = int(size)
size = int(size)
except ValueError:
return request.formatter.error(0, 'Invalid size value')
else: else:
return send_file(os.path.join(res.path, 'cover.jpg')) return send_file(os.path.join(res.path, 'cover.jpg'))

View File

@ -21,8 +21,6 @@
import uuid import uuid
from flask import request from flask import request
from pony.orm import rollback
from pony.orm import ObjectNotFound
from ..db import Playlist, User, Track from ..db import Playlist, User, Track
from ..py23 import dict from ..py23 import dict
@ -48,10 +46,7 @@ def list_playlists():
@api.route('/getPlaylist.view', methods = [ 'GET', 'POST' ]) @api.route('/getPlaylist.view', methods = [ 'GET', 'POST' ])
def show_playlist(): def show_playlist():
status, res = get_entity(Playlist) res = get_entity(Playlist)
if not status:
return res
if res.user.id != request.user.id and not request.user.admin: if res.user.id != request.user.id and not request.user.admin:
return request.formatter.error('50', 'Private playlist') return request.formatter.error('50', 'Private playlist')
@ -64,16 +59,10 @@ def create_playlist():
playlist_id, name = map(request.values.get, [ 'playlistId', 'name' ]) playlist_id, name = map(request.values.get, [ 'playlistId', 'name' ])
# songId actually doesn't seem to be required # songId actually doesn't seem to be required
songs = request.values.getlist('songId') songs = request.values.getlist('songId')
try: playlist_id = uuid.UUID(playlist_id) if playlist_id else None
playlist_id = uuid.UUID(playlist_id) if playlist_id else None
except ValueError:
return request.formatter.error(0, 'Invalid playlist id')
if playlist_id: if playlist_id:
try: playlist = Playlist[playlist_id]
playlist = Playlist[playlist_id]
except ObjectNotFound:
return request.formatter.error(70, 'Unknwon playlist')
if playlist.user.id != request.user.id and not request.user.admin: if playlist.user.id != request.user.id and not request.user.admin:
return request.formatter.error(50, "You're not allowed to modify a playlist that isn't yours") return request.formatter.error(50, "You're not allowed to modify a playlist that isn't yours")
@ -86,26 +75,16 @@ def create_playlist():
else: else:
return request.formatter.error(10, 'Missing playlist id or name') return request.formatter.error(10, 'Missing playlist id or name')
try: songs = map(uuid.UUID, songs)
songs = map(uuid.UUID, songs) for sid in songs:
for sid in songs: track = Track[sid]
track = Track[sid] playlist.add(track)
playlist.add(track)
except ValueError:
rollback()
return request.formatter.error(0, 'Invalid song id')
except ObjectNotFound:
rollback()
return request.formatter.error(70, 'Unknown song')
return request.formatter.empty return request.formatter.empty
@api.route('/deletePlaylist.view', methods = [ 'GET', 'POST' ]) @api.route('/deletePlaylist.view', methods = [ 'GET', 'POST' ])
def delete_playlist(): def delete_playlist():
status, res = get_entity(Playlist) res = get_entity(Playlist)
if not status:
return res
if res.user.id != request.user.id and not request.user.admin: if res.user.id != request.user.id and not request.user.admin:
return request.formatter.error(50, "You're not allowed to delete a playlist that isn't yours") return request.formatter.error(50, "You're not allowed to delete a playlist that isn't yours")
@ -114,10 +93,7 @@ def delete_playlist():
@api.route('/updatePlaylist.view', methods = [ 'GET', 'POST' ]) @api.route('/updatePlaylist.view', methods = [ 'GET', 'POST' ])
def update_playlist(): def update_playlist():
status, res = get_entity(Playlist, 'playlistId') res = get_entity(Playlist, 'playlistId')
if not status:
return res
if res.user.id != request.user.id and not request.user.admin: if res.user.id != request.user.id and not request.user.admin:
return request.formatter.error(50, "You're not allowed to delete a playlist that isn't yours") return request.formatter.error(50, "You're not allowed to delete a playlist that isn't yours")
@ -132,19 +108,14 @@ def update_playlist():
if public: if public:
playlist.public = public in (True, 'True', 'true', 1, '1') playlist.public = public in (True, 'True', 'true', 1, '1')
try: to_add = map(uuid.UUID, to_add)
to_add = map(uuid.UUID, to_add) to_remove = map(int, to_remove)
to_remove = map(int, to_remove)
for sid in to_add: for sid in to_add:
track = Track[sid] track = Track[sid]
playlist.add(track) playlist.add(track)
playlist.remove_at_indexes(to_remove) playlist.remove_at_indexes(to_remove)
except ValueError:
return request.formatter.error(0, 'Invalid parameter')
except ObjectNotFound:
return request.formatter.error(70, 'Unknown song')
return request.formatter.empty return request.formatter.empty

View File

@ -30,13 +30,10 @@ from . import api
@api.route('/search.view', methods = [ 'GET', 'POST' ]) @api.route('/search.view', methods = [ 'GET', 'POST' ])
def old_search(): def old_search():
artist, album, title, anyf, count, offset, newer_than = map(request.values.get, [ 'artist', 'album', 'title', 'any', 'count', 'offset', 'newerThan' ]) artist, album, title, anyf, count, offset, newer_than = map(request.values.get, [ 'artist', 'album', 'title', 'any', 'count', 'offset', 'newerThan' ])
try:
count = int(count) if count else 20
offset = int(offset) if offset else 0
newer_than = int(newer_than) / 1000 if newer_than else 0
except ValueError:
return request.formatter.error(0, 'Invalid parameter')
count = int(count) if count else 20
offset = int(offset) if offset else 0
newer_than = int(newer_than) / 1000 if newer_than else 0
min_date = datetime.fromtimestamp(newer_than) min_date = datetime.fromtimestamp(newer_than)
if artist: if artist:
@ -74,15 +71,12 @@ def new_search():
query, artist_count, artist_offset, album_count, album_offset, song_count, song_offset = map( query, artist_count, artist_offset, album_count, album_offset, song_count, song_offset = map(
request.values.get, [ 'query', 'artistCount', 'artistOffset', 'albumCount', 'albumOffset', 'songCount', 'songOffset' ]) request.values.get, [ 'query', 'artistCount', 'artistOffset', 'albumCount', 'albumOffset', 'songCount', 'songOffset' ])
try: artist_count = int(artist_count) if artist_count else 20
artist_count = int(artist_count) if artist_count else 20 artist_offset = int(artist_offset) if artist_offset else 0
artist_offset = int(artist_offset) if artist_offset else 0 album_count = int(album_count) if album_count else 20
album_count = int(album_count) if album_count else 20 album_offset = int(album_offset) if album_offset else 0
album_offset = int(album_offset) if album_offset else 0 song_count = int(song_count) if song_count else 20
song_count = int(song_count) if song_count else 20 song_offset = int(song_offset) if song_offset else 0
song_offset = int(song_offset) if song_offset else 0
except ValueError:
return request.formatter.error(0, 'Invalid parameter')
if not query: if not query:
return request.formatter.error(10, 'Missing query parameter') return request.formatter.error(10, 'Missing query parameter')
@ -102,15 +96,12 @@ def search_id3():
query, artist_count, artist_offset, album_count, album_offset, song_count, song_offset = map( query, artist_count, artist_offset, album_count, album_offset, song_count, song_offset = map(
request.values.get, [ 'query', 'artistCount', 'artistOffset', 'albumCount', 'albumOffset', 'songCount', 'songOffset' ]) request.values.get, [ 'query', 'artistCount', 'artistOffset', 'albumCount', 'albumOffset', 'songCount', 'songOffset' ])
try: artist_count = int(artist_count) if artist_count else 20
artist_count = int(artist_count) if artist_count else 20 artist_offset = int(artist_offset) if artist_offset else 0
artist_offset = int(artist_offset) if artist_offset else 0 album_count = int(album_count) if album_count else 20
album_count = int(album_count) if album_count else 20 album_offset = int(album_offset) if album_offset else 0
album_offset = int(album_offset) if album_offset else 0 song_count = int(song_count) if song_count else 20
song_count = int(song_count) if song_count else 20 song_offset = int(song_offset) if song_offset else 0
song_offset = int(song_offset) if song_offset else 0
except ValueError:
return request.formatter.error(0, 'Invalid parameter')
if not query: if not query:
return request.formatter.error(10, 'Missing query parameter') return request.formatter.error(10, 'Missing query parameter')

View File

@ -63,9 +63,9 @@ class ApiSetupTestCase(TestBase):
def test_auth_basic(self): def test_auth_basic(self):
# No auth info # No auth info
rv = self.client.get('/rest/ping.view?c=tests') rv = self.client.get('/rest/ping.view?c=tests')
self.assertEqual(rv.status_code, 401) self.assertEqual(rv.status_code, 400)
self.assertIn('status="failed"', rv.data) self.assertIn('status="failed"', rv.data)
self.assertIn('code="40"', rv.data) self.assertIn('code="10"', rv.data)
self.__test_auth(self.__basic_auth_get) self.__test_auth(self.__basic_auth_get)