diff --git a/config.sample b/config.sample index d314891..14579da 100644 --- a/config.sample +++ b/config.sample @@ -12,6 +12,12 @@ ; Optional cache directory. Default: /tmp/supysonic cache_dir = /var/supysonic/cache +; Main cache max size in MB. Default: 512 +cache_size = 512 + +; Transcode cache max size in MB. Default: 1024 (1GB) +transcode_cache_size = 1024 + ; Optional rotating log file. Default: none log_file = /var/supysonic/supysonic.log diff --git a/setup.py b/setup.py index 60b6584..79416a8 100755 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ reqs = [ 'Pillow', 'requests>=1.0.0', 'mutagen>=1.33', + 'scandir<2.0.0', 'zipstream' ] extras = { diff --git a/supysonic/api/media.py b/supysonic/api/media.py index 3749d36..be02879 100644 --- a/supysonic/api/media.py +++ b/supysonic/api/media.py @@ -15,6 +15,10 @@ import requests import shlex import subprocess import uuid +import io +import hashlib +import json +import zlib from flask import request, Response, send_file from flask import current_app @@ -25,6 +29,7 @@ from zipfile import ZIP_DEFLATED from zipstream import ZipFile from .. import scanner +from ..cache import CacheMiss from ..db import Track, Album, Artist, Folder, User, ClientPrefs, now from ..py23 import dict @@ -78,50 +83,59 @@ def stream_media(): dst_mimetype = mimetypes.guess_type('dummyname.' + dst_suffix, False)[0] or 'application/octet-stream' if format != 'raw' and (dst_suffix != src_suffix or dst_bitrate != res.bitrate): - config = current_app.config['TRANSCODING'] - transcoder = config.get('transcoder_{}_{}'.format(src_suffix, dst_suffix)) - decoder = config.get('decoder_' + src_suffix) or config.get('decoder') - encoder = config.get('encoder_' + dst_suffix) or config.get('encoder') - if not transcoder and (not decoder or not encoder): - transcoder = config.get('transcoder') - if not transcoder: - message = 'No way to transcode from {} to {}'.format(src_suffix, dst_suffix) - logger.info(message) - raise GenericError(message) + # Requires transcoding + cache = current_app.transcode_cache + cache_key = "{}-{}.{}".format(res.id, dst_bitrate, dst_suffix) - transcoder, decoder, encoder = map(lambda x: prepare_transcoding_cmdline(x, res.path, src_suffix, dst_suffix, dst_bitrate), [ transcoder, decoder, encoder ]) try: - if transcoder: - dec_proc = None - proc = subprocess.Popen(transcoder, stdout = subprocess.PIPE) - else: - dec_proc = subprocess.Popen(decoder, stdout = subprocess.PIPE) - proc = subprocess.Popen(encoder, stdin = dec_proc.stdout, stdout = subprocess.PIPE) - except OSError: - raise ServerError('Error while running the transcoding process') + response = send_file(cache.get(cache_key), mimetype=dst_mimetype, conditional=True) + except CacheMiss: + config = current_app.config['TRANSCODING'] + transcoder = config.get('transcoder_{}_{}'.format(src_suffix, dst_suffix)) + decoder = config.get('decoder_' + src_suffix) or config.get('decoder') + encoder = config.get('encoder_' + dst_suffix) or config.get('encoder') + if not transcoder and (not decoder or not encoder): + transcoder = config.get('transcoder') + if not transcoder: + message = 'No way to transcode from {} to {}'.format(src_suffix, dst_suffix) + logger.info(message) + raise GenericError(message) - def transcode(): + transcoder, decoder, encoder = map(lambda x: prepare_transcoding_cmdline(x, res.path, src_suffix, dst_suffix, dst_bitrate), [ transcoder, decoder, encoder ]) try: - while True: - data = proc.stdout.read(8192) - if not data: - break - yield data - except: # pragma: nocover - if dec_proc != None: - dec_proc.kill() - proc.kill() + if transcoder: + dec_proc = None + proc = subprocess.Popen(transcoder, stdout = subprocess.PIPE) + else: + dec_proc = subprocess.Popen(decoder, stdout = subprocess.PIPE) + proc = subprocess.Popen(encoder, stdin = dec_proc.stdout, stdout = subprocess.PIPE) + except OSError: + raise ServerError('Error while running the transcoding process') - if dec_proc != None: - dec_proc.wait() - proc.wait() + def transcode(): + try: + while True: + data = proc.stdout.read(8192) + if not data: + break + yield data + except: # pragma: nocover + if dec_proc != None: + dec_proc.kill() + proc.kill() + raise + finally: + if dec_proc != None: + dec_proc.wait() + proc.wait() + resp_content = cache.set_generated(cache_key, transcode) - logger.info('Transcoding track {0.id} for user {1.id}. Source: {2} at {0.bitrate}kbps. Dest: {3} at {4}kbps'.format(res, request.user, src_suffix, dst_suffix, dst_bitrate)) - response = Response(transcode(), mimetype = dst_mimetype) - if estimateContentLength == 'true': - response.headers.add('Content-Length', dst_bitrate * 1000 * res.duration // 8) + logger.info('Transcoding track {0.id} for user {1.id}. Source: {2} at {0.bitrate}kbps. Dest: {3} at {4}kbps'.format(res, request.user, src_suffix, dst_suffix, dst_bitrate)) + response = Response(resp_content, mimetype=dst_mimetype) + if estimateContentLength == 'true': + response.headers.add('Content-Length', dst_bitrate * 1000 * res.duration // 8) else: - response = send_file(res.path, mimetype = dst_mimetype, conditional=True) + response = send_file(res.path, mimetype=dst_mimetype, conditional=True) res.play_count = res.play_count + 1 res.last_play = now() @@ -159,6 +173,7 @@ def download_media(): @api.route('/getCoverArt.view', methods = [ 'GET', 'POST' ]) def cover_art(): + cache = current_app.cache eid = request.values['id'] if Folder.exists(id=eid): res = get_entity(Folder) @@ -166,18 +181,15 @@ def cover_art(): raise NotFound('Cover art') cover_path = os.path.join(res.path, res.cover_art) elif Track.exists(id=eid): - embed_cache = os.path.join(current_app.config['WEBAPP']['cache_dir'], 'embeded_art') - cover_path = os.path.join(embed_cache, eid) - if not os.path.exists(cover_path): + cache_key = "{}-cover".format(eid) + try: + cover_path = cache.get(cache_key) + except CacheMiss: res = get_entity(Track) art = res.extract_cover_art() if not art: raise NotFound('Cover art') - #Art found, save to cache - if not os.path.exists(embed_cache): - os.makedirs(embed_cache) - with open(cover_path, 'wb') as cover_file: - cover_file.write(art) + cover_path = cache.set(cache_key, art) else: raise NotFound('Entity') @@ -188,19 +200,18 @@ def cover_art(): return send_file(cover_path) im = Image.open(cover_path) + mimetype = 'image/{}'.format(im.format.lower()) if size > im.width and size > im.height: - return send_file(cover_path) + return send_file(cover_path, mimetype=mimetype) - size_path = os.path.join(current_app.config['WEBAPP']['cache_dir'], str(size)) - path = os.path.abspath(os.path.join(size_path, eid)) - if os.path.exists(path): - return send_file(path, mimetype = 'image/' + im.format.lower()) - if not os.path.exists(size_path): - os.makedirs(size_path) - - im.thumbnail([size, size], Image.ANTIALIAS) - im.save(path, im.format) - return send_file(path, mimetype = 'image/' + im.format.lower()) + cache_key = "{}-cover-{}".format(eid, size) + try: + return send_file(cache.get(cache_key), mimetype=mimetype) + except CacheMiss: + im.thumbnail([size, size], Image.ANTIALIAS) + with cache.set_fileobj(cache_key) as fp: + im.save(fp, im.format) + return send_file(cache.get(cache_key), mimetype=mimetype) @api.route('/getLyrics.view', methods = [ 'GET', 'POST' ]) def lyrics(): @@ -227,21 +238,37 @@ def lyrics(): value = lyrics )) + # Create a stable, unique, filesystem-compatible identifier for the artist+title + unique = hashlib.md5(json.dumps([x.lower() for x in (artist, title)]).encode('utf-8')).hexdigest() + cache_key = "lyrics-{}".format(unique) + + lyrics = dict() try: - r = requests.get("http://api.chartlyrics.com/apiv1.asmx/SearchLyricDirect", - params = { 'artist': artist, 'song': title }, timeout = 5) - root = ElementTree.fromstring(r.content) + lyrics = json.loads( + zlib.decompress( + current_app.cache.get_value(cache_key) + ).decode('utf-8') + ) + except (CacheMiss, zlib.error, TypeError, ValueError): + try: + r = requests.get("http://api.chartlyrics.com/apiv1.asmx/SearchLyricDirect", + params={'artist': artist, 'song': title}, timeout=5) + root = ElementTree.fromstring(r.content) - ns = { 'cl': 'http://api.chartlyrics.com/' } - return request.formatter('lyrics', dict( - artist = root.find('cl:LyricArtist', namespaces = ns).text, - title = root.find('cl:LyricSong', namespaces = ns).text, - value = root.find('cl:Lyric', namespaces = ns).text - )) - except requests.exceptions.RequestException as e: # pragma: nocover - logger.warning('Error while requesting the ChartLyrics API: ' + str(e)) + ns = {'cl': 'http://api.chartlyrics.com/'} + lyrics = dict( + artist = root.find('cl:LyricArtist', namespaces=ns).text, + title = root.find('cl:LyricSong', namespaces=ns).text, + value = root.find('cl:Lyric', namespaces=ns).text + ) - return request.formatter('lyrics', dict()) # pragma: nocover + current_app.cache.set( + cache_key, zlib.compress(json.dumps(lyrics).encode('utf-8'), 9) + ) + except requests.exceptions.RequestException as e: # pragma: nocover + logger.warning('Error while requesting the ChartLyrics API: ' + str(e)) + + return request.formatter('lyrics', lyrics) def read_file_as_unicode(path): """ Opens a file trying with different encodings and returns the contents as a unicode string """ diff --git a/supysonic/cache.py b/supysonic/cache.py new file mode 100644 index 0000000..b1b7b35 --- /dev/null +++ b/supysonic/cache.py @@ -0,0 +1,231 @@ +# coding: utf-8 +# +# This file is part of Supysonic. +# Supysonic is a Python implementation of the Subsonic server API. +# +# Copyright (C) 2013-2018 Alban 'spl0k' Féron +# 2017 Óscar García Amor +# +# Distributed under terms of the GNU AGPLv3 license. + +from collections import OrderedDict, namedtuple +import contextlib +import errno +import logging +import os +import os.path +import tempfile +import threading +from time import time + +from .py23 import scandir, osreplace + + +logger = logging.getLogger(__name__) + + +class CacheMiss(KeyError): + """The requested data is not in the cache""" + pass + + +class ProtectedError(Exception): + """The data cannot be purged from the cache""" + pass + + +CacheEntry = namedtuple("CacheEntry", ["size", "expires"]) +NULL_ENTRY = CacheEntry(0, 0) + +class Cache(object): + """Provides a common interface for caching files to disk""" + # Modeled after werkzeug.contrib.cache.FileSystemCache + + # keys must be filename-compatible strings (no paths) + # values must be bytes (not strings) + + def __init__(self, cache_dir, max_size, min_time=300, auto_prune=True): + """Initialize the cache + + cache_dir: The folder to store cached files + max_size: The maximum allowed size of the cache in bytes + min_time: The minimum amount of time a file will be stored in the cache + in seconds (default 300 = 5min) + auto_prune: If True (default) the cache will automatically be pruned to + the max_size when possible. + + Note that max_size is not a hard restriction and in some cases will + temporarily be exceeded, even when auto-pruning is turned on. + """ + self._cache_dir = os.path.abspath(cache_dir) + self.min_time = min_time + self.max_size = max_size + self._auto_prune = auto_prune + self._lock = threading.RLock() + + # Create the cache directory + try: + os.makedirs(self._cache_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + # Make a key -> CacheEntry(size, expiry) map ordered by mtime + self._size = 0 + self._files = OrderedDict() + for mtime, size, key in sorted([(f.stat().st_mtime, f.stat().st_size, f.name) + for f in scandir(self._cache_dir) + if f.is_file()]): + self._files[key] = CacheEntry(size, mtime + self.min_time) + self._size += size + + def _filepath(self, key): + return os.path.join(self._cache_dir, key) + + def _make_space(self, required_space, key=None): + """Delete files to free up the required space (or close to it) + + If key is provided and exists in the cache, its size will be + subtracted from the required size. + """ + target = self.max_size - required_space + if key is not None: + target += self._files.get(key, NULL_ENTRY).size + + with self._lock: + # Delete the oldest file until self._size <= target + for k in list(self._files.keys()): + if self._size <= target: + break + try: + self.delete(k) + except ProtectedError: + pass + + def _record_file(self, key, size): + # If the file is being replaced, add only the difference in size + self._size += size - self._files.get(key, NULL_ENTRY).size + self._files[key] = CacheEntry(size, int(time()) + self.min_time) + + def _freshen_file(self, key): + """Touch the file to change modified time and move it to the end of the cache dict""" + old = self._files.pop(key) + self._files[key] = CacheEntry(old.size, int(time()) + self.min_time) + os.utime(self._filepath(key), None) + + @property + def size(self): + """The current amount of data cached""" + return self._size + + def touch(self, key): + """Mark a cache entry as fresh""" + with self._lock: + if not self.has(key): + raise CacheMiss(key) + self._freshen_file(key) + + @contextlib.contextmanager + def set_fileobj(self, key): + """Yields a file object that can have bytes written to it in order to + store them in the cache. + + The contents of the file object will be stored in the cache when the + context is exited. + + Ex: + >>> with cache.set_fileobj(key) as fp: + ... json.dump(some_data, fp) + """ + try: + with tempfile.NamedTemporaryFile(dir=self._cache_dir, suffix=".part", delete=True) as f: + yield f + + # seek to end and get position to get filesize + f.seek(0, 2) + size = f.tell() + + with self._lock: + if self._auto_prune: + self._make_space(size, key=key) + osreplace(f.name, self._filepath(key)) + self._record_file(key, size) + except OSError as e: + # Ignore error from trying to delete the renamed temp file + if e.errno != errno.ENOENT: + raise + + def set(self, key, value): + """Set a literal value into the cache and return its path""" + with self.set_fileobj(key) as f: + f.write(value) + return self._filepath(key) + + def set_generated(self, key, gen_function): + """Pass the values yielded from the generator function through and set + the end result in the cache. + + The contents will be set into the cache only if and when the generator + completes. + + Ex: + >>> for x in cache.set_generated(key, generator_function): + ... print(x) + """ + with self.set_fileobj(key) as f: + for data in gen_function(): + f.write(data) + yield data + + def get(self, key): + """Return the path to the file where the cached data is stored""" + self.touch(key) + return self._filepath(key) + + @contextlib.contextmanager + def get_fileobj(self, key): + """Yields a file object that can be used to read cached bytes""" + with open(self.get(key), 'rb') as f: + yield f + + def get_value(self, key): + """Return the cached data""" + with self.get_fileobj(key) as f: + return f.read() + + def delete(self, key): + """Delete a file from the cache""" + with self._lock: + if not self.has(key): + return + if time() < self._files[key].expires: + raise ProtectedError("File has not expired") + + os.remove(self._filepath(key)) + self._size -= self._files.pop(key).size + + def prune(self): + """Prune the cache down to the max size + + Note that protected files are not deleted + """ + self._make_space(0) + + def clear(self): + """Clear the cache + + Note that protected files are not deleted + """ + self._make_space(self.max_size) + + def has(self, key): + """Check if a key is currently cached""" + if key not in self._files: + return False + + if not os.path.exists(self._filepath(key)): + # Underlying file is gone, remove from the cache + self._size -= self._files.pop(key).size + return False + + return True diff --git a/supysonic/config.py b/supysonic/config.py index f82487c..1501346 100644 --- a/supysonic/config.py +++ b/supysonic/config.py @@ -26,6 +26,8 @@ class DefaultConfig(object): } WEBAPP = { 'cache_dir': tempdir, + 'cache_size': 1024, + 'transcode_cache_size': 512, 'log_file': None, 'log_level': 'WARNING', diff --git a/supysonic/py23.py b/supysonic/py23.py index 319e00e..ae500cd 100644 --- a/supysonic/py23.py +++ b/supysonic/py23.py @@ -7,6 +7,31 @@ # # Distributed under terms of the GNU AGPLv3 license. +# Try built-in scandir, fall back to the package for Python 2.7 +try: + from os import scandir +except ImportError: + from scandir import scandir + +# os.replace was added in Python 3.3, provide a fallback for Python 2.7 +try: + from os import replace as osreplace +except ImportError: + # os.rename is equivalent to os.replace except on Windows + # On Windows an existing file will not be overwritten + # This fallback just attempts to delete the dst file before using rename + import sys + if sys.platform != 'win32': + from os import rename as osreplace + else: + import os + def osreplace(src, dst): + try: + os.remove(dst) + except OSError: + pass + os.rename(src, dst) + try: # Python 2 strtype = basestring diff --git a/supysonic/web.py b/supysonic/web.py index 278e5f8..fc1a906 100644 --- a/supysonic/web.py +++ b/supysonic/web.py @@ -17,6 +17,7 @@ from os import makedirs, path, urandom from pony.orm import db_session from .config import IniConfig +from .cache import Cache from .db import init_database logger = logging.getLogger(__package__) @@ -53,6 +54,14 @@ def create_application(config = None): if extension not in mimetypes.types_map: mimetypes.add_type(v, extension, False) + # Initialize Cache objects + # Max size is MB in the config file but Cache expects bytes + cache_dir = app.config['WEBAPP']['cache_dir'] + max_size_cache = app.config['WEBAPP']['cache_size'] * 1024**2 + max_size_transcodes = app.config['WEBAPP']['transcode_cache_size'] * 1024**2 + app.cache = Cache(path.join(cache_dir, "cache"), max_size_cache) + app.transcode_cache = Cache(path.join(cache_dir, "transcodes"), max_size_transcodes) + # Test for the cache directory cache_path = app.config['WEBAPP']['cache_dir'] if not path.exists(cache_path): diff --git a/tests/base/__init__.py b/tests/base/__init__.py index 2f4630e..4343e7b 100644 --- a/tests/base/__init__.py +++ b/tests/base/__init__.py @@ -10,6 +10,7 @@ import unittest from .test_cli import CLITestCase +from .test_cache import CacheTestCase from .test_config import ConfigTestCase from .test_db import DbTestCase from .test_lastfm import LastFmTestCase @@ -20,6 +21,7 @@ from .test_watcher import suite as watcher_suite def suite(): suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(CacheTestCase)) suite.addTest(unittest.makeSuite(ConfigTestCase)) suite.addTest(unittest.makeSuite(DbTestCase)) suite.addTest(unittest.makeSuite(ScannerTestCase)) diff --git a/tests/base/test_cache.py b/tests/base/test_cache.py new file mode 100644 index 0000000..a134b10 --- /dev/null +++ b/tests/base/test_cache.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python +# coding: utf-8 +# +# This file is part of Supysonic. +# Supysonic is a Python implementation of the Subsonic server API. +# +# Copyright (C) 2018 Alban 'spl0k' Féron +# +# Distributed under terms of the GNU AGPLv3 license. + +import os +import unittest +import shutil +import time +import tempfile + +from supysonic.cache import Cache, CacheMiss, ProtectedError + + +class CacheTestCase(unittest.TestCase): + def setUp(self): + self.__dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.__dir) + + def test_existing_files_order(self): + cache = Cache(self.__dir, 30) + val = b'0123456789' + cache.set("key1", val) + cache.set("key2", val) + cache.set("key3", val) + self.assertEqual(cache.size, 30) + + # file mtime is accurate to the second + time.sleep(1) + cache.get_value("key1") + + cache = Cache(self.__dir, 30, min_time=0) + self.assertEqual(cache.size, 30) + self.assertTrue(cache.has("key1")) + self.assertTrue(cache.has("key2")) + self.assertTrue(cache.has("key3")) + + cache.set("key4", val) + self.assertEqual(cache.size, 30) + self.assertTrue(cache.has("key1")) + self.assertFalse(cache.has("key2")) + self.assertTrue(cache.has("key3")) + self.assertTrue(cache.has("key4")) + + def test_missing(self): + cache = Cache(self.__dir, 10) + self.assertFalse(cache.has("missing")) + with self.assertRaises(CacheMiss): + cache.get_value("missing") + + def test_delete_missing(self): + cache = Cache(self.__dir, 0, min_time=0) + cache.delete("missing1") + cache.delete("missing2") + + def test_store_literal(self): + cache = Cache(self.__dir, 10) + val = b'0123456789' + cache.set("key", val) + self.assertEqual(cache.size, 10) + self.assertTrue(cache.has("key")) + self.assertEqual(cache.get_value("key"), val) + + def test_store_generated(self): + cache = Cache(self.__dir, 10) + val = [b'0', b'12', b'345', b'6789'] + def gen(): + for b in val: + yield b + + t = [] + for x in cache.set_generated("key", gen): + t.append(x) + self.assertEqual(cache.size, 0) + self.assertFalse(cache.has("key")) + + self.assertEqual(t, val) + self.assertEqual(cache.size, 10) + self.assertEqual(cache.get_value("key"), b''.join(val)) + + def test_store_to_fp(self): + cache = Cache(self.__dir, 10) + val = b'0123456789' + with cache.set_fileobj("key") as fp: + fp.write(val) + self.assertEqual(cache.size, 0) + + self.assertEqual(cache.size, 10) + self.assertEqual(cache.get_value("key"), val) + + def test_access_data(self): + cache = Cache(self.__dir, 25, min_time=0) + val = b'0123456789' + cache.set("key", val) + + self.assertEqual(cache.get_value("key"), val) + + with cache.get_fileobj("key") as f: + self.assertEqual(f.read(), val) + + with open(cache.get("key"), 'rb') as f: + self.assertEqual(f.read(), val) + + + def test_accessing_preserves(self): + cache = Cache(self.__dir, 25, min_time=0) + val = b'0123456789' + cache.set("key1", val) + cache.set("key2", val) + self.assertEqual(cache.size, 20) + + cache.get_value("key1") + + cache.set("key3", val) + self.assertEqual(cache.size, 20) + self.assertTrue(cache.has("key1")) + self.assertFalse(cache.has("key2")) + self.assertTrue(cache.has("key3")) + + def test_automatic_delete_oldest(self): + cache = Cache(self.__dir, 25, min_time=0) + val = b'0123456789' + cache.set("key1", val) + self.assertTrue(cache.has("key1")) + self.assertEqual(cache.size, 10) + + cache.set("key2", val) + self.assertEqual(cache.size, 20) + self.assertTrue(cache.has("key1")) + self.assertTrue(cache.has("key2")) + + cache.set("key3", val) + self.assertEqual(cache.size, 20) + self.assertFalse(cache.has("key1")) + self.assertTrue(cache.has("key2")) + self.assertTrue(cache.has("key3")) + + def test_delete(self): + cache = Cache(self.__dir, 25, min_time=0) + val = b'0123456789' + cache.set("key1", val) + self.assertTrue(cache.has("key1")) + self.assertEqual(cache.size, 10) + + cache.delete("key1") + + self.assertFalse(cache.has("key1")) + self.assertEqual(cache.size, 0) + + def test_cleanup_on_error(self): + cache = Cache(self.__dir, 10) + def gen(): + # Cause a TypeError halfway through + for b in [b'0', b'12', object(), b'345', b'6789']: + yield b + + with self.assertRaises(TypeError): + for x in cache.set_generated("key", gen): + pass + + # Make sure no partial files are left after the error + self.assertEqual(list(os.listdir(self.__dir)), list()) + + def test_parallel_generation(self): + cache = Cache(self.__dir, 20) + def gen(): + for b in [b'0', b'12', b'345', b'6789']: + yield b + + g1 = cache.set_generated("key", gen) + g2 = cache.set_generated("key", gen) + + next(g1) + files = os.listdir(self.__dir) + self.assertEqual(len(files), 1) + for x in files: + self.assertTrue(x.endswith(".part")) + + next(g2) + files = os.listdir(self.__dir) + self.assertEqual(len(files), 2) + for x in files: + self.assertTrue(x.endswith(".part")) + + self.assertEqual(cache.size, 0) + for x in g1: + pass + self.assertEqual(cache.size, 10) + self.assertTrue(cache.has("key")) + + # Replace the file - size should stay the same + for x in g2: + pass + self.assertEqual(cache.size, 10) + self.assertTrue(cache.has("key")) + + # Only a single file + self.assertEqual(len(os.listdir(self.__dir)), 1) + + def test_replace(self): + cache = Cache(self.__dir, 20) + val_small = b'0' + val_big = b'0123456789' + + cache.set("key", val_small) + self.assertEqual(cache.size, 1) + + cache.set("key", val_big) + self.assertEqual(cache.size, 10) + + cache.set("key", val_small) + self.assertEqual(cache.size, 1) + + def test_no_auto_prune(self): + cache = Cache(self.__dir, 10, min_time=0, auto_prune=False) + val = b'0123456789' + + cache.set("key1", val) + cache.set("key2", val) + cache.set("key3", val) + cache.set("key4", val) + self.assertEqual(cache.size, 40) + cache.prune() + + self.assertEqual(cache.size, 10) + + def test_min_time_clear(self): + cache = Cache(self.__dir, 40, min_time=1) + val = b'0123456789' + + cache.set("key1", val) + cache.set("key2", val) + time.sleep(1) + cache.set("key3", val) + cache.set("key4", val) + + self.assertEqual(cache.size, 40) + cache.clear() + self.assertEqual(cache.size, 20) + time.sleep(1) + cache.clear() + self.assertEqual(cache.size, 0) + + def test_not_expired(self): + cache = Cache(self.__dir, 40, min_time=1) + val = b'0123456789' + cache.set("key1", val) + with self.assertRaises(ProtectedError): + cache.delete("key1") + time.sleep(1) + cache.delete("key1") + self.assertEqual(cache.size, 0) + + def test_missing_cache_file(self): + cache = Cache(self.__dir, 10, min_time=0) + val = b'0123456789' + os.remove(cache.set("key", val)) + + self.assertEqual(cache.size, 10) + self.assertFalse(cache.has("key")) + self.assertEqual(cache.size, 0) + + os.remove(cache.set("key", val)) + self.assertEqual(cache.size, 10) + with self.assertRaises(CacheMiss): + cache.get("key") + self.assertEqual(cache.size, 0) + + +if __name__ == '__main__': + unittest.main()