1
0
mirror of https://github.com/spl0k/supysonic.git synced 2024-12-23 01:16:18 +00:00

Merge remote-tracking branch 'pR0Ps/feature/transcode-cache'

This commit is contained in:
spl0k 2019-02-09 15:49:30 +01:00
commit cf846e88ee
9 changed files with 649 additions and 68 deletions

View File

@ -12,6 +12,12 @@
; Optional cache directory. Default: /tmp/supysonic ; Optional cache directory. Default: /tmp/supysonic
cache_dir = /var/supysonic/cache cache_dir = /var/supysonic/cache
; Main cache max size in MB. Default: 512
cache_size = 512
; Transcode cache max size in MB. Default: 1024 (1GB)
transcode_cache_size = 1024
; Optional rotating log file. Default: none ; Optional rotating log file. Default: none
log_file = /var/supysonic/supysonic.log log_file = /var/supysonic/supysonic.log

View File

@ -22,6 +22,7 @@ reqs = [
'Pillow', 'Pillow',
'requests>=1.0.0', 'requests>=1.0.0',
'mutagen>=1.33', 'mutagen>=1.33',
'scandir<2.0.0',
'zipstream' 'zipstream'
] ]
extras = { extras = {

View File

@ -15,6 +15,10 @@ import requests
import shlex import shlex
import subprocess import subprocess
import uuid import uuid
import io
import hashlib
import json
import zlib
from flask import request, Response, send_file from flask import request, Response, send_file
from flask import current_app from flask import current_app
@ -25,6 +29,7 @@ from zipfile import ZIP_DEFLATED
from zipstream import ZipFile from zipstream import ZipFile
from .. import scanner from .. import scanner
from ..cache import CacheMiss
from ..db import Track, Album, Artist, Folder, User, ClientPrefs, now from ..db import Track, Album, Artist, Folder, User, ClientPrefs, now
from ..py23 import dict from ..py23 import dict
@ -78,50 +83,59 @@ def stream_media():
dst_mimetype = mimetypes.guess_type('dummyname.' + dst_suffix, False)[0] or 'application/octet-stream' dst_mimetype = mimetypes.guess_type('dummyname.' + dst_suffix, False)[0] or 'application/octet-stream'
if format != 'raw' and (dst_suffix != src_suffix or dst_bitrate != res.bitrate): if format != 'raw' and (dst_suffix != src_suffix or dst_bitrate != res.bitrate):
config = current_app.config['TRANSCODING'] # Requires transcoding
transcoder = config.get('transcoder_{}_{}'.format(src_suffix, dst_suffix)) cache = current_app.transcode_cache
decoder = config.get('decoder_' + src_suffix) or config.get('decoder') cache_key = "{}-{}.{}".format(res.id, dst_bitrate, dst_suffix)
encoder = config.get('encoder_' + dst_suffix) or config.get('encoder')
if not transcoder and (not decoder or not encoder):
transcoder = config.get('transcoder')
if not transcoder:
message = 'No way to transcode from {} to {}'.format(src_suffix, dst_suffix)
logger.info(message)
raise GenericError(message)
transcoder, decoder, encoder = map(lambda x: prepare_transcoding_cmdline(x, res.path, src_suffix, dst_suffix, dst_bitrate), [ transcoder, decoder, encoder ])
try: try:
if transcoder: response = send_file(cache.get(cache_key), mimetype=dst_mimetype, conditional=True)
dec_proc = None except CacheMiss:
proc = subprocess.Popen(transcoder, stdout = subprocess.PIPE) config = current_app.config['TRANSCODING']
else: transcoder = config.get('transcoder_{}_{}'.format(src_suffix, dst_suffix))
dec_proc = subprocess.Popen(decoder, stdout = subprocess.PIPE) decoder = config.get('decoder_' + src_suffix) or config.get('decoder')
proc = subprocess.Popen(encoder, stdin = dec_proc.stdout, stdout = subprocess.PIPE) encoder = config.get('encoder_' + dst_suffix) or config.get('encoder')
except OSError: if not transcoder and (not decoder or not encoder):
raise ServerError('Error while running the transcoding process') transcoder = config.get('transcoder')
if not transcoder:
message = 'No way to transcode from {} to {}'.format(src_suffix, dst_suffix)
logger.info(message)
raise GenericError(message)
def transcode(): transcoder, decoder, encoder = map(lambda x: prepare_transcoding_cmdline(x, res.path, src_suffix, dst_suffix, dst_bitrate), [ transcoder, decoder, encoder ])
try: try:
while True: if transcoder:
data = proc.stdout.read(8192) dec_proc = None
if not data: proc = subprocess.Popen(transcoder, stdout = subprocess.PIPE)
break else:
yield data dec_proc = subprocess.Popen(decoder, stdout = subprocess.PIPE)
except: # pragma: nocover proc = subprocess.Popen(encoder, stdin = dec_proc.stdout, stdout = subprocess.PIPE)
if dec_proc != None: except OSError:
dec_proc.kill() raise ServerError('Error while running the transcoding process')
proc.kill()
if dec_proc != None: def transcode():
dec_proc.wait() try:
proc.wait() while True:
data = proc.stdout.read(8192)
if not data:
break
yield data
except: # pragma: nocover
if dec_proc != None:
dec_proc.kill()
proc.kill()
raise
finally:
if dec_proc != None:
dec_proc.wait()
proc.wait()
resp_content = cache.set_generated(cache_key, transcode)
logger.info('Transcoding track {0.id} for user {1.id}. Source: {2} at {0.bitrate}kbps. Dest: {3} at {4}kbps'.format(res, request.user, src_suffix, dst_suffix, dst_bitrate)) logger.info('Transcoding track {0.id} for user {1.id}. Source: {2} at {0.bitrate}kbps. Dest: {3} at {4}kbps'.format(res, request.user, src_suffix, dst_suffix, dst_bitrate))
response = Response(transcode(), mimetype = dst_mimetype) response = Response(resp_content, mimetype=dst_mimetype)
if estimateContentLength == 'true': if estimateContentLength == 'true':
response.headers.add('Content-Length', dst_bitrate * 1000 * res.duration // 8) response.headers.add('Content-Length', dst_bitrate * 1000 * res.duration // 8)
else: else:
response = send_file(res.path, mimetype = dst_mimetype, conditional=True) response = send_file(res.path, mimetype=dst_mimetype, conditional=True)
res.play_count = res.play_count + 1 res.play_count = res.play_count + 1
res.last_play = now() res.last_play = now()
@ -159,6 +173,7 @@ def download_media():
@api.route('/getCoverArt.view', methods = [ 'GET', 'POST' ]) @api.route('/getCoverArt.view', methods = [ 'GET', 'POST' ])
def cover_art(): def cover_art():
cache = current_app.cache
eid = request.values['id'] eid = request.values['id']
if Folder.exists(id=eid): if Folder.exists(id=eid):
res = get_entity(Folder) res = get_entity(Folder)
@ -166,18 +181,15 @@ def cover_art():
raise NotFound('Cover art') raise NotFound('Cover art')
cover_path = os.path.join(res.path, res.cover_art) cover_path = os.path.join(res.path, res.cover_art)
elif Track.exists(id=eid): elif Track.exists(id=eid):
embed_cache = os.path.join(current_app.config['WEBAPP']['cache_dir'], 'embeded_art') cache_key = "{}-cover".format(eid)
cover_path = os.path.join(embed_cache, eid) try:
if not os.path.exists(cover_path): cover_path = cache.get(cache_key)
except CacheMiss:
res = get_entity(Track) res = get_entity(Track)
art = res.extract_cover_art() art = res.extract_cover_art()
if not art: if not art:
raise NotFound('Cover art') raise NotFound('Cover art')
#Art found, save to cache cover_path = cache.set(cache_key, art)
if not os.path.exists(embed_cache):
os.makedirs(embed_cache)
with open(cover_path, 'wb') as cover_file:
cover_file.write(art)
else: else:
raise NotFound('Entity') raise NotFound('Entity')
@ -188,19 +200,18 @@ def cover_art():
return send_file(cover_path) return send_file(cover_path)
im = Image.open(cover_path) im = Image.open(cover_path)
mimetype = 'image/{}'.format(im.format.lower())
if size > im.width and size > im.height: if size > im.width and size > im.height:
return send_file(cover_path) return send_file(cover_path, mimetype=mimetype)
size_path = os.path.join(current_app.config['WEBAPP']['cache_dir'], str(size)) cache_key = "{}-cover-{}".format(eid, size)
path = os.path.abspath(os.path.join(size_path, eid)) try:
if os.path.exists(path): return send_file(cache.get(cache_key), mimetype=mimetype)
return send_file(path, mimetype = 'image/' + im.format.lower()) except CacheMiss:
if not os.path.exists(size_path): im.thumbnail([size, size], Image.ANTIALIAS)
os.makedirs(size_path) with cache.set_fileobj(cache_key) as fp:
im.save(fp, im.format)
im.thumbnail([size, size], Image.ANTIALIAS) return send_file(cache.get(cache_key), mimetype=mimetype)
im.save(path, im.format)
return send_file(path, mimetype = 'image/' + im.format.lower())
@api.route('/getLyrics.view', methods = [ 'GET', 'POST' ]) @api.route('/getLyrics.view', methods = [ 'GET', 'POST' ])
def lyrics(): def lyrics():
@ -227,21 +238,37 @@ def lyrics():
value = lyrics value = lyrics
)) ))
# Create a stable, unique, filesystem-compatible identifier for the artist+title
unique = hashlib.md5(json.dumps([x.lower() for x in (artist, title)]).encode('utf-8')).hexdigest()
cache_key = "lyrics-{}".format(unique)
lyrics = dict()
try: try:
r = requests.get("http://api.chartlyrics.com/apiv1.asmx/SearchLyricDirect", lyrics = json.loads(
params = { 'artist': artist, 'song': title }, timeout = 5) zlib.decompress(
root = ElementTree.fromstring(r.content) current_app.cache.get_value(cache_key)
).decode('utf-8')
)
except (CacheMiss, zlib.error, TypeError, ValueError):
try:
r = requests.get("http://api.chartlyrics.com/apiv1.asmx/SearchLyricDirect",
params={'artist': artist, 'song': title}, timeout=5)
root = ElementTree.fromstring(r.content)
ns = { 'cl': 'http://api.chartlyrics.com/' } ns = {'cl': 'http://api.chartlyrics.com/'}
return request.formatter('lyrics', dict( lyrics = dict(
artist = root.find('cl:LyricArtist', namespaces = ns).text, artist = root.find('cl:LyricArtist', namespaces=ns).text,
title = root.find('cl:LyricSong', namespaces = ns).text, title = root.find('cl:LyricSong', namespaces=ns).text,
value = root.find('cl:Lyric', namespaces = ns).text value = root.find('cl:Lyric', namespaces=ns).text
)) )
except requests.exceptions.RequestException as e: # pragma: nocover
logger.warning('Error while requesting the ChartLyrics API: ' + str(e))
return request.formatter('lyrics', dict()) # pragma: nocover current_app.cache.set(
cache_key, zlib.compress(json.dumps(lyrics).encode('utf-8'), 9)
)
except requests.exceptions.RequestException as e: # pragma: nocover
logger.warning('Error while requesting the ChartLyrics API: ' + str(e))
return request.formatter('lyrics', lyrics)
def read_file_as_unicode(path): def read_file_as_unicode(path):
""" Opens a file trying with different encodings and returns the contents as a unicode string """ """ Opens a file trying with different encodings and returns the contents as a unicode string """

231
supysonic/cache.py Normal file
View File

@ -0,0 +1,231 @@
# coding: utf-8
#
# This file is part of Supysonic.
# Supysonic is a Python implementation of the Subsonic server API.
#
# Copyright (C) 2013-2018 Alban 'spl0k' Féron
# 2017 Óscar García Amor
#
# Distributed under terms of the GNU AGPLv3 license.
from collections import OrderedDict, namedtuple
import contextlib
import errno
import logging
import os
import os.path
import tempfile
import threading
from time import time
from .py23 import scandir, osreplace
logger = logging.getLogger(__name__)
class CacheMiss(KeyError):
"""The requested data is not in the cache"""
pass
class ProtectedError(Exception):
"""The data cannot be purged from the cache"""
pass
CacheEntry = namedtuple("CacheEntry", ["size", "expires"])
NULL_ENTRY = CacheEntry(0, 0)
class Cache(object):
"""Provides a common interface for caching files to disk"""
# Modeled after werkzeug.contrib.cache.FileSystemCache
# keys must be filename-compatible strings (no paths)
# values must be bytes (not strings)
def __init__(self, cache_dir, max_size, min_time=300, auto_prune=True):
"""Initialize the cache
cache_dir: The folder to store cached files
max_size: The maximum allowed size of the cache in bytes
min_time: The minimum amount of time a file will be stored in the cache
in seconds (default 300 = 5min)
auto_prune: If True (default) the cache will automatically be pruned to
the max_size when possible.
Note that max_size is not a hard restriction and in some cases will
temporarily be exceeded, even when auto-pruning is turned on.
"""
self._cache_dir = os.path.abspath(cache_dir)
self.min_time = min_time
self.max_size = max_size
self._auto_prune = auto_prune
self._lock = threading.RLock()
# Create the cache directory
try:
os.makedirs(self._cache_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
# Make a key -> CacheEntry(size, expiry) map ordered by mtime
self._size = 0
self._files = OrderedDict()
for mtime, size, key in sorted([(f.stat().st_mtime, f.stat().st_size, f.name)
for f in scandir(self._cache_dir)
if f.is_file()]):
self._files[key] = CacheEntry(size, mtime + self.min_time)
self._size += size
def _filepath(self, key):
return os.path.join(self._cache_dir, key)
def _make_space(self, required_space, key=None):
"""Delete files to free up the required space (or close to it)
If key is provided and exists in the cache, its size will be
subtracted from the required size.
"""
target = self.max_size - required_space
if key is not None:
target += self._files.get(key, NULL_ENTRY).size
with self._lock:
# Delete the oldest file until self._size <= target
for k in list(self._files.keys()):
if self._size <= target:
break
try:
self.delete(k)
except ProtectedError:
pass
def _record_file(self, key, size):
# If the file is being replaced, add only the difference in size
self._size += size - self._files.get(key, NULL_ENTRY).size
self._files[key] = CacheEntry(size, int(time()) + self.min_time)
def _freshen_file(self, key):
"""Touch the file to change modified time and move it to the end of the cache dict"""
old = self._files.pop(key)
self._files[key] = CacheEntry(old.size, int(time()) + self.min_time)
os.utime(self._filepath(key), None)
@property
def size(self):
"""The current amount of data cached"""
return self._size
def touch(self, key):
"""Mark a cache entry as fresh"""
with self._lock:
if not self.has(key):
raise CacheMiss(key)
self._freshen_file(key)
@contextlib.contextmanager
def set_fileobj(self, key):
"""Yields a file object that can have bytes written to it in order to
store them in the cache.
The contents of the file object will be stored in the cache when the
context is exited.
Ex:
>>> with cache.set_fileobj(key) as fp:
... json.dump(some_data, fp)
"""
try:
with tempfile.NamedTemporaryFile(dir=self._cache_dir, suffix=".part", delete=True) as f:
yield f
# seek to end and get position to get filesize
f.seek(0, 2)
size = f.tell()
with self._lock:
if self._auto_prune:
self._make_space(size, key=key)
osreplace(f.name, self._filepath(key))
self._record_file(key, size)
except OSError as e:
# Ignore error from trying to delete the renamed temp file
if e.errno != errno.ENOENT:
raise
def set(self, key, value):
"""Set a literal value into the cache and return its path"""
with self.set_fileobj(key) as f:
f.write(value)
return self._filepath(key)
def set_generated(self, key, gen_function):
"""Pass the values yielded from the generator function through and set
the end result in the cache.
The contents will be set into the cache only if and when the generator
completes.
Ex:
>>> for x in cache.set_generated(key, generator_function):
... print(x)
"""
with self.set_fileobj(key) as f:
for data in gen_function():
f.write(data)
yield data
def get(self, key):
"""Return the path to the file where the cached data is stored"""
self.touch(key)
return self._filepath(key)
@contextlib.contextmanager
def get_fileobj(self, key):
"""Yields a file object that can be used to read cached bytes"""
with open(self.get(key), 'rb') as f:
yield f
def get_value(self, key):
"""Return the cached data"""
with self.get_fileobj(key) as f:
return f.read()
def delete(self, key):
"""Delete a file from the cache"""
with self._lock:
if not self.has(key):
return
if time() < self._files[key].expires:
raise ProtectedError("File has not expired")
os.remove(self._filepath(key))
self._size -= self._files.pop(key).size
def prune(self):
"""Prune the cache down to the max size
Note that protected files are not deleted
"""
self._make_space(0)
def clear(self):
"""Clear the cache
Note that protected files are not deleted
"""
self._make_space(self.max_size)
def has(self, key):
"""Check if a key is currently cached"""
if key not in self._files:
return False
if not os.path.exists(self._filepath(key)):
# Underlying file is gone, remove from the cache
self._size -= self._files.pop(key).size
return False
return True

View File

@ -26,6 +26,8 @@ class DefaultConfig(object):
} }
WEBAPP = { WEBAPP = {
'cache_dir': tempdir, 'cache_dir': tempdir,
'cache_size': 1024,
'transcode_cache_size': 512,
'log_file': None, 'log_file': None,
'log_level': 'WARNING', 'log_level': 'WARNING',

View File

@ -7,6 +7,31 @@
# #
# Distributed under terms of the GNU AGPLv3 license. # Distributed under terms of the GNU AGPLv3 license.
# Try built-in scandir, fall back to the package for Python 2.7
try:
from os import scandir
except ImportError:
from scandir import scandir
# os.replace was added in Python 3.3, provide a fallback for Python 2.7
try:
from os import replace as osreplace
except ImportError:
# os.rename is equivalent to os.replace except on Windows
# On Windows an existing file will not be overwritten
# This fallback just attempts to delete the dst file before using rename
import sys
if sys.platform != 'win32':
from os import rename as osreplace
else:
import os
def osreplace(src, dst):
try:
os.remove(dst)
except OSError:
pass
os.rename(src, dst)
try: try:
# Python 2 # Python 2
strtype = basestring strtype = basestring

View File

@ -17,6 +17,7 @@ from os import makedirs, path, urandom
from pony.orm import db_session from pony.orm import db_session
from .config import IniConfig from .config import IniConfig
from .cache import Cache
from .db import init_database from .db import init_database
logger = logging.getLogger(__package__) logger = logging.getLogger(__package__)
@ -53,6 +54,14 @@ def create_application(config = None):
if extension not in mimetypes.types_map: if extension not in mimetypes.types_map:
mimetypes.add_type(v, extension, False) mimetypes.add_type(v, extension, False)
# Initialize Cache objects
# Max size is MB in the config file but Cache expects bytes
cache_dir = app.config['WEBAPP']['cache_dir']
max_size_cache = app.config['WEBAPP']['cache_size'] * 1024**2
max_size_transcodes = app.config['WEBAPP']['transcode_cache_size'] * 1024**2
app.cache = Cache(path.join(cache_dir, "cache"), max_size_cache)
app.transcode_cache = Cache(path.join(cache_dir, "transcodes"), max_size_transcodes)
# Test for the cache directory # Test for the cache directory
cache_path = app.config['WEBAPP']['cache_dir'] cache_path = app.config['WEBAPP']['cache_dir']
if not path.exists(cache_path): if not path.exists(cache_path):

View File

@ -10,6 +10,7 @@
import unittest import unittest
from .test_cli import CLITestCase from .test_cli import CLITestCase
from .test_cache import CacheTestCase
from .test_config import ConfigTestCase from .test_config import ConfigTestCase
from .test_db import DbTestCase from .test_db import DbTestCase
from .test_lastfm import LastFmTestCase from .test_lastfm import LastFmTestCase
@ -20,6 +21,7 @@ from .test_watcher import suite as watcher_suite
def suite(): def suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(CacheTestCase))
suite.addTest(unittest.makeSuite(ConfigTestCase)) suite.addTest(unittest.makeSuite(ConfigTestCase))
suite.addTest(unittest.makeSuite(DbTestCase)) suite.addTest(unittest.makeSuite(DbTestCase))
suite.addTest(unittest.makeSuite(ScannerTestCase)) suite.addTest(unittest.makeSuite(ScannerTestCase))

278
tests/base/test_cache.py Normal file
View File

@ -0,0 +1,278 @@
#!/usr/bin/env python
# coding: utf-8
#
# This file is part of Supysonic.
# Supysonic is a Python implementation of the Subsonic server API.
#
# Copyright (C) 2018 Alban 'spl0k' Féron
#
# Distributed under terms of the GNU AGPLv3 license.
import os
import unittest
import shutil
import time
import tempfile
from supysonic.cache import Cache, CacheMiss, ProtectedError
class CacheTestCase(unittest.TestCase):
def setUp(self):
self.__dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.__dir)
def test_existing_files_order(self):
cache = Cache(self.__dir, 30)
val = b'0123456789'
cache.set("key1", val)
cache.set("key2", val)
cache.set("key3", val)
self.assertEqual(cache.size, 30)
# file mtime is accurate to the second
time.sleep(1)
cache.get_value("key1")
cache = Cache(self.__dir, 30, min_time=0)
self.assertEqual(cache.size, 30)
self.assertTrue(cache.has("key1"))
self.assertTrue(cache.has("key2"))
self.assertTrue(cache.has("key3"))
cache.set("key4", val)
self.assertEqual(cache.size, 30)
self.assertTrue(cache.has("key1"))
self.assertFalse(cache.has("key2"))
self.assertTrue(cache.has("key3"))
self.assertTrue(cache.has("key4"))
def test_missing(self):
cache = Cache(self.__dir, 10)
self.assertFalse(cache.has("missing"))
with self.assertRaises(CacheMiss):
cache.get_value("missing")
def test_delete_missing(self):
cache = Cache(self.__dir, 0, min_time=0)
cache.delete("missing1")
cache.delete("missing2")
def test_store_literal(self):
cache = Cache(self.__dir, 10)
val = b'0123456789'
cache.set("key", val)
self.assertEqual(cache.size, 10)
self.assertTrue(cache.has("key"))
self.assertEqual(cache.get_value("key"), val)
def test_store_generated(self):
cache = Cache(self.__dir, 10)
val = [b'0', b'12', b'345', b'6789']
def gen():
for b in val:
yield b
t = []
for x in cache.set_generated("key", gen):
t.append(x)
self.assertEqual(cache.size, 0)
self.assertFalse(cache.has("key"))
self.assertEqual(t, val)
self.assertEqual(cache.size, 10)
self.assertEqual(cache.get_value("key"), b''.join(val))
def test_store_to_fp(self):
cache = Cache(self.__dir, 10)
val = b'0123456789'
with cache.set_fileobj("key") as fp:
fp.write(val)
self.assertEqual(cache.size, 0)
self.assertEqual(cache.size, 10)
self.assertEqual(cache.get_value("key"), val)
def test_access_data(self):
cache = Cache(self.__dir, 25, min_time=0)
val = b'0123456789'
cache.set("key", val)
self.assertEqual(cache.get_value("key"), val)
with cache.get_fileobj("key") as f:
self.assertEqual(f.read(), val)
with open(cache.get("key"), 'rb') as f:
self.assertEqual(f.read(), val)
def test_accessing_preserves(self):
cache = Cache(self.__dir, 25, min_time=0)
val = b'0123456789'
cache.set("key1", val)
cache.set("key2", val)
self.assertEqual(cache.size, 20)
cache.get_value("key1")
cache.set("key3", val)
self.assertEqual(cache.size, 20)
self.assertTrue(cache.has("key1"))
self.assertFalse(cache.has("key2"))
self.assertTrue(cache.has("key3"))
def test_automatic_delete_oldest(self):
cache = Cache(self.__dir, 25, min_time=0)
val = b'0123456789'
cache.set("key1", val)
self.assertTrue(cache.has("key1"))
self.assertEqual(cache.size, 10)
cache.set("key2", val)
self.assertEqual(cache.size, 20)
self.assertTrue(cache.has("key1"))
self.assertTrue(cache.has("key2"))
cache.set("key3", val)
self.assertEqual(cache.size, 20)
self.assertFalse(cache.has("key1"))
self.assertTrue(cache.has("key2"))
self.assertTrue(cache.has("key3"))
def test_delete(self):
cache = Cache(self.__dir, 25, min_time=0)
val = b'0123456789'
cache.set("key1", val)
self.assertTrue(cache.has("key1"))
self.assertEqual(cache.size, 10)
cache.delete("key1")
self.assertFalse(cache.has("key1"))
self.assertEqual(cache.size, 0)
def test_cleanup_on_error(self):
cache = Cache(self.__dir, 10)
def gen():
# Cause a TypeError halfway through
for b in [b'0', b'12', object(), b'345', b'6789']:
yield b
with self.assertRaises(TypeError):
for x in cache.set_generated("key", gen):
pass
# Make sure no partial files are left after the error
self.assertEqual(list(os.listdir(self.__dir)), list())
def test_parallel_generation(self):
cache = Cache(self.__dir, 20)
def gen():
for b in [b'0', b'12', b'345', b'6789']:
yield b
g1 = cache.set_generated("key", gen)
g2 = cache.set_generated("key", gen)
next(g1)
files = os.listdir(self.__dir)
self.assertEqual(len(files), 1)
for x in files:
self.assertTrue(x.endswith(".part"))
next(g2)
files = os.listdir(self.__dir)
self.assertEqual(len(files), 2)
for x in files:
self.assertTrue(x.endswith(".part"))
self.assertEqual(cache.size, 0)
for x in g1:
pass
self.assertEqual(cache.size, 10)
self.assertTrue(cache.has("key"))
# Replace the file - size should stay the same
for x in g2:
pass
self.assertEqual(cache.size, 10)
self.assertTrue(cache.has("key"))
# Only a single file
self.assertEqual(len(os.listdir(self.__dir)), 1)
def test_replace(self):
cache = Cache(self.__dir, 20)
val_small = b'0'
val_big = b'0123456789'
cache.set("key", val_small)
self.assertEqual(cache.size, 1)
cache.set("key", val_big)
self.assertEqual(cache.size, 10)
cache.set("key", val_small)
self.assertEqual(cache.size, 1)
def test_no_auto_prune(self):
cache = Cache(self.__dir, 10, min_time=0, auto_prune=False)
val = b'0123456789'
cache.set("key1", val)
cache.set("key2", val)
cache.set("key3", val)
cache.set("key4", val)
self.assertEqual(cache.size, 40)
cache.prune()
self.assertEqual(cache.size, 10)
def test_min_time_clear(self):
cache = Cache(self.__dir, 40, min_time=1)
val = b'0123456789'
cache.set("key1", val)
cache.set("key2", val)
time.sleep(1)
cache.set("key3", val)
cache.set("key4", val)
self.assertEqual(cache.size, 40)
cache.clear()
self.assertEqual(cache.size, 20)
time.sleep(1)
cache.clear()
self.assertEqual(cache.size, 0)
def test_not_expired(self):
cache = Cache(self.__dir, 40, min_time=1)
val = b'0123456789'
cache.set("key1", val)
with self.assertRaises(ProtectedError):
cache.delete("key1")
time.sleep(1)
cache.delete("key1")
self.assertEqual(cache.size, 0)
def test_missing_cache_file(self):
cache = Cache(self.__dir, 10, min_time=0)
val = b'0123456789'
os.remove(cache.set("key", val))
self.assertEqual(cache.size, 10)
self.assertFalse(cache.has("key"))
self.assertEqual(cache.size, 0)
os.remove(cache.set("key", val))
self.assertEqual(cache.size, 10)
with self.assertRaises(CacheMiss):
cache.get("key")
self.assertEqual(cache.size, 0)
if __name__ == '__main__':
unittest.main()