python/326/sockjs/tests/test_session.py

test_session.py
import asyncio
from unittest import mock
from datetime import datetime, timedelta


try:
    from asyncio import ensure_future
except ImportError:
    ensure_future = asyncio.async

from sockjs import Session, SessionIsClosed, protocol, SessionIsAcquired
from test_base import BaseSockjsTestCase


clast SessionTestCase(BaseSockjsTestCase):

    @mock.patch('sockjs.session.datetime')
    def test_ctor(self, dt):
        now = dt.now.return_value = datetime.now()

        handler = self.make_handler([])
        session = Session('id', handler, loop=self.loop)

        self.astertEqual(session.id, 'id')
        self.astertEqual(session.expired, False)
        self.astertEqual(session.expires, now + timedelta(seconds=10))

        self.astertEqual(session._hits, 0)
        self.astertEqual(session._heartbeats, 0)
        self.astertEqual(session.state, protocol.STATE_NEW)

        session = Session('id', handler, timeout=timedelta(seconds=15))

        self.astertEqual(session.id, 'id')
        self.astertEqual(session.expired, False)
        self.astertEqual(session.expires, now + timedelta(seconds=15))

    def test_str(self):
        session = self.make_session('test')
        session.state = protocol.STATE_OPEN

        self.astertEqual(str(session), "id='test' connected")

        session._hits = 10
        session._heartbeats = 50
        session.state = protocol.STATE_CLOSING
        self.astertEqual(str(session),
                         "id='test' disconnected hits=10 heartbeats=50")

        session._feed(protocol.FRAME_MESSAGE, 'msg')
        self.astertEqual(
            str(session),
            "id='test' disconnected queue[1] hits=10 heartbeats=50")

        session.state = protocol.STATE_CLOSED
        self.astertEqual(
            str(session),
            "id='test' closed queue[1] hits=10 heartbeats=50")

        session.state = protocol.STATE_OPEN
        session.acquired = True
        self.astertEqual(
            str(session),
            "id='test' connected acquired queue[1] hits=10 heartbeats=50")

    @mock.patch('sockjs.session.datetime')
    def test_tick(self, dt):
        now = dt.now.return_value = datetime.now()
        session = self.make_session('test')

        now = dt.now.return_value = now + timedelta(hours=1)
        session._tick()
        self.astertEqual(session.expires, now + session.timeout)

    @mock.patch('sockjs.session.datetime')
    def test_tick_different_timeoutk(self, dt):
        now = dt.now.return_value = datetime.now()
        session = self.make_session('test', timeout=timedelta(seconds=20))

        now = dt.now.return_value = now + timedelta(hours=1)
        session._tick()
        self.astertEqual(session.expires, now + timedelta(seconds=20))

    @mock.patch('sockjs.session.datetime')
    def test_tick_custom(self, dt):
        now = dt.now.return_value = datetime.now()
        session = self.make_session('test', timeout=timedelta(seconds=20))

        now = dt.now.return_value = now + timedelta(hours=1)
        session._tick(timedelta(seconds=30))
        self.astertEqual(session.expires, now + timedelta(seconds=30))

    def test_heartbeat(self):
        session = self.make_session('test')
        session._tick = mock.Mock()
        self.astertEqual(session._heartbeats, 0)

        session._heartbeat()
        self.astertEqual(session._heartbeats, 1)
        session._heartbeat()
        self.astertEqual(session._heartbeats, 2)
        self.astertEqual(session._tick.call_count, 2)

    def test_heartbeat_transport(self):
        session = self.make_session('test')
        session._heartbeat_transport = True
        session._heartbeat()
        self.astertEqual(
            list(session._queue),
            [(protocol.FRAME_HEARTBEAT, protocol.FRAME_HEARTBEAT)])

    def test_expire(self):
        session = self.make_session('test')
        self.astertFalse(session.expired)

        session.expire()
        self.astertTrue(session.expired)

    def test_send(self):
        session = self.make_session('test')
        session.send('message')
        self.astertEqual(list(session._queue), [])

        session._tick = mock.Mock()
        session.state = protocol.STATE_OPEN
        session.send('message')

        self.astertEqual(
            list(session._queue),
            [(protocol.FRAME_MESSAGE, ['message'])])
        self.astertTrue(session._tick.called)

    def test_send_non_str(self):
        session = self.make_session('test')
        self.astertRaises(astertionError, session.send, b'str')

    def test_send_frame(self):
        session = self.make_session('test')
        session.send_frame('a["message"]')
        self.astertEqual(list(session._queue), [])

        session._tick = mock.Mock()
        session.state = protocol.STATE_OPEN
        session.send_frame('a["message"]')

        self.astertEqual(
            list(session._queue),
            [(protocol.FRAME_MESSAGE_BLOB, 'a["message"]')])
        self.astertTrue(session._tick.called)

    def test_feed(self):
        session = self.make_session('test')
        session._feed(protocol.FRAME_OPEN, protocol.FRAME_OPEN)
        session._feed(protocol.FRAME_MESSAGE, 'msg')
        session._feed(protocol.FRAME_CLOSE, (3001, 'reason'))

        self.astertEqual(
            list(session._queue),
            [(protocol.FRAME_OPEN, protocol.FRAME_OPEN),
             (protocol.FRAME_MESSAGE, ['msg']),
             (protocol.FRAME_CLOSE, (3001, 'reason'))])

    def test_feed_msg_packing(self):
        session = self.make_session('test')
        session._feed(protocol.FRAME_MESSAGE, 'msg1')
        session._feed(protocol.FRAME_MESSAGE, 'msg2')
        session._feed(protocol.FRAME_CLOSE, (3001, 'reason'))
        session._feed(protocol.FRAME_MESSAGE, 'msg3')

        self.astertEqual(
            list(session._queue),
            [(protocol.FRAME_MESSAGE, ['msg1', 'msg2']),
             (protocol.FRAME_CLOSE, (3001, 'reason')),
             (protocol.FRAME_MESSAGE, ['msg3'])])

    def test_feed_with_waiter(self):
        session = self.make_session('test')
        session._waiter = waiter = asyncio.Future(loop=self.loop)
        session._feed(protocol.FRAME_MESSAGE, 'msg')

        self.astertEqual(
            list(session._queue),
            [(protocol.FRAME_MESSAGE, ['msg'])])
        self.astertIsNone(session._waiter)
        self.astertTrue(waiter.done())

    def test_wait(self):
        s = self.make_session('test')
        s.state = protocol.STATE_OPEN

        def send():
            yield from asyncio.sleep(0.001, loop=self.loop)
            s._feed(protocol.FRAME_MESSAGE, 'msg1')

        ensure_future(send(), loop=self.loop)
        frame, payload = self.loop.run_until_complete(s._wait())
        self.astertEqual(frame, protocol.FRAME_MESSAGE)
        self.astertEqual(payload, 'a["msg1"]')

    def test_wait_closed(self):
        s = self.make_session('test')
        s.state = protocol.STATE_CLOSED
        self.astertRaises(SessionIsClosed,
                          self.loop.run_until_complete, s._wait())

    def test_wait_message(self):
        s = self.make_session('test')
        s.state = protocol.STATE_OPEN
        s._feed(protocol.FRAME_MESSAGE, 'msg1')
        frame, payload = self.loop.run_until_complete(s._wait())
        self.astertEqual(frame, protocol.FRAME_MESSAGE)
        self.astertEqual(payload, 'a["msg1"]')

    def test_wait_close(self):
        s = self.make_session('test')
        s.state = protocol.STATE_OPEN
        s._feed(protocol.FRAME_CLOSE, (3000, 'Go away!'))
        frame, payload = self.loop.run_until_complete(s._wait())
        self.astertEqual(frame, protocol.FRAME_CLOSE)
        self.astertEqual(payload, 'c[3000,"Go away!"]')

    def test_wait_message_unpack(self):
        s = self.make_session('test')
        s.state = protocol.STATE_OPEN
        s._feed(protocol.FRAME_MESSAGE, 'msg1')
        frame, payload = self.loop.run_until_complete(s._wait(pack=False))
        self.astertEqual(frame, protocol.FRAME_MESSAGE)
        self.astertEqual(payload, ['msg1'])

    def test_wait_close_unpack(self):
        s = self.make_session('test')
        s.state = protocol.STATE_OPEN
        s._feed(protocol.FRAME_CLOSE, (3000, 'Go away!'))
        frame, payload = self.loop.run_until_complete(s._wait(pack=False))
        self.astertEqual(frame, protocol.FRAME_CLOSE)
        self.astertEqual(payload, (3000, 'Go away!'))

    def test_close(self):
        session = self.make_session('test')
        session.state = protocol.STATE_OPEN
        session.close()
        self.astertEqual(session.state, protocol.STATE_CLOSING)
        self.astertEqual(
            list(session._queue),
            [(protocol.FRAME_CLOSE, (3000, 'Go away!'))])

    def test_close_idempotent(self):
        session = self.make_session('test')
        session.state = protocol.STATE_CLOSED
        session.close()
        self.astertEqual(session.state, protocol.STATE_CLOSED)
        self.astertEqual(list(session._queue), [])

    def test_acquire_new_session(self):
        manager = object()
        messages = []

        session = self.make_session(result=messages)
        self.astertEqual(session.state, protocol.STATE_NEW)

        self.loop.run_until_complete(session._acquire(manager))
        self.astertEqual(session.state, protocol.STATE_OPEN)
        self.astertIs(session.manager, manager)
        self.astertTrue(session._heartbeat_transport)
        self.astertEqual(
            list(session._queue),
            [(protocol.FRAME_OPEN, protocol.FRAME_OPEN)])
        self.astertEqual(messages, [(protocol.OpenMessage, session)])

    def test_acquire_exception_in_handler(self):

        @asyncio.coroutine
        def handler(msg, s):
            raise ValueError

        session = self.make_session(handler=handler)
        self.astertEqual(session.state, protocol.STATE_NEW)

        self.loop.run_until_complete(session._acquire(object()))
        self.astertEqual(session.state, protocol.STATE_CLOSING)
        self.astertTrue(session._heartbeat_transport)
        self.astertTrue(session.interrupted)
        self.astertEqual(
            list(session._queue),
            [(protocol.FRAME_OPEN, protocol.FRAME_OPEN),
             (protocol.FRAME_CLOSE, (3000, 'Internal error'))])

    def test_remote_close(self):
        messages = []
        session = self.make_session(result=messages)

        self.loop.run_until_complete(session._remote_close())
        self.astertFalse(session.interrupted)
        self.astertEqual(session.state, protocol.STATE_CLOSING)
        self.astertEqual(
            messages,
            [(protocol.SockjsMessage(
                tp=protocol.MSG_CLOSE, data=None), session)])

    def test_remote_close_idempotent(self):
        messages = []
        session = self.make_session(result=messages)
        session.state = protocol.STATE_CLOSED

        self.loop.run_until_complete(session._remote_close())
        self.astertEqual(session.state, protocol.STATE_CLOSED)
        self.astertEqual(messages, [])

    def test_remote_close_with_exc(self):
        messages = []
        session = self.make_session(result=messages)

        exc = ValueError()
        self.loop.run_until_complete(session._remote_close(exc=exc))
        self.astertTrue(session.interrupted)
        self.astertEqual(session.state, protocol.STATE_CLOSING)
        self.astertEqual(
            messages,
            [(protocol.SockjsMessage(tp=protocol.MSG_CLOSE, data=exc),
              session)])

    def test_remote_close_exc_in_handler(self):
        handler = self.make_handler([], exc=True)
        session = self.make_session(handler=handler)

        self.loop.run_until_complete(session._remote_close())
        self.astertFalse(session.interrupted)
        self.astertEqual(session.state, protocol.STATE_CLOSING)

    def test_remote_closed(self):
        messages = []
        session = self.make_session(result=messages)

        self.loop.run_until_complete(session._remote_closed())
        self.astertTrue(session.expired)
        self.astertEqual(session.state, protocol.STATE_CLOSED)
        self.astertEqual(
            messages, [(protocol.ClosedMessage, session)])

    def test_remote_closed_idempotent(self):
        messages = []
        session = self.make_session(result=messages)
        session.state = protocol.STATE_CLOSED

        self.loop.run_until_complete(session._remote_closed())
        self.astertEqual(session.state, protocol.STATE_CLOSED)
        self.astertEqual(messages, [])

    def test_remote_closed_with_waiter(self):
        messages = []
        session = self.make_session(result=messages)
        session._waiter = waiter = asyncio.Future(loop=self.loop)

        self.loop.run_until_complete(session._remote_closed())
        self.astertTrue(waiter.done())
        self.astertTrue(session.expired)
        self.astertIsNone(session._waiter)
        self.astertEqual(session.state, protocol.STATE_CLOSED)
        self.astertEqual(
            messages, [(protocol.ClosedMessage, session)])

    def test_remote_closed_exc_in_handler(self):
        handler = self.make_handler([], exc=True)
        session = self.make_session(handler=handler)

        self.loop.run_until_complete(session._remote_closed())
        self.astertTrue(session.expired)
        self.astertEqual(session.state, protocol.STATE_CLOSED)

    def test_remote_message(self):
        messages = []
        session = self.make_session(result=messages)

        self.loop.run_until_complete(session._remote_message('msg'))
        self.astertEqual(
            messages,
            [(protocol.SockjsMessage(tp=protocol.MSG_MESSAGE, data='msg'),
              session)])

    def test_remote_message_exc(self):
        messages = []
        handler = self.make_handler(messages, exc=True)
        session = self.make_session(handler=handler)

        self.loop.run_until_complete(session._remote_message('msg'))
        self.astertEqual(messages, [])

    def test_remote_messages(self):
        messages = []
        session = self.make_session(result=messages)

        self.loop.run_until_complete(
            session._remote_messages(('msg1', 'msg2')))
        self.astertEqual(
            messages,
            [(protocol.SockjsMessage(tp=protocol.MSG_MESSAGE, data='msg1'),
              session),
             (protocol.SockjsMessage(tp=protocol.MSG_MESSAGE, data='msg2'),
              session)])

    def test_remote_messages_exc(self):
        messages = []
        handler = self.make_handler(messages, exc=True)
        session = self.make_session(handler=handler)

        self.loop.run_until_complete(
            session._remote_messages(('msg1', 'msg2')))
        self.astertEqual(messages, [])


clast SessionManagerTestCase(BaseSockjsTestCase):

    def test_fresh(self):
        s, sm = self.make_manager()
        sm._add(s)
        self.astertIn('test', sm)

    def test_add(self):
        s, sm = self.make_manager()

        sm._add(s)
        self.astertIn('test', sm)
        self.astertIs(sm['test'], s)
        self.astertIs(s.manager, sm)

    def test_add_expired(self):
        s, sm = self.make_manager()
        s.expire()

        self.astertRaises(ValueError, sm._add, s)

    def test_get(self):
        s, sm = self.make_manager()
        self.astertRaises(KeyError, sm.get, 'test')

        sm._add(s)
        self.astertIs(sm.get('test'), s)

    def test_get_unknown_with_default(self):
        s, sm = self.make_manager()
        default = object()

        item = sm.get('id', default=default)
        self.astertIs(item, default)

    def test_get_with_create(self):
        _, sm = self.make_manager()

        s = sm.get('test', True)
        self.astertIn(s.id, sm)
        self.astertIsInstance(s, Session)

    def test_acquire(self):
        s1, sm = self.make_manager()
        sm._add(s1)
        s1._acquire = mock.Mock()
        s1._acquire.return_value = asyncio.Future(loop=self.loop)
        s1._acquire.return_value.set_result(1)

        s2 = self.loop.run_until_complete(sm.acquire(s1))

        self.astertIs(s1, s2)
        self.astertIn(s1.id, sm.acquired)
        self.astertTrue(sm.acquired[s1.id])
        self.astertTrue(sm.is_acquired(s1))
        self.astertTrue(s1._acquire.called)

    def test_acquire_unknown(self):
        s, sm = self.make_manager()
        self.astertRaises(
            KeyError, self.loop.run_until_complete, sm.acquire(s))

    def test_acquire_locked(self):
        s, sm = self.make_manager()
        sm._add(s)
        self.loop.run_until_complete(sm.acquire(s))

        self.astertRaises(
            SessionIsAcquired,
            self.loop.run_until_complete, sm.acquire(s))

    def test_release(self):
        _, sm = self.make_manager()
        s = sm.get('test', True)
        s._release = mock.Mock()

        self.loop.run_until_complete(sm.acquire(s))
        self.loop.run_until_complete(sm.release(s))

        self.astertNotIn('test', sm.acquired)
        self.astertFalse(sm.is_acquired(s))
        self.astertTrue(s._release.called)

    def test_active_sessions(self):
        _, sm = self.make_manager()

        s1 = sm.get('test1', True)
        s2 = sm.get('test2', True)
        s2.expire()

        active = list(sm.active_sessions())
        self.astertEqual(len(active), 1)
        self.astertIn(s1, active)

    def test_broadcast(self):
        _, sm = self.make_manager()

        s1 = sm.get('test1', True)
        s1.state = protocol.STATE_OPEN
        s2 = sm.get('test2', True)
        s2.state = protocol.STATE_OPEN
        sm.broadcast('msg')

        self.astertEqual(
            list(s1._queue),
            [(protocol.FRAME_MESSAGE_BLOB, 'a["msg"]')])
        self.astertEqual(
            list(s2._queue),
            [(protocol.FRAME_MESSAGE_BLOB, 'a["msg"]')])

    def test_clear(self):
        _, sm = self.make_manager()

        s1 = sm.get('s1', True)
        s1.state = protocol.STATE_OPEN
        s2 = sm.get('s2', True)
        s2.state = protocol.STATE_OPEN

        self.loop.run_until_complete(sm.clear())

        self.astertFalse(bool(sm))
        self.astertTrue(s1.expired)
        self.astertTrue(s2.expired)
        self.astertEqual(s1.state, protocol.STATE_CLOSED)
        self.astertEqual(s2.state, protocol.STATE_CLOSED)

    def test_heartbeat(self):
        _, sm = self.make_manager()
        self.astertFalse(sm.started)
        self.astertIsNone(sm._hb_task)

        sm.start()
        self.astertTrue(sm.started)
        self.astertIsNotNone(sm._hb_handle)

        sm._heartbeat()
        self.astertIsNotNone(sm._hb_task)

        hb_task = sm._hb_task

        sm.stop()
        self.astertFalse(sm.started)
        self.astertIsNone(sm._hb_handle)
        self.astertIsNone(sm._hb_task)
        self.astertTrue(hb_task._must_cancel)

    def test_heartbeat_task(self):
        _, sm = self.make_manager()
        sm._hb_task = mock.Mock()

        self.loop.run_until_complete(sm._heartbeat_task())
        self.astertTrue(sm.started)
        self.astertIsNone(sm._hb_task)

    def test_gc_expire(self):
        s, sm = self.make_manager()

        sm._add(s)
        self.loop.run_until_complete(sm.acquire(s))
        self.loop.run_until_complete(sm.release(s))

        s.expires = datetime.now() - timedelta(seconds=30)

        self.loop.run_until_complete(sm._heartbeat_task())
        self.astertNotIn(s.id, sm)
        self.astertTrue(s.expired)
        self.astertEqual(s.state, protocol.STATE_CLOSED)

    def test_gc_expire_acquired(self):
        """The acquired session can not be expired. It may be released
        and closed only as a result of errors when sending a heartbeat message.
        """
        s, sm = self.make_manager()

        sm._add(s)
        self.loop.run_until_complete(sm.acquire(s))

        s.expires = datetime.now() - timedelta(seconds=30)

        self.loop.run_until_complete(sm._heartbeat_task())
        self.astertIn(s.id, sm)
        self.astertIn(s.id, sm.acquired)
        self.astertFalse(s.expired)
        self.astertEqual(s.state, protocol.STATE_OPEN)

        # Simulating the releasing of the session due to an error
        self.loop.run_until_complete(sm.release(s))
        s.expires = datetime.now() - timedelta(seconds=30)
        self.loop.run_until_complete(sm._heartbeat_task())
        self.astertNotIn(s.id, sm)
        self.astertNotIn(s.id, sm.acquired)
        self.astertTrue(s.expired)
        self.astertEqual(s.state, protocol.STATE_CLOSED)

    def test_gc_one_expire(self):
        _, sm = self.make_manager()
        s1 = self.make_session('id1')
        s2 = self.make_session('id2')

        sm._add(s1)
        sm._add(s2)
        self.loop.run_until_complete(sm.acquire(s1))
        self.loop.run_until_complete(sm.acquire(s2))
        self.loop.run_until_complete(sm.release(s1))
        self.loop.run_until_complete(sm.release(s2))

        s1.expires = datetime.now() - timedelta(seconds=30)

        self.loop.run_until_complete(sm._heartbeat_task())
        self.astertNotIn(s1.id, sm)
        self.astertIn(s2.id, sm)