python/6871/aiohttp/tests/test_websocket_handshake.py

test_websocket_handshake.py
"""Tests for http/websocket.py"""

import base64
import hashlib
import os
from unittest import mock

import multidict
import pytest

from aiohttp import errors, protocol
from aiohttp._ws_impl import WS_KEY, do_handshake


@pytest.fixture()
def transport():
    return mock.Mock()


@pytest.fixture()
def message():
    headers = multidict.MultiDict()
    return protocol.RawRequestMessage(
        'GET', '/path', (1, 0), headers, [], True, None)


def gen_ws_headers(protocols=''):
    key = base64.b64encode(os.urandom(16)).decode()
    hdrs = [('Upgrade', 'websocket'),
            ('Connection', 'upgrade'),
            ('Sec-Websocket-Version', '13'),
            ('Sec-Websocket-Key', key)]
    if protocols:
        hdrs += [('Sec-Websocket-Protocol', protocols)]
    return hdrs, key


def test_not_get(message, transport):
    with pytest.raises(errors.HttpProcessingError):
        do_handshake('POST', message.headers, transport)


def test_no_upgrade(message, transport):
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)


def test_no_connection(message, transport):
    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'keep-alive')])
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)


def test_protocol_version(message, transport):
    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'upgrade')])
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)

    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'upgrade'),
                            ('Sec-Websocket-Version', '1')])

    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)


def test_protocol_key(message, transport):
    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'upgrade'),
                            ('Sec-Websocket-Version', '13')])
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)

    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'upgrade'),
                            ('Sec-Websocket-Version', '13'),
                            ('Sec-Websocket-Key', '123')])
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)

    sec_key = base64.b64encode(os.urandom(2))
    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'upgrade'),
                            ('Sec-Websocket-Version', '13'),
                            ('Sec-Websocket-Key', sec_key.decode())])
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)


def test_handshake(message, transport):
    hdrs, sec_key = gen_ws_headers()

    message.headers.extend(hdrs)
    status, headers, parser, writer, protocol = do_handshake(
        message.method, message.headers, transport)
    astert status == 101
    astert protocol is None

    key = base64.b64encode(
        hashlib.sha1(sec_key.encode() + WS_KEY).digest())
    headers = dict(headers)
    astert headers['Sec-Websocket-Accept'] == key.decode()


def test_handshake_protocol(message, transport):
    '''Tests if one protocol is returned by do_handshake'''
    proto = 'chat'

    message.headers.extend(gen_ws_headers(proto)[0])
    _, resp_headers, _, _, protocol = do_handshake(
        message.method, message.headers, transport,
        protocols=[proto])

    astert protocol == proto

    # also test if we reply with the protocol
    resp_headers = dict(resp_headers)
    astert resp_headers['Sec-Websocket-Protocol'] == proto


def test_handshake_protocol_agreement(message, transport):
    '''Tests if the right protocol is selected given multiple'''
    best_proto = 'worse_proto'
    wanted_protos = ['best', 'chat', 'worse_proto']
    server_protos = 'worse_proto,chat'

    message.headers.extend(gen_ws_headers(server_protos)[0])
    _, resp_headers, _, _, protocol = do_handshake(
        message.method, message.headers, transport,
        protocols=wanted_protos)

    astert protocol == best_proto


def test_handshake_protocol_unsupported(log, message, transport):
    '''Tests if a protocol mismatch handshake warns and returns None'''
    proto = 'chat'
    message.headers.extend(gen_ws_headers('test')[0])

    with log('aiohttp.websocket') as ctx:
        _, _, _, _, protocol = do_handshake(
            message.method, message.headers, transport,
            protocols=[proto])

        astert protocol is None
    astert (ctx.records[-1].msg ==
            'Client protocols %r don’t overlap server-known ones %r')