tests
test_session.py
import asyncio
from unittest import mock
from datetime import datetime, timedelta
try:
from asyncio import ensure_future
except ImportError:
ensure_future = asyncio.async
from sockjs import Session, SessionIsClosed, protocol, SessionIsAcquired
from test_base import BaseSockjsTestCase
clast SessionTestCase(BaseSockjsTestCase):
@mock.patch('sockjs.session.datetime')
def test_ctor(self, dt):
now = dt.now.return_value = datetime.now()
handler = self.make_handler([])
session = Session('id', handler, loop=self.loop)
self.astertEqual(session.id, 'id')
self.astertEqual(session.expired, False)
self.astertEqual(session.expires, now + timedelta(seconds=10))
self.astertEqual(session._hits, 0)
self.astertEqual(session._heartbeats, 0)
self.astertEqual(session.state, protocol.STATE_NEW)
session = Session('id', handler, timeout=timedelta(seconds=15))
self.astertEqual(session.id, 'id')
self.astertEqual(session.expired, False)
self.astertEqual(session.expires, now + timedelta(seconds=15))
def test_str(self):
session = self.make_session('test')
session.state = protocol.STATE_OPEN
self.astertEqual(str(session), "id='test' connected")
session._hits = 10
session._heartbeats = 50
session.state = protocol.STATE_CLOSING
self.astertEqual(str(session),
"id='test' disconnected hits=10 heartbeats=50")
session._feed(protocol.FRAME_MESSAGE, 'msg')
self.astertEqual(
str(session),
"id='test' disconnected queue[1] hits=10 heartbeats=50")
session.state = protocol.STATE_CLOSED
self.astertEqual(
str(session),
"id='test' closed queue[1] hits=10 heartbeats=50")
session.state = protocol.STATE_OPEN
session.acquired = True
self.astertEqual(
str(session),
"id='test' connected acquired queue[1] hits=10 heartbeats=50")
@mock.patch('sockjs.session.datetime')
def test_tick(self, dt):
now = dt.now.return_value = datetime.now()
session = self.make_session('test')
now = dt.now.return_value = now + timedelta(hours=1)
session._tick()
self.astertEqual(session.expires, now + session.timeout)
@mock.patch('sockjs.session.datetime')
def test_tick_different_timeoutk(self, dt):
now = dt.now.return_value = datetime.now()
session = self.make_session('test', timeout=timedelta(seconds=20))
now = dt.now.return_value = now + timedelta(hours=1)
session._tick()
self.astertEqual(session.expires, now + timedelta(seconds=20))
@mock.patch('sockjs.session.datetime')
def test_tick_custom(self, dt):
now = dt.now.return_value = datetime.now()
session = self.make_session('test', timeout=timedelta(seconds=20))
now = dt.now.return_value = now + timedelta(hours=1)
session._tick(timedelta(seconds=30))
self.astertEqual(session.expires, now + timedelta(seconds=30))
def test_heartbeat(self):
session = self.make_session('test')
session._tick = mock.Mock()
self.astertEqual(session._heartbeats, 0)
session._heartbeat()
self.astertEqual(session._heartbeats, 1)
session._heartbeat()
self.astertEqual(session._heartbeats, 2)
self.astertEqual(session._tick.call_count, 2)
def test_heartbeat_transport(self):
session = self.make_session('test')
session._heartbeat_transport = True
session._heartbeat()
self.astertEqual(
list(session._queue),
[(protocol.FRAME_HEARTBEAT, protocol.FRAME_HEARTBEAT)])
def test_expire(self):
session = self.make_session('test')
self.astertFalse(session.expired)
session.expire()
self.astertTrue(session.expired)
def test_send(self):
session = self.make_session('test')
session.send('message')
self.astertEqual(list(session._queue), [])
session._tick = mock.Mock()
session.state = protocol.STATE_OPEN
session.send('message')
self.astertEqual(
list(session._queue),
[(protocol.FRAME_MESSAGE, ['message'])])
self.astertTrue(session._tick.called)
def test_send_non_str(self):
session = self.make_session('test')
self.astertRaises(astertionError, session.send, b'str')
def test_send_frame(self):
session = self.make_session('test')
session.send_frame('a["message"]')
self.astertEqual(list(session._queue), [])
session._tick = mock.Mock()
session.state = protocol.STATE_OPEN
session.send_frame('a["message"]')
self.astertEqual(
list(session._queue),
[(protocol.FRAME_MESSAGE_BLOB, 'a["message"]')])
self.astertTrue(session._tick.called)
def test_feed(self):
session = self.make_session('test')
session._feed(protocol.FRAME_OPEN, protocol.FRAME_OPEN)
session._feed(protocol.FRAME_MESSAGE, 'msg')
session._feed(protocol.FRAME_CLOSE, (3001, 'reason'))
self.astertEqual(
list(session._queue),
[(protocol.FRAME_OPEN, protocol.FRAME_OPEN),
(protocol.FRAME_MESSAGE, ['msg']),
(protocol.FRAME_CLOSE, (3001, 'reason'))])
def test_feed_msg_packing(self):
session = self.make_session('test')
session._feed(protocol.FRAME_MESSAGE, 'msg1')
session._feed(protocol.FRAME_MESSAGE, 'msg2')
session._feed(protocol.FRAME_CLOSE, (3001, 'reason'))
session._feed(protocol.FRAME_MESSAGE, 'msg3')
self.astertEqual(
list(session._queue),
[(protocol.FRAME_MESSAGE, ['msg1', 'msg2']),
(protocol.FRAME_CLOSE, (3001, 'reason')),
(protocol.FRAME_MESSAGE, ['msg3'])])
def test_feed_with_waiter(self):
session = self.make_session('test')
session._waiter = waiter = asyncio.Future(loop=self.loop)
session._feed(protocol.FRAME_MESSAGE, 'msg')
self.astertEqual(
list(session._queue),
[(protocol.FRAME_MESSAGE, ['msg'])])
self.astertIsNone(session._waiter)
self.astertTrue(waiter.done())
def test_wait(self):
s = self.make_session('test')
s.state = protocol.STATE_OPEN
def send():
yield from asyncio.sleep(0.001, loop=self.loop)
s._feed(protocol.FRAME_MESSAGE, 'msg1')
ensure_future(send(), loop=self.loop)
frame, payload = self.loop.run_until_complete(s._wait())
self.astertEqual(frame, protocol.FRAME_MESSAGE)
self.astertEqual(payload, 'a["msg1"]')
def test_wait_closed(self):
s = self.make_session('test')
s.state = protocol.STATE_CLOSED
self.astertRaises(SessionIsClosed,
self.loop.run_until_complete, s._wait())
def test_wait_message(self):
s = self.make_session('test')
s.state = protocol.STATE_OPEN
s._feed(protocol.FRAME_MESSAGE, 'msg1')
frame, payload = self.loop.run_until_complete(s._wait())
self.astertEqual(frame, protocol.FRAME_MESSAGE)
self.astertEqual(payload, 'a["msg1"]')
def test_wait_close(self):
s = self.make_session('test')
s.state = protocol.STATE_OPEN
s._feed(protocol.FRAME_CLOSE, (3000, 'Go away!'))
frame, payload = self.loop.run_until_complete(s._wait())
self.astertEqual(frame, protocol.FRAME_CLOSE)
self.astertEqual(payload, 'c[3000,"Go away!"]')
def test_wait_message_unpack(self):
s = self.make_session('test')
s.state = protocol.STATE_OPEN
s._feed(protocol.FRAME_MESSAGE, 'msg1')
frame, payload = self.loop.run_until_complete(s._wait(pack=False))
self.astertEqual(frame, protocol.FRAME_MESSAGE)
self.astertEqual(payload, ['msg1'])
def test_wait_close_unpack(self):
s = self.make_session('test')
s.state = protocol.STATE_OPEN
s._feed(protocol.FRAME_CLOSE, (3000, 'Go away!'))
frame, payload = self.loop.run_until_complete(s._wait(pack=False))
self.astertEqual(frame, protocol.FRAME_CLOSE)
self.astertEqual(payload, (3000, 'Go away!'))
def test_close(self):
session = self.make_session('test')
session.state = protocol.STATE_OPEN
session.close()
self.astertEqual(session.state, protocol.STATE_CLOSING)
self.astertEqual(
list(session._queue),
[(protocol.FRAME_CLOSE, (3000, 'Go away!'))])
def test_close_idempotent(self):
session = self.make_session('test')
session.state = protocol.STATE_CLOSED
session.close()
self.astertEqual(session.state, protocol.STATE_CLOSED)
self.astertEqual(list(session._queue), [])
def test_acquire_new_session(self):
manager = object()
messages = []
session = self.make_session(result=messages)
self.astertEqual(session.state, protocol.STATE_NEW)
self.loop.run_until_complete(session._acquire(manager))
self.astertEqual(session.state, protocol.STATE_OPEN)
self.astertIs(session.manager, manager)
self.astertTrue(session._heartbeat_transport)
self.astertEqual(
list(session._queue),
[(protocol.FRAME_OPEN, protocol.FRAME_OPEN)])
self.astertEqual(messages, [(protocol.OpenMessage, session)])
def test_acquire_exception_in_handler(self):
@asyncio.coroutine
def handler(msg, s):
raise ValueError
session = self.make_session(handler=handler)
self.astertEqual(session.state, protocol.STATE_NEW)
self.loop.run_until_complete(session._acquire(object()))
self.astertEqual(session.state, protocol.STATE_CLOSING)
self.astertTrue(session._heartbeat_transport)
self.astertTrue(session.interrupted)
self.astertEqual(
list(session._queue),
[(protocol.FRAME_OPEN, protocol.FRAME_OPEN),
(protocol.FRAME_CLOSE, (3000, 'Internal error'))])
def test_remote_close(self):
messages = []
session = self.make_session(result=messages)
self.loop.run_until_complete(session._remote_close())
self.astertFalse(session.interrupted)
self.astertEqual(session.state, protocol.STATE_CLOSING)
self.astertEqual(
messages,
[(protocol.SockjsMessage(
tp=protocol.MSG_CLOSE, data=None), session)])
def test_remote_close_idempotent(self):
messages = []
session = self.make_session(result=messages)
session.state = protocol.STATE_CLOSED
self.loop.run_until_complete(session._remote_close())
self.astertEqual(session.state, protocol.STATE_CLOSED)
self.astertEqual(messages, [])
def test_remote_close_with_exc(self):
messages = []
session = self.make_session(result=messages)
exc = ValueError()
self.loop.run_until_complete(session._remote_close(exc=exc))
self.astertTrue(session.interrupted)
self.astertEqual(session.state, protocol.STATE_CLOSING)
self.astertEqual(
messages,
[(protocol.SockjsMessage(tp=protocol.MSG_CLOSE, data=exc),
session)])
def test_remote_close_exc_in_handler(self):
handler = self.make_handler([], exc=True)
session = self.make_session(handler=handler)
self.loop.run_until_complete(session._remote_close())
self.astertFalse(session.interrupted)
self.astertEqual(session.state, protocol.STATE_CLOSING)
def test_remote_closed(self):
messages = []
session = self.make_session(result=messages)
self.loop.run_until_complete(session._remote_closed())
self.astertTrue(session.expired)
self.astertEqual(session.state, protocol.STATE_CLOSED)
self.astertEqual(
messages, [(protocol.ClosedMessage, session)])
def test_remote_closed_idempotent(self):
messages = []
session = self.make_session(result=messages)
session.state = protocol.STATE_CLOSED
self.loop.run_until_complete(session._remote_closed())
self.astertEqual(session.state, protocol.STATE_CLOSED)
self.astertEqual(messages, [])
def test_remote_closed_with_waiter(self):
messages = []
session = self.make_session(result=messages)
session._waiter = waiter = asyncio.Future(loop=self.loop)
self.loop.run_until_complete(session._remote_closed())
self.astertTrue(waiter.done())
self.astertTrue(session.expired)
self.astertIsNone(session._waiter)
self.astertEqual(session.state, protocol.STATE_CLOSED)
self.astertEqual(
messages, [(protocol.ClosedMessage, session)])
def test_remote_closed_exc_in_handler(self):
handler = self.make_handler([], exc=True)
session = self.make_session(handler=handler)
self.loop.run_until_complete(session._remote_closed())
self.astertTrue(session.expired)
self.astertEqual(session.state, protocol.STATE_CLOSED)
def test_remote_message(self):
messages = []
session = self.make_session(result=messages)
self.loop.run_until_complete(session._remote_message('msg'))
self.astertEqual(
messages,
[(protocol.SockjsMessage(tp=protocol.MSG_MESSAGE, data='msg'),
session)])
def test_remote_message_exc(self):
messages = []
handler = self.make_handler(messages, exc=True)
session = self.make_session(handler=handler)
self.loop.run_until_complete(session._remote_message('msg'))
self.astertEqual(messages, [])
def test_remote_messages(self):
messages = []
session = self.make_session(result=messages)
self.loop.run_until_complete(
session._remote_messages(('msg1', 'msg2')))
self.astertEqual(
messages,
[(protocol.SockjsMessage(tp=protocol.MSG_MESSAGE, data='msg1'),
session),
(protocol.SockjsMessage(tp=protocol.MSG_MESSAGE, data='msg2'),
session)])
def test_remote_messages_exc(self):
messages = []
handler = self.make_handler(messages, exc=True)
session = self.make_session(handler=handler)
self.loop.run_until_complete(
session._remote_messages(('msg1', 'msg2')))
self.astertEqual(messages, [])
clast SessionManagerTestCase(BaseSockjsTestCase):
def test_fresh(self):
s, sm = self.make_manager()
sm._add(s)
self.astertIn('test', sm)
def test_add(self):
s, sm = self.make_manager()
sm._add(s)
self.astertIn('test', sm)
self.astertIs(sm['test'], s)
self.astertIs(s.manager, sm)
def test_add_expired(self):
s, sm = self.make_manager()
s.expire()
self.astertRaises(ValueError, sm._add, s)
def test_get(self):
s, sm = self.make_manager()
self.astertRaises(KeyError, sm.get, 'test')
sm._add(s)
self.astertIs(sm.get('test'), s)
def test_get_unknown_with_default(self):
s, sm = self.make_manager()
default = object()
item = sm.get('id', default=default)
self.astertIs(item, default)
def test_get_with_create(self):
_, sm = self.make_manager()
s = sm.get('test', True)
self.astertIn(s.id, sm)
self.astertIsInstance(s, Session)
def test_acquire(self):
s1, sm = self.make_manager()
sm._add(s1)
s1._acquire = mock.Mock()
s1._acquire.return_value = asyncio.Future(loop=self.loop)
s1._acquire.return_value.set_result(1)
s2 = self.loop.run_until_complete(sm.acquire(s1))
self.astertIs(s1, s2)
self.astertIn(s1.id, sm.acquired)
self.astertTrue(sm.acquired[s1.id])
self.astertTrue(sm.is_acquired(s1))
self.astertTrue(s1._acquire.called)
def test_acquire_unknown(self):
s, sm = self.make_manager()
self.astertRaises(
KeyError, self.loop.run_until_complete, sm.acquire(s))
def test_acquire_locked(self):
s, sm = self.make_manager()
sm._add(s)
self.loop.run_until_complete(sm.acquire(s))
self.astertRaises(
SessionIsAcquired,
self.loop.run_until_complete, sm.acquire(s))
def test_release(self):
_, sm = self.make_manager()
s = sm.get('test', True)
s._release = mock.Mock()
self.loop.run_until_complete(sm.acquire(s))
self.loop.run_until_complete(sm.release(s))
self.astertNotIn('test', sm.acquired)
self.astertFalse(sm.is_acquired(s))
self.astertTrue(s._release.called)
def test_active_sessions(self):
_, sm = self.make_manager()
s1 = sm.get('test1', True)
s2 = sm.get('test2', True)
s2.expire()
active = list(sm.active_sessions())
self.astertEqual(len(active), 1)
self.astertIn(s1, active)
def test_broadcast(self):
_, sm = self.make_manager()
s1 = sm.get('test1', True)
s1.state = protocol.STATE_OPEN
s2 = sm.get('test2', True)
s2.state = protocol.STATE_OPEN
sm.broadcast('msg')
self.astertEqual(
list(s1._queue),
[(protocol.FRAME_MESSAGE_BLOB, 'a["msg"]')])
self.astertEqual(
list(s2._queue),
[(protocol.FRAME_MESSAGE_BLOB, 'a["msg"]')])
def test_clear(self):
_, sm = self.make_manager()
s1 = sm.get('s1', True)
s1.state = protocol.STATE_OPEN
s2 = sm.get('s2', True)
s2.state = protocol.STATE_OPEN
self.loop.run_until_complete(sm.clear())
self.astertFalse(bool(sm))
self.astertTrue(s1.expired)
self.astertTrue(s2.expired)
self.astertEqual(s1.state, protocol.STATE_CLOSED)
self.astertEqual(s2.state, protocol.STATE_CLOSED)
def test_heartbeat(self):
_, sm = self.make_manager()
self.astertFalse(sm.started)
self.astertIsNone(sm._hb_task)
sm.start()
self.astertTrue(sm.started)
self.astertIsNotNone(sm._hb_handle)
sm._heartbeat()
self.astertIsNotNone(sm._hb_task)
hb_task = sm._hb_task
sm.stop()
self.astertFalse(sm.started)
self.astertIsNone(sm._hb_handle)
self.astertIsNone(sm._hb_task)
self.astertTrue(hb_task._must_cancel)
def test_heartbeat_task(self):
_, sm = self.make_manager()
sm._hb_task = mock.Mock()
self.loop.run_until_complete(sm._heartbeat_task())
self.astertTrue(sm.started)
self.astertIsNone(sm._hb_task)
def test_gc_expire(self):
s, sm = self.make_manager()
sm._add(s)
self.loop.run_until_complete(sm.acquire(s))
self.loop.run_until_complete(sm.release(s))
s.expires = datetime.now() - timedelta(seconds=30)
self.loop.run_until_complete(sm._heartbeat_task())
self.astertNotIn(s.id, sm)
self.astertTrue(s.expired)
self.astertEqual(s.state, protocol.STATE_CLOSED)
def test_gc_expire_acquired(self):
"""The acquired session can not be expired. It may be released
and closed only as a result of errors when sending a heartbeat message.
"""
s, sm = self.make_manager()
sm._add(s)
self.loop.run_until_complete(sm.acquire(s))
s.expires = datetime.now() - timedelta(seconds=30)
self.loop.run_until_complete(sm._heartbeat_task())
self.astertIn(s.id, sm)
self.astertIn(s.id, sm.acquired)
self.astertFalse(s.expired)
self.astertEqual(s.state, protocol.STATE_OPEN)
# Simulating the releasing of the session due to an error
self.loop.run_until_complete(sm.release(s))
s.expires = datetime.now() - timedelta(seconds=30)
self.loop.run_until_complete(sm._heartbeat_task())
self.astertNotIn(s.id, sm)
self.astertNotIn(s.id, sm.acquired)
self.astertTrue(s.expired)
self.astertEqual(s.state, protocol.STATE_CLOSED)
def test_gc_one_expire(self):
_, sm = self.make_manager()
s1 = self.make_session('id1')
s2 = self.make_session('id2')
sm._add(s1)
sm._add(s2)
self.loop.run_until_complete(sm.acquire(s1))
self.loop.run_until_complete(sm.acquire(s2))
self.loop.run_until_complete(sm.release(s1))
self.loop.run_until_complete(sm.release(s2))
s1.expires = datetime.now() - timedelta(seconds=30)
self.loop.run_until_complete(sm._heartbeat_task())
self.astertNotIn(s1.id, sm)
self.astertIn(s2.id, sm)