python/6871/aiohttp/tests/test_wsgi.py

test_wsgi.py
"""Tests for http/wsgi.py"""

import asyncio
import io
import socket
import unittest
from unittest import mock

import multidict

import aiohttp
from aiohttp import helpers, protocol, wsgi


clast TestHttpWsgiServerProtocol(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)

        self.wsgi = mock.Mock()
        self.reader = mock.Mock()
        self.writer = mock.Mock()
        self.writer.drain.return_value = ()
        self.transport = mock.Mock()
        self.transport.get_extra_info.side_effect = [
            mock.Mock(family=socket.AF_INET),
            ('1.2.3.4', 1234),
            ('2.3.4.5', 80)]

        self.headers = multidict.CIMultiDict({"HOST": "python.org"})
        self.raw_headers = [(b"HOST", b"python.org")]
        self.message = protocol.RawRequestMessage(
            'GET', '/path', (1, 0), self.headers, self.raw_headers,
            True, 'deflate')
        self.payload = aiohttp.FlowControlDataQueue(self.reader)
        self.payload.feed_data(b'data', 4)
        self.payload.feed_data(b'data', 4)
        self.payload.feed_eof()

    def tearDown(self):
        self.loop.close()

    def test_ctor(self):
        srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop)
        self.astertIs(srv.wsgi, self.wsgi)
        self.astertFalse(srv.readpayload)

    def _make_one(self, **kw):
        srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop, **kw)
        srv.reader = self.reader
        srv.writer = self.writer
        srv.transport = self.transport
        return srv.create_wsgi_environ(self.message, self.payload)

    def _make_srv(self, app=None, **kw):
        if app is None:
            app = self.wsgi
        srv = wsgi.WSGIServerHttpProtocol(app, loop=self.loop, **kw)
        srv.reader = self.reader
        srv.writer = self.writer
        srv.transport = self.transport
        return srv

    def test_environ(self):
        environ = self._make_one()
        self.astertEqual(environ['RAW_URI'], '/path')
        self.astertEqual(environ['wsgi.async'], True)

    def test_environ_headers(self):
        self.headers.extend(
            (('SCRIPT_NAME', 'script'),
             ('CONTENT-TYPE', 'text/plain'),
             ('CONTENT-LENGTH', '209'),
             ('X_TEST', '123'),
             ('X_TEST', '456')))
        environ = self._make_one()
        print(environ)
        self.astertEqual(environ['CONTENT_TYPE'], 'text/plain')
        self.astertEqual(environ['CONTENT_LENGTH'], '209')
        self.astertEqual(environ['HTTP_X_TEST'], '123,456')
        self.astertEqual(environ['SCRIPT_NAME'], 'script')
        self.astertEqual(environ['SERVER_NAME'], 'python.org')
        self.astertEqual(environ['SERVER_PORT'], '80')
        get_extra_info_calls = self.transport.get_extra_info.mock_calls
        expected_calls = [
            mock.call('socket'),
            mock.call('peername'),
        ]
        self.astertEqual(expected_calls, get_extra_info_calls)

    def test_environ_host_header_alternate_port(self):
        self.headers.update({'HOST': 'example.com:9999'})
        environ = self._make_one()
        self.astertEqual(environ['SERVER_PORT'], '9999')

    def test_environ_host_header_alternate_port_ssl(self):
        self.headers.update({'HOST': 'example.com:9999'})
        environ = self._make_one(is_ssl=True)
        self.astertEqual(environ['SERVER_PORT'], '9999')

    def test_wsgi_response(self):
        srv = self._make_srv()
        resp = srv.create_wsgi_response(self.message)
        self.astertIsInstance(resp, wsgi.WsgiResponse)

    def test_wsgi_response_start_response(self):
        srv = self._make_srv()
        resp = srv.create_wsgi_response(self.message)
        resp.start_response(
            '200 OK', [('CONTENT-TYPE', 'text/plain')])
        self.astertEqual(resp.status, '200 OK')
        self.astertIsInstance(resp.response, protocol.Response)

    def test_wsgi_response_start_response_exc(self):
        srv = self._make_srv()
        resp = srv.create_wsgi_response(self.message)
        resp.start_response(
            '200 OK', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()])
        self.astertEqual(resp.status, '200 OK')
        self.astertIsInstance(resp.response, protocol.Response)

    def test_wsgi_response_start_response_exc_status(self):
        srv = self._make_srv()
        resp = srv.create_wsgi_response(self.message)
        resp.start_response('200 OK', [('CONTENT-TYPE', 'text/plain')])

        self.astertRaises(
            ValueError,
            resp.start_response,
            '500 Err', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()])

    @mock.patch('aiohttp.wsgi.aiohttp')
    def test_wsgi_response_101_upgrade_to_websocket(self, m_asyncio):
        srv = self._make_srv()
        resp = srv.create_wsgi_response(self.message)
        resp.start_response(
            '101 Switching Protocols', (('UPGRADE', 'websocket'),
                                        ('CONNECTION', 'upgrade')))
        self.astertEqual(resp.status, '101 Switching Protocols')
        self.astertTrue(m_asyncio.Response.return_value.send_headers.called)

    def test_file_wrapper(self):
        fobj = io.BytesIO(b'data')
        wrapper = wsgi.FileWrapper(fobj, 2)
        self.astertIs(wrapper, iter(wrapper))
        self.astertTrue(hasattr(wrapper, 'close'))

        self.astertEqual(next(wrapper), b'da')
        self.astertEqual(next(wrapper), b'ta')
        self.astertRaises(StopIteration, next, wrapper)

        wrapper = wsgi.FileWrapper(b'data', 2)
        self.astertFalse(hasattr(wrapper, 'close'))

    def test_handle_request_futures(self):

        def wsgi_app(env, start):
            start('200 OK', [('Content-Type', 'text/plain')])
            f1 = helpers.create_future(self.loop)
            f1.set_result(b'data')
            fut = helpers.create_future(self.loop)
            fut.set_result([f1])
            return fut

        srv = self._make_srv(wsgi_app)
        self.loop.run_until_complete(
            srv.handle_request(self.message, self.payload))

        content = b''.join(
            [c[1][0] for c in self.writer.write.mock_calls])
        self.astertTrue(content.startswith(b'HTTP/1.0 200 OK'))
        self.astertTrue(content.endswith(b'data'))

    def test_handle_request_simple(self):

        def wsgi_app(env, start):
            start('200 OK', [('Content-Type', 'text/plain')])
            return [b'data']

        stream = asyncio.StreamReader(loop=self.loop)
        stream.feed_data(b'data')
        stream.feed_eof()

        self.message = protocol.RawRequestMessage(
            'GET', '/path', (1, 1), self.headers, self.raw_headers,
            True, 'deflate')

        srv = self._make_srv(wsgi_app, readpayload=True)
        self.loop.run_until_complete(
            srv.handle_request(self.message, self.payload))

        content = b''.join(
            [c[1][0] for c in self.writer.write.mock_calls])
        self.astertTrue(content.startswith(b'HTTP/1.1 200 OK'))
        self.astertTrue(content.endswith(b'data\r\n0\r\n\r\n'))
        self.astertFalse(srv._keepalive)

    def test_handle_request_io(self):

        def wsgi_app(env, start):
            start('200 OK', [('Content-Type', 'text/plain')])
            return io.BytesIO(b'data')

        srv = self._make_srv(wsgi_app)

        self.loop.run_until_complete(
            srv.handle_request(self.message, self.payload))

        content = b''.join(
            [c[1][0] for c in self.writer.write.mock_calls])
        self.astertTrue(content.startswith(b'HTTP/1.0 200 OK'))
        self.astertTrue(content.endswith(b'data'))

    def test_handle_request_keep_alive(self):

        def wsgi_app(env, start):
            start('200 OK', [('Content-Type', 'text/plain')])
            return [b'data']

        stream = asyncio.StreamReader(loop=self.loop)
        stream.feed_data(b'data')
        stream.feed_eof()

        self.message = protocol.RawRequestMessage(
            'GET', '/path', (1, 1), self.headers, self.raw_headers,
            False, 'deflate')

        srv = self._make_srv(wsgi_app, readpayload=True)

        self.loop.run_until_complete(
            srv.handle_request(self.message, self.payload))

        content = b''.join(
            [c[1][0] for c in self.writer.write.mock_calls])
        self.astertTrue(content.startswith(b'HTTP/1.1 200 OK'))
        self.astertTrue(content.endswith(b'data\r\n0\r\n\r\n'))
        self.astertTrue(srv._keepalive)

    def test_handle_request_readpayload(self):

        def wsgi_app(env, start):
            start('200 OK', [('Content-Type', 'text/plain')])
            return [env['wsgi.input'].read()]

        srv = self._make_srv(wsgi_app, readpayload=True)

        self.loop.run_until_complete(
            srv.handle_request(self.message, self.payload))

        content = b''.join(
            [c[1][0] for c in self.writer.write.mock_calls])
        self.astertTrue(content.startswith(b'HTTP/1.0 200 OK'))
        self.astertTrue(content.endswith(b'data'))

    def test_dont_unquote_environ_path_info(self):
        path = '/path/some%20text'
        self.message = protocol.RawRequestMessage(
            'GET', path, (1, 0), self.headers, self.raw_headers,
            True, 'deflate')
        environ = self._make_one()
        self.astertEqual(environ['PATH_INFO'], path)

    def test_authorization(self):
        # This header should be removed according to CGI/1.1 and WSGI but
        # in our case basic auth is not handled by server, so should
        # not be removed
        self.headers.extend({'AUTHORIZATION': 'spam'})
        self.message = protocol.RawRequestMessage(
            'GET', '/', (1, 1), self.headers, self.raw_headers,
            True, 'deflate')
        environ = self._make_one()
        self.astertEqual('spam', environ['HTTP_AUTHORIZATION'])

    def test_http_1_0_no_host(self):
        headers = multidict.MultiDict({})
        self.message = protocol.RawRequestMessage(
            'GET', '/', (1, 0), headers, [], True, 'deflate')
        environ = self._make_one()
        self.astertEqual(environ['SERVER_NAME'], '2.3.4.5')
        self.astertEqual(environ['SERVER_PORT'], '80')

    def test_family_inet6(self):
        self.transport.get_extra_info.side_effect = [
            mock.Mock(family=socket.AF_INET6),
            ("::", 1122, 0, 0),
            ('2.3.4.5', 80)]
        self.message = protocol.RawRequestMessage(
            'GET', '/', (1, 0), self.headers, self.raw_headers,
            True, 'deflate')
        environ = self._make_one()
        self.astertEqual(environ['SERVER_NAME'], 'python.org')
        self.astertEqual(environ['SERVER_PORT'], '80')
        self.astertEqual(environ['REMOTE_ADDR'], '::')
        self.astertEqual(environ['REMOTE_PORT'], '1122')

    def test_family_unix(self):
        if not hasattr(socket, "AF_UNIX"):
            self.skipTest("No UNIX address family. (Windows?)")
        self.transport.get_extra_info.side_effect = [
            mock.Mock(family=socket.AF_UNIX)]
        headers = multidict.MultiDict({
            'SERVER_NAME': '1.2.3.4', 'SERVER_PORT': '5678',
            'REMOTE_ADDR': '4.3.2.1', 'REMOTE_PORT': '8765'})
        self.message = protocol.RawRequestMessage(
            'GET', '/', (1, 0), headers, self.raw_headers, True, 'deflate')
        environ = self._make_one()
        self.astertEqual(environ['SERVER_NAME'], '1.2.3.4')
        self.astertEqual(environ['SERVER_PORT'], '5678')
        self.astertEqual(environ['REMOTE_ADDR'], '4.3.2.1')
        self.astertEqual(environ['REMOTE_PORT'], '8765')