46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
import tempfile
|
|
import unittest
|
|
import os
|
|
from json import loads
|
|
from ITPlanning.config import DefaultConfig
|
|
from ITPlanning.db import init_database, release_database
|
|
from ITPlanning.app import create_app
|
|
|
|
|
|
class TestBase(unittest.TestCase):
|
|
def setUp(self):
|
|
self.__db = tempfile.mkstemp()
|
|
self.config = DefaultConfig()
|
|
self.config.BASE["database_uri"] = "sqlite:///" + self.__db[1]
|
|
self.config.TESTING = True
|
|
self.__app = create_app(self.config)
|
|
self.client = self.__app.test_client()
|
|
|
|
def app_context(self, *args, **kwargs):
|
|
return self.__app.app_context(*args, **kwargs)
|
|
|
|
def request_context(self, *args, **kwargs):
|
|
return self.__app.test_request_context(*args, **kwargs)
|
|
|
|
def tearDown(self):
|
|
release_database()
|
|
os.close(self.__db[0])
|
|
os.remove(self.__db[1])
|
|
|
|
|
|
class APITestBase(TestBase):
|
|
def make_request(self, endpoint, return_code, args={}, method="get"):
|
|
if not isinstance(args, dict):
|
|
raise TypeError("'args', expecting a dict, got " + type(args).__name__)
|
|
uri = "/api/v1/{}".format(endpoint)
|
|
if method == "get":
|
|
rv = self.client.get(uri, query_string=args, follow_redirects=True)
|
|
if method == "post":
|
|
rv = self.client.post(uri, data=args, follow_redirects=True)
|
|
if return_code == 200:
|
|
data = loads(rv.data)
|
|
else:
|
|
data = None
|
|
self.assertEqual(rv.status_code, return_code)
|
|
return rv, data
|