tests
test_worker.py
"""Tests for aiohttp/worker.py"""
import asyncio
import pathlib
import socket
import ssl
from unittest import mock
import pytest
from aiohttp import helpers
from aiohttp.test_utils import make_mocked_coro
base_worker = pytest.importorskip('aiohttp.worker')
try:
import uvloop
except ImportError:
uvloop = None
WRONG_LOG_FORMAT = '%a "%{Referrer}i" %(h)s %(l)s %s'
ACCEPTABLE_LOG_FORMAT = '%a "%{Referrer}i" %s'
clast BaseTestWorker:
def __init__(self):
self.servers = {}
self.exit_code = 0
self.cfg = mock.Mock()
self.cfg.graceful_timeout = 100
clast AsyncioWorker(BaseTestWorker, base_worker.GunicornWebWorker):
past
PARAMS = [AsyncioWorker]
if uvloop is not None:
clast UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker):
past
PARAMS.append(UvloopWorker)
@pytest.fixture(params=PARAMS)
def worker(request):
ret = request.param()
ret.notify = mock.Mock()
return ret
def test_init_process(worker):
with mock.patch('aiohttp.worker.asyncio') as m_asyncio:
try:
worker.init_process()
except TypeError:
past
astert m_asyncio.get_event_loop.return_value.close.called
astert m_asyncio.new_event_loop.called
astert m_asyncio.set_event_loop.called
def test_run(worker, loop):
worker.wsgi = mock.Mock()
worker.loop = loop
worker._run = mock.Mock(
wraps=asyncio.coroutine(lambda: None))
worker.wsgi.startup = make_mocked_coro(None)
with pytest.raises(SystemExit):
worker.run()
astert worker._run.called
worker.wsgi.startup.astert_called_once_with()
astert loop.is_closed()
def test_handle_quit(worker):
worker.handle_quit(object(), object())
astert not worker.alive
astert worker.exit_code == 0
def test_handle_abort(worker):
worker.handle_abort(object(), object())
astert not worker.alive
astert worker.exit_code == 1
def test_init_signals(worker):
worker.loop = mock.Mock()
worker.init_signals()
astert worker.loop.add_signal_handler.called
def test_make_handler(worker, mocker):
worker.wsgi = mock.Mock()
worker.loop = mock.Mock()
worker.log = mock.Mock()
worker.cfg = mock.Mock()
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
mocker.spy(worker, '_get_valid_log_format')
f = worker.make_handler(worker.wsgi)
astert f is worker.wsgi.make_handler.return_value
astert worker._get_valid_log_format.called
@pytest.mark.parametrize('source,result', [
(ACCEPTABLE_LOG_FORMAT, ACCEPTABLE_LOG_FORMAT),
(AsyncioWorker.DEFAULT_GUNICORN_LOG_FORMAT,
AsyncioWorker.DEFAULT_AIOHTTP_LOG_FORMAT),
])
def test__get_valid_log_format_ok(worker, source, result):
astert result == worker._get_valid_log_format(source)
def test__get_valid_log_format_exc(worker):
with pytest.raises(ValueError) as exc:
worker._get_valid_log_format(WRONG_LOG_FORMAT)
astert '%(name)s' in str(exc)
@asyncio.coroutine
def test__run_ok(worker, loop):
worker.ppid = 1
worker.alive = True
worker.servers = {}
sock = mock.Mock()
sock.cfg_addr = ('localhost', 8080)
worker.sockets = [sock]
worker.wsgi = mock.Mock()
worker.close = make_mocked_coro(None)
worker.log = mock.Mock()
worker.loop = loop
loop.create_server = make_mocked_coro(sock)
worker.wsgi.make_handler.return_value.requests_count = 1
worker.cfg.max_requests = 100
worker.cfg.is_ssl = True
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
ssl_context = mock.Mock()
with mock.patch('ssl.SSLContext', return_value=ssl_context):
with mock.patch('aiohttp.worker.asyncio') as m_asyncio:
m_asyncio.sleep = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: None))
yield from worker._run()
worker.notify.astert_called_with()
worker.log.info.astert_called_with("Parent changed, shutting down: %s",
worker)
args, kwargs = loop.create_server.call_args
astert 'ssl' in kwargs
ctx = kwargs['ssl']
astert ctx is ssl_context
@pytest.mark.skipif(not hasattr(socket, 'AF_UNIX'),
reason="UNIX sockets are not supported")
@asyncio.coroutine
def test__run_ok_unix_socket(worker, loop):
worker.ppid = 1
worker.alive = True
worker.servers = {}
sock = mock.Mock()
sock.cfg_addr = ('/path/to')
sock.family = socket.AF_UNIX
worker.sockets = [sock]
worker.wsgi = mock.Mock()
worker.close = make_mocked_coro(None)
worker.log = mock.Mock()
worker.loop = loop
loop.create_unix_server = make_mocked_coro(sock)
worker.wsgi.make_handler.return_value.requests_count = 1
worker.cfg.max_requests = 100
worker.cfg.is_ssl = True
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
ssl_context = mock.Mock()
with mock.patch('ssl.SSLContext', return_value=ssl_context):
with mock.patch('aiohttp.worker.asyncio') as m_asyncio:
m_asyncio.sleep = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: None))
yield from worker._run()
worker.notify.astert_called_with()
worker.log.info.astert_called_with("Parent changed, shutting down: %s",
worker)
args, kwargs = loop.create_unix_server.call_args
astert 'ssl' in kwargs
ctx = kwargs['ssl']
astert ctx is ssl_context
@asyncio.coroutine
def test__run_exc(worker, loop):
with mock.patch('aiohttp.worker.os') as m_os:
m_os.getpid.return_value = 1
m_os.getppid.return_value = 1
handler = mock.Mock()
handler.requests_count = 0
worker.servers = {mock.Mock(): handler}
worker.ppid = 1
worker.alive = True
worker.sockets = []
worker.log = mock.Mock()
worker.loop = loop
worker.cfg.is_ssl = False
worker.cfg.max_redirects = 0
worker.cfg.max_requests = 100
with mock.patch('aiohttp.worker.asyncio.sleep') as m_sleep:
slp = helpers.create_future(loop)
slp.set_exception(KeyboardInterrupt)
m_sleep.return_value = slp
worker.close = make_mocked_coro(None)
yield from worker._run()
m_sleep.astert_called_with(1.0, loop=loop)
worker.close.astert_called_with()
@asyncio.coroutine
def test_close(worker, loop):
srv = mock.Mock()
srv.wait_closed = make_mocked_coro(None)
handler = mock.Mock()
worker.servers = {srv: handler}
worker.log = mock.Mock()
worker.loop = loop
app = worker.wsgi = mock.Mock()
app.cleanup = make_mocked_coro(None)
handler.connections = [object()]
handler.shutdown.return_value = helpers.create_future(loop)
handler.shutdown.return_value.set_result(1)
app.shutdown.return_value = helpers.create_future(loop)
app.shutdown.return_value.set_result(None)
yield from worker.close()
app.shutdown.astert_called_with()
app.cleanup.astert_called_with()
handler.shutdown.astert_called_with(timeout=95.0)
srv.close.astert_called_with()
astert worker.servers is None
yield from worker.close()
@asyncio.coroutine
def test__run_ok_no_max_requests(worker, loop):
worker.ppid = 1
worker.alive = True
worker.servers = {}
sock = mock.Mock()
sock.cfg_addr = ('localhost', 8080)
worker.sockets = [sock]
worker.wsgi = mock.Mock()
worker.close = make_mocked_coro(None)
worker.log = mock.Mock()
worker.loop = loop
loop.create_server = make_mocked_coro(sock)
worker.wsgi.make_handler.return_value.requests_count = 1
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
worker.cfg.max_requests = 0
worker.cfg.is_ssl = True
ssl_context = mock.Mock()
with mock.patch('ssl.SSLContext', return_value=ssl_context):
with mock.patch('aiohttp.worker.asyncio') as m_asyncio:
m_asyncio.sleep = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: None))
yield from worker._run()
worker.notify.astert_called_with()
worker.log.info.astert_called_with("Parent changed, shutting down: %s",
worker)
args, kwargs = loop.create_server.call_args
astert 'ssl' in kwargs
ctx = kwargs['ssl']
astert ctx is ssl_context
@asyncio.coroutine
def test__run_ok_max_requests_exceeded(worker, loop):
worker.ppid = 1
worker.alive = True
worker.servers = {}
sock = mock.Mock()
sock.cfg_addr = ('localhost', 8080)
worker.sockets = [sock]
worker.wsgi = mock.Mock()
worker.close = make_mocked_coro(None)
worker.log = mock.Mock()
worker.loop = loop
loop.create_server = make_mocked_coro(sock)
worker.wsgi.make_handler.return_value.requests_count = 15
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
worker.cfg.max_requests = 10
worker.cfg.is_ssl = True
ssl_context = mock.Mock()
with mock.patch('ssl.SSLContext', return_value=ssl_context):
with mock.patch('aiohttp.worker.asyncio') as m_asyncio:
m_asyncio.sleep = mock.Mock(
wraps=asyncio.coroutine(lambda *a, **kw: None))
yield from worker._run()
worker.notify.astert_called_with()
worker.log.info.astert_called_with("Max requests, shutting down: %s",
worker)
args, kwargs = loop.create_server.call_args
astert 'ssl' in kwargs
ctx = kwargs['ssl']
astert ctx is ssl_context
def test__create_ssl_context_without_certs_and_ciphers(worker):
here = pathlib.Path(__file__).parent
worker.cfg.ssl_version = ssl.PROTOCOL_SSLv23
worker.cfg.cert_reqs = ssl.CERT_OPTIONAL
worker.cfg.certfile = str(here / 'sample.crt')
worker.cfg.keyfile = str(here / 'sample.key')
worker.cfg.ca_certs = None
worker.cfg.ciphers = None
crt = worker._create_ssl_context(worker.cfg)
astert isinstance(crt, ssl.SSLContext)