asyncio.coroutine

Here are the examples of the python api asyncio.coroutine taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

199 Examples 7

Example 1

Project: Agent
Source File: test_agent.py
View license
async def test_coroutines():

    @agent.async_generator
    def inner_gen():
        yield from asyncio.coroutine(lambda: 'Not sent')()
        yield (yield from asyncio.coroutine(lambda: 'yield_from')())
        yield asyncio.coroutine(lambda: 'yield')()

    gen = inner_gen()
    assert await agent.anext(gen) == 'yield_from'
    assert await (await agent.anext(gen)) == 'yield'

Example 2

Project: tanner
Source File: test_session_analyzer.py
View license
    def test_create_stats(self):
        @asyncio.coroutine
        def sess_get():
            return session

        @asyncio.coroutine
        def set_of_members(key):
            return set()

        @asyncio.coroutine
        def push_list():
            return ''

        redis_mock = mock.Mock()
        redis_mock.get = sess_get
        redis_mock.smembers_asset = set_of_members
        redis_mock.lpush = push_list
        stats = asyncio.get_event_loop().run_until_complete(self.handler.create_stats(self.session, redis_mock))
        self.assertEqual(stats['possible_owners'], ['attacker'])

Example 3

Project: tanner
Source File: test_session_analyzer.py
View license
    def test_create_stats(self):
        @asyncio.coroutine
        def sess_get():
            return session

        @asyncio.coroutine
        def set_of_members(key):
            return set()

        @asyncio.coroutine
        def push_list():
            return ''

        redis_mock = mock.Mock()
        redis_mock.get = sess_get
        redis_mock.smembers_asset = set_of_members
        redis_mock.lpush = push_list
        stats = asyncio.get_event_loop().run_until_complete(self.handler.create_stats(self.session, redis_mock))
        self.assertEqual(stats['possible_owners'], ['attacker'])

Example 4

Project: aiodocker
Source File: docker.py
View license
    @asyncio.coroutine
    def _websocket(self, url, **params):
        if not params:
            params = {
                'stdout': 1,
                'stderr': 1,
                'stream': 1
            }
        url = self._endpoint(url) + "?" + urllib.parse.urlencode(params)
        ws = yield from aiohttp.ws_connect(url, connector=self.connector)
        return ws

Example 5

Project: pycoinnet
Source File: Client.py
View license
    def __init__(self, network, host_port_q, should_download_block_f, block_chain_store,
                 blockchain_change_callback, server_port=9999):
        """
        network:
            a value from pycoinnet.helpers.networks
        host_port_q:
            a Queue that is being fed potential places to connect
        should_download_block_f:
            a function that accepting(block_hash, block_index) and returning a boolean
            indicating whether that block should be downloaded. Only used during fast-forward.
        block_chain_store:
            usually a BlockChainStore instance
        blockchain_change_callback:
            a callback that expects (blockchain, list_of_ops) that is invoked whenever the
            block chain is updated; blockchain is a BlockChain object and list_of_ops is a pair
            of tuples of the form (op, block_hash, block_index) where op is one of "add" or "remove",
            block_hash is a binary block hash, and block_index is an integer index number.
        """

        block_chain = BlockChain(did_lock_to_index_f=block_chain_store.did_lock_to_index)

        block_chain.preload_locked_blocks(block_chain_store.headers())

        block_chain.add_change_callback(block_chain_locker_callback)

        self.blockfetcher = Blockfetcher()
        self.inv_collector = InvCollector()

        self.block_store = TwoLevelDict()

        @asyncio.coroutine
        def _rotate(block_store):
            while True:
                block_store.rotate()
                yield from asyncio.sleep(1800)
        self.rotate_task = asyncio.Task(_rotate(self.block_store))

        self.blockhandler = BlockHandler(self.inv_collector, block_chain, self.block_store,
                                         should_download_f=should_download_block_f)

        block_chain.add_change_callback(blockchain_change_callback)

        self.fast_forward_add_peer = fast_forwarder_add_peer_f(block_chain)
        self.fetcher_task = asyncio.Task(new_block_fetcher(self.inv_collector, block_chain))

        self.nonce = int.from_bytes(os.urandom(8), byteorder="big")
        self.subversion = "/Notoshi/".encode("utf8")

        @asyncio.coroutine
        def run_peer(peer, fetcher, fast_forward_add_peer, blockfetcher, inv_collector, blockhandler):
            yield from asyncio.wait_for(peer.connection_made_future, timeout=None)
            version_parameters = version_data_for_peer(
                peer, local_port=(server_port or 0), last_block_index=block_chain.length(),
                nonce=self.nonce, subversion=self.subversion)
            version_data = yield from initial_handshake(peer, version_parameters)
            last_block_index = version_data["last_block_index"]
            fast_forward_add_peer(peer, last_block_index)
            blockfetcher.add_peer(peer, fetcher, last_block_index)
            inv_collector.add_peer(peer)
            blockhandler.add_peer(peer)

        def create_protocol_callback():
            peer = BitcoinPeerProtocol(network["MAGIC_HEADER"])
            install_pingpong_manager(peer)
            fetcher = Fetcher(peer)
            peer.add_task(run_peer(
                peer, fetcher, self.fast_forward_add_peer,
                self.blockfetcher, self.inv_collector, self.blockhandler))
            return peer

        self.connection_info_q = manage_connection_count(host_port_q, create_protocol_callback, 8)
        self.show_task = asyncio.Task(show_connection_info(self.connection_info_q))

        # listener
        @asyncio.coroutine
        def run_listener():
            abstract_server = None
            try:
                abstract_server = yield from asyncio.get_event_loop().create_server(
                    protocol_factory=create_protocol_callback, port=server_port)
                return abstract_server
            except OSError:
                logging.info("can't listen on port %d", server_port)

        if server_port:
            self.server_task = asyncio.Task(run_listener())

Example 6

Project: asyncssh
Source File: misc.py
View license
def async_context_manager(coro):
    """Decorator for methods returning asynchronous context managers

       This function can be used as a decorator for coroutines which
       return objects intended to be used as Python 3.5 asynchronous
       context managers. The object returned should implement __aenter__
       and __aexit__ methods to run when the async context is entered
       and exited.

       This wrapper also allows non-async context managers to be defined
       on the returned object, as well as the use of "await" or "yield
       from" on the function being decorated for backward compatibility
       with the API defined by older versions of AsyncSSH.

    """

    class AsyncContextManager:
        """Async context manager wrapper for Python 3.5 and later"""

        def __init__(self, coro):
            self._coro = coro
            self._result = None

        def __iter__(self):
            return (yield from self._coro)

        def __await__(self):
            return (yield from self._coro)

        @asyncio.coroutine
        def __aenter__(self):
            self._result = yield from self._coro
            return (yield from self._result.__aenter__())

        @asyncio.coroutine
        def __aexit__(self, *exc_info):
            yield from self._result.__aexit__(*exc_info)
            self._result = None

    @functools.wraps(coro)
    def coro_wrapper(*args, **kwargs):
        """Return an async context manager wrapper for this coroutine"""

        return AsyncContextManager(asyncio.coroutine(coro)(*args, **kwargs))

    if python35:
        return coro_wrapper
    else:
        return coro

Example 7

Project: aiofiles
Source File: test_concurrency.py
View license
@pytest.mark.asyncio
def test_slow_file(monkeypatch, unused_tcp_port):
    """Monkey patch open and file.read(), and assert the loop still works."""
    filename = join(dirname(__file__), '..', 'resources', 'multiline_file.txt')

    with open(filename, mode='rb') as f:
        contents = f.read()

    def new_open(*args, **kwargs):
        time.sleep(1)
        return open(*args, **kwargs)

    monkeypatch.setattr(aiofiles.threadpool, '_sync_open', value=new_open)

    @asyncio.coroutine
    def serve_file(_, writer):
        file = yield from aiofiles.threadpool.open(filename, mode='rb')
        try:
            while True:
                data = yield from file.read(1)
                if not data:
                    break
                writer.write(data)
                yield from writer.drain()
            yield from writer.drain()
        finally:
            writer.close()
            yield from file.close()

    @asyncio.coroutine
    def return_one(_, writer):
        writer.write(b'1')
        yield from writer.drain()
        writer.close()

    counter = 0

    @asyncio.coroutine
    def spam_client():
        nonlocal counter
        while True:
            r, w = yield from asyncio.open_connection('127.0.0.1', port=30001)
            assert (yield from r.read()) == b'1'
            counter += 1
            w.close()
            yield from asyncio.sleep(0.01)

    file_server = yield from asyncio.start_server(serve_file,
                                                  port=unused_tcp_port)
    spam_server = yield from asyncio.start_server(return_one, port=30001)

    spam_task = asyncio.async(spam_client())

    reader, writer = yield from asyncio.open_connection('127.0.0.1',
                                                        port=unused_tcp_port)

    actual_contents = yield from reader.read()
    writer.close()

    yield from asyncio.sleep(0)

    file_server.close()
    spam_server.close()

    yield from file_server.wait_closed()
    yield from spam_server.wait_closed()

    spam_task.cancel()

    assert actual_contents == contents
    assert counter > 40

Example 8

Project: pyzmq
Source File: _test_asyncio.py
View license
    def test_aiohttp(self):
        try:
            import aiohttp
        except ImportError:
            raise SkipTest("Requires aiohttp")
        from aiohttp import web
        
        zmq.asyncio.install()
        
        @asyncio.coroutine
        def echo(request):
            print(request.path)
            return web.Response(body=str(request).encode('utf8'))
        
        @asyncio.coroutine
        def server(loop):
            app = web.Application(loop=loop)
            app.router.add_route('GET', '/', echo)

            srv = yield from loop.create_server(app.make_handler(),
                                                '127.0.0.1', 8080)
            print("Server started at http://127.0.0.1:8080")
            return srv

        @asyncio.coroutine
        def client():
            push, pull = self.create_bound_pair(zmq.PUSH, zmq.PULL)
            
            res = yield from aiohttp.request('GET', 'http://127.0.0.1:8080/')
            text = yield from res.text()
            yield from push.send(text.encode('utf8'))
            rcvd = yield from pull.recv()
            self.assertEqual(rcvd.decode('utf8'), text)

        loop = asyncio.get_event_loop()
        loop.run_until_complete(server(loop))
        print("servered")
        loop.run_until_complete(client())

Example 9

Project: aiohttp-cors
Source File: test_real_browser.py
View license
    @asyncio.coroutine
    def start_servers(self):
        test_page_path = pathlib.Path(__file__).with_name("test_page.html")

        @asyncio.coroutine
        def handle_test_page(request: web.Request) -> web.StreamResponse:
            with test_page_path.open("r", encoding="utf-8") as f:
                return web.Response(
                    text=f.read(),
                    headers={hdrs.CONTENT_TYPE: "text/html"})

        @asyncio.coroutine
        def handle_no_cors(request: web.Request) -> web.StreamResponse:
            return web.Response(
                text="""{"type": "no_cors.json"}""",
                headers={hdrs.CONTENT_TYPE: "application/json"})

        @asyncio.coroutine
        def handle_resource(request: web.Request) -> web.StreamResponse:
            return web.Response(
                text="""{"type": "resource"}""",
                headers={hdrs.CONTENT_TYPE: "application/json"})

        @asyncio.coroutine
        def handle_servers_addresses(
                request: web.Request) -> web.StreamResponse:
            servers_addresses = \
                {name: descr.url for name, descr in self.servers.items()}
            return web.Response(
                text=json.dumps(servers_addresses))

        # For most resources:
        # "origin" server has no CORS configuration.
        # "allowing" server explicitly allows CORS requests to "origin" server.
        # "denying" server explicitly disallows CORS requests to "origin"
        # server.
        # "free_for_all" server allows CORS requests for all origins server.
        # "no_cors" server has no CORS configuration.
        cors_server_names = ["allowing", "denying", "free_for_all"]
        server_names = cors_server_names + ["origin", "no_cors"]

        for server_name in server_names:
            assert server_name not in self.servers
            self.servers[server_name] = _ServerDescr()

        # Create applications.
        for server_descr in self.servers.values():
            server_descr.app = web.Application()

        # Server test page from origin server.
        self.servers["origin"].app.router.add_route(
            "GET", "/", handle_test_page)
        self.servers["origin"].app.router.add_route(
            "GET", "/servers_addresses", handle_servers_addresses)

        # Add routes to all servers.
        for server_name in server_names:
            app = self.servers[server_name].app
            app.router.add_route("GET", "/no_cors.json", handle_no_cors)
            app.router.add_route("GET", "/cors_resource", handle_resource,
                                 name="cors_resource")

        # Start servers.
        for server_name, server_descr in self.servers.items():
            handler = server_descr.app.make_handler()
            server = yield from create_server(handler, self.loop)
            server_descr.handler = handler
            server_descr.server = server

            hostaddr, port = server.sockets[0].getsockname()
            server_descr.url = "http://{host}:{port}".format(
                host=hostaddr, port=port)

            self._logger.info("Started server '%s' at '%s'",
                              server_name, server_descr.url)

        cors_default_configs = {
            "allowing": {
                self.servers["origin"].url:
                    ResourceOptions(
                        allow_credentials=True, expose_headers="*",
                        allow_headers="*")
            },
            "denying": {
                # Allow requests to other than "origin" server.
                self.servers["allowing"].url:
                    ResourceOptions(
                        allow_credentials=True, expose_headers="*",
                        allow_headers="*")
            },
            "free_for_all": {
                "*":
                    ResourceOptions(
                        allow_credentials=True, expose_headers="*",
                        allow_headers="*")
            },
        }

        # Configure CORS.
        for server_name, server_descr in self.servers.items():
            default_config = cors_default_configs.get(server_name)
            if default_config is None:
                continue
            server_descr.cors = setup(
                server_descr.app, defaults=default_config)

        # Add CORS routes.
        for server_name in cors_server_names:
            server_descr = self.servers[server_name]
            # TODO: Starting from aiohttp 0.21.0 name-based access returns
            # Resource, not Route. Manually get route while aiohttp_cors
            # doesn't support configuring for Resources.
            resource = server_descr.app.router["cors_resource"]
            route = next(iter(resource))
            if self.use_resources:
                server_descr.cors.add(resource)
                server_descr.cors.add(route)

            else:
                server_descr.cors.add(route)

Example 10

View license
@asyncio.coroutine
def test_forget(loop, test_client):

    @asyncio.coroutine
    def index(request):
        return web.Response()

    @asyncio.coroutine
    def login(request):
        response = web.HTTPFound(location='/')
        yield from remember(request, response, 'Andrew')
        return response

    @asyncio.coroutine
    def logout(request):
        response = web.HTTPFound(location='/')
        yield from forget(request, response)
        return response

    app = web.Application(loop=loop)
    _setup(app, CookiesIdentityPolicy(), Autz())
    app.router.add_route('GET', '/', index)
    app.router.add_route('POST', '/login', login)
    app.router.add_route('POST', '/logout', logout)
    client = yield from test_client(app)
    resp = yield from client.post('/login')
    assert 200 == resp.status
    assert resp.url.endswith('/')
    cookies = client.session.cookie_jar.filter_cookies(
        client.make_url('/'))
    assert 'Andrew' == cookies['AIOHTTP_SECURITY'].value
    yield from resp.release()
    resp = yield from client.post('/logout')
    assert 200 == resp.status
    assert resp.url.endswith('/')
    cookies = client.session.cookie_jar.filter_cookies(
        client.make_url('/'))
    assert 'AIOHTTP_SECURITY' not in cookies
    yield from resp.release()

Example 11

View license
@asyncio.coroutine
def test_forget(make_app, test_client):

    @asyncio.coroutine
    def index(request):
        session = yield from get_session(request)
        return web.HTTPOk(text=session.get('AIOHTTP_SECURITY', ''))

    @asyncio.coroutine
    def login(request):
        response = web.HTTPFound(location='/')
        yield from remember(request, response, 'Andrew')
        return response

    @asyncio.coroutine
    def logout(request):
        response = web.HTTPFound('/')
        yield from forget(request, response)
        return response

    app = make_app()
    app.router.add_route('GET', '/', index)
    app.router.add_route('POST', '/login', login)
    app.router.add_route('POST', '/logout', logout)

    client = yield from test_client(app)

    resp = yield from client.post('/login')
    assert 200 == resp.status
    assert resp.url.endswith('/')
    txt = yield from resp.text()
    assert 'Andrew' == txt
    yield from resp.release()

    resp = yield from client.post('/logout')
    assert 200 == resp.status
    assert resp.url.endswith('/')
    txt = yield from resp.text()
    assert '' == txt
    yield from resp.release()

Example 12

Project: aiokafka
Source File: test_fetcher.py
View license
    @run_until_complete
    def test_proc_fetch_request(self):
        client = AIOKafkaClient(
            loop=self.loop,
            bootstrap_servers=[])
        subscriptions = SubscriptionState('latest')
        fetcher = Fetcher(client, subscriptions, loop=self.loop)

        tp = TopicPartition('test', 0)
        tp_info = (tp.topic, [(tp.partition, 155, 100000)])
        req = FetchRequest(
            -1,  # replica_id
            100, 100, [tp_info])

        client.ready = mock.MagicMock()
        client.ready.side_effect = asyncio.coroutine(lambda a: True)
        client.force_metadata_update = mock.MagicMock()
        client.force_metadata_update.side_effect = asyncio.coroutine(
            lambda: False)
        client.send = mock.MagicMock()
        msg = Message(b"test msg")
        msg._encode_self()
        client.send.side_effect = asyncio.coroutine(
            lambda n, r: FetchResponse(
                [('test', [(0, 0, 9, [(4, 10, msg)])])]))
        fetcher._in_flight.add(0)
        needs_wake_up = yield from fetcher._proc_fetch_request(0, req)
        self.assertEqual(needs_wake_up, False)

        state = TopicPartitionState()
        state.seek(0)
        subscriptions.assignment[tp] = state
        subscriptions.needs_partition_assignment = False
        fetcher._in_flight.add(0)
        needs_wake_up = yield from fetcher._proc_fetch_request(0, req)
        self.assertEqual(needs_wake_up, True)
        buf = fetcher._records[tp]
        self.assertEqual(buf.getone(), None)  # invalid offset, msg is ignored

        state.seek(4)
        fetcher._in_flight.add(0)
        fetcher._records.clear()
        needs_wake_up = yield from fetcher._proc_fetch_request(0, req)
        self.assertEqual(needs_wake_up, True)
        buf = fetcher._records[tp]
        self.assertEqual(buf.getone().value, b"test msg")

        # error -> no partition found
        client.send.side_effect = asyncio.coroutine(
            lambda n, r: FetchResponse(
                [('test', [(0, 3, 9, [(4, 10, msg)])])]))
        fetcher._in_flight.add(0)
        fetcher._records.clear()
        needs_wake_up = yield from fetcher._proc_fetch_request(0, req)
        self.assertEqual(needs_wake_up, False)

        # error -> topic auth failed
        client.send.side_effect = asyncio.coroutine(
            lambda n, r: FetchResponse(
                [('test', [(0, 29, 9, [(4, 10, msg)])])]))
        fetcher._in_flight.add(0)
        fetcher._records.clear()
        needs_wake_up = yield from fetcher._proc_fetch_request(0, req)
        self.assertEqual(needs_wake_up, True)
        with self.assertRaises(TopicAuthorizationFailedError):
            yield from fetcher.next_record([])

        # error -> unknown
        client.send.side_effect = asyncio.coroutine(
            lambda n, r: FetchResponse(
                [('test', [(0, -1, 9, [(4, 10, msg)])])]))
        fetcher._in_flight.add(0)
        fetcher._records.clear()
        needs_wake_up = yield from fetcher._proc_fetch_request(0, req)
        self.assertEqual(needs_wake_up, False)

        # error -> offset out of range
        client.send.side_effect = asyncio.coroutine(
            lambda n, r: FetchResponse(
                [('test', [(0, 1, 9, [(4, 10, msg)])])]))
        fetcher._in_flight.add(0)
        fetcher._records.clear()
        needs_wake_up = yield from fetcher._proc_fetch_request(0, req)
        self.assertEqual(needs_wake_up, False)
        self.assertEqual(state.is_fetchable(), False)

        state.seek(4)
        subscriptions._default_offset_reset_strategy = OffsetResetStrategy.NONE
        client.send.side_effect = asyncio.coroutine(
            lambda n, r: FetchResponse(
                [('test', [(0, 1, 9, [(4, 10, msg)])])]))
        fetcher._in_flight.add(0)
        fetcher._records.clear()
        needs_wake_up = yield from fetcher._proc_fetch_request(0, req)
        self.assertEqual(needs_wake_up, True)
        with self.assertRaises(OffsetOutOfRangeError):
            yield from fetcher.next_record([])

        yield from fetcher.close()

Example 13

Project: aioredis
Source File: pipeline.py
View license
@asyncio.coroutine
def main():
    redis = yield from aioredis.create_redis(
        ('localhost', 6379))

    # No pipelining;
    @asyncio.coroutine
    def wait_each_command():
        val = yield from redis.get('foo')    # wait until `val` is available
        cnt = yield from redis.incr('bar')   # wait until `cnt` is available
        return val, cnt

    # Sending multiple commands and then gathering results
    @asyncio.coroutine
    def pipelined():
        fut1 = redis.get('foo')      # issue command and return future
        fut2 = redis.incr('bar')     # issue command and return future
        # block until results are available
        val, cnt = yield from asyncio.gather(fut1, fut2)
        return val, cnt

    # Convenient way
    @asyncio.coroutine
    def convenience_way():
        pipe = redis.pipeline()
        fut1 = pipe.get('foo')
        fut2 = pipe.incr('bar')
        result = yield from pipe.execute()
        val, cnt = yield from asyncio.gather(fut1, fut2)
        assert result == [val, cnt]
        return val, cnt

    res = yield from wait_each_command()
    print(res)
    res = yield from pipelined()
    print(res)
    res = yield from convenience_way()
    print(res)

    redis.close()
    yield from redis.wait_closed()

Example 14

Project: aioredis
Source File: pubsub2.py
View license
@asyncio.coroutine
def pubsub():
    sub = yield from aioredis.create_redis(
         ('localhost', 6379))

    ch1, ch2 = yield from sub.subscribe('channel:1', 'channel:2')
    assert isinstance(ch1, aioredis.Channel)
    assert isinstance(ch2, aioredis.Channel)

    @asyncio.coroutine
    def async_reader(channel):
        while (yield from channel.wait_message()):
            msg = yield from channel.get(encoding='utf-8')
            # ... process message ...
            print("message in {}: {}".format(channel.name, msg))

    tsk1 = asyncio.async(async_reader(ch1))

    # Or alternatively:

    @asyncio.coroutine
    def async_reader2(channel):
        while True:
            msg = yield from channel.get(encoding='utf-8')
            if msg is None:
                break
            # ... process message ...
            print("message in {}: {}".format(channel.name, msg))

    tsk2 = asyncio.async(async_reader2(ch2))

    # Publish messages and terminate
    pub = yield from aioredis.create_redis(
        ('localhost', 6379))
    while True:
        channels = yield from pub.pubsub_channels()
        if len(channels) == 2:
            break

    for msg in ("Hello", ",", "world!"):
        for ch in ('channel:1', 'channel:2'):
            yield from pub.publish(ch, msg)
    pub.close()
    sub.close()
    yield from asyncio.sleep(0)
    yield from pub.wait_closed()
    yield from sub.wait_closed()
    yield from asyncio.gather(tsk1, tsk2)

Example 15

Project: aiorest
Source File: session_factory_test.py
View license
    def test_load_and_save(self):
        factory = create_session_factory(self.dummy_sid, self.dummy_storage,
                                         loop=self.loop)

        @asyncio.coroutine
        def load(session_id):
            return (None, None)

        @asyncio.coroutine
        def save(session):
            return

        self.dummy_storage.load_session_data.side_effect = load
        self.dummy_storage.save_session_data.side_effect = save

        @asyncio.coroutine
        def go():
            waiter = asyncio.Future(loop=self.loop)
            req = mock.Mock()

            factory(req, waiter)

            sess = yield from asyncio.wait_for(waiter, timeout=1,
                                               loop=self.loop)

            self.assertIsInstance(sess, Session)
            self.assertTrue(sess.new)
            self.assertIsNone(sess.identity)
            req.add_response_callback.assert_called_once_with(
                factory._save, session=sess)

            yield from factory._save(req, session=sess)
            self.dummy_storage.save_session_data.assert_call_count(0)

            sess['foo'] = 'bar'
            yield from factory._save(req, session=sess)
            self.dummy_storage.save_session_data.assert_call_once_with(sess)

        self.loop.run_until_complete(go())

        self.dummy_storage.save_session_data.assert_call_count(0)

Example 16

Project: aiozmq
Source File: zmq_events_test.py
View license
    def test_req_rep(self):
        @asyncio.coroutine
        def connect_req():
            tr1, pr1 = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REQ,
                bind='inproc://test',
                loop=self.loop)
            self.assertEqual('CONNECTED', pr1.state)
            yield from pr1.connected
            return tr1, pr1

        tr1, pr1 = self.loop.run_until_complete(connect_req())

        @asyncio.coroutine
        def connect_rep():
            tr2, pr2 = yield from aiozmq.create_zmq_connection(
                lambda: Protocol(self.loop),
                zmq.REP,
                connect='inproc://test',
                loop=self.loop)
            self.assertEqual('CONNECTED', pr2.state)
            yield from pr2.connected
            return tr2, pr2

        tr2, pr2 = self.loop.run_until_complete(connect_rep())

        @asyncio.coroutine
        def communicate():
            tr1.write([b'request'])
            request = yield from pr2.received.get()
            self.assertEqual([b'request'], request)
            tr2.write([b'answer'])
            answer = yield from pr1.received.get()
            self.assertEqual([b'answer'], answer)

        self.loop.run_until_complete(communicate())

        @asyncio.coroutine
        def closing():
            tr1.close()
            tr2.close()

            yield from pr1.closed
            self.assertEqual('CLOSED', pr1.state)
            yield from pr2.closed
            self.assertEqual('CLOSED', pr2.state)

        self.loop.run_until_complete(closing())

Example 17

Project: gns3-server
Source File: route.py
View license
    @classmethod
    def _route(cls, method, path, *args, **kw):
        # This block is executed only the first time
        output_schema = kw.get("output", {})
        input_schema = kw.get("input", {})
        api_version = kw.get("api_version", 1)
        raw = kw.get("raw", False)

        # If it's a JSON api endpoint just register the endpoint an do nothing
        if api_version is None:
            cls._path = path
        else:
            cls._path = "/v{version}{path}".format(path=path, version=api_version)

        def register(func):
            route = cls._path

            handler = func.__module__.replace("_handler", "").replace("gns3server.handlers.api.", "")
            cls._documentation.setdefault(handler, {})
            cls._documentation[handler].setdefault(route, {"api_version": api_version,
                                                           "methods": []})

            cls._documentation[handler][route]["methods"].append({
                "method": method,
                "status_codes": kw.get("status_codes", {200: "OK"}),
                "parameters": kw.get("parameters", {}),
                "output_schema": output_schema,
                "input_schema": input_schema,
                "description": kw.get("description", ""),
            })
            func = asyncio.coroutine(func)

            @asyncio.coroutine
            def control_schema(request):
                # This block is executed at each method call

                server_config = Config.instance().get_section_config("Server")

                # Authenticate
                response = cls.authenticate(request, route, server_config)
                if response:
                    return response

                # Non API call
                if api_version is None or raw is True:
                    response = Response(request=request, route=route, output_schema=output_schema)

                    request = yield from parse_request(request, None, raw)
                    yield from func(request, response)
                    return response

                # API call
                try:
                    request = yield from parse_request(request, input_schema, raw)
                    record_file = server_config.get("record")
                    if record_file:
                        try:
                            with open(record_file, "a", encoding="utf-8") as f:
                                f.write("curl -X {} 'http://{}{}' -d '{}'".format(request.method, request.host, request.path_qs, json.dumps(request.json)))
                                f.write("\n")
                        except OSError as e:
                            log.warn("Could not write to the record file {}: {}".format(record_file, e))
                    response = Response(request=request, route=route, output_schema=output_schema)
                    yield from func(request, response)
                except aiohttp.web.HTTPBadRequest as e:
                    response = Response(request=request, route=route)
                    response.set_status(e.status)
                    response.json({"message": e.text, "status": e.status, "path": route, "request": request.json, "method": request.method})
                except aiohttp.web.HTTPException as e:
                    response = Response(request=request, route=route)
                    response.set_status(e.status)
                    response.json({"message": e.text, "status": e.status})
                except (VMError, UbridgeError) as e:
                    log.error("VM error detected: {type}".format(type=type(e)), exc_info=1)
                    response = Response(request=request, route=route)
                    response.set_status(409)
                    response.json({"message": str(e), "status": 409})
                except asyncio.futures.CancelledError as e:
                    log.error("Request canceled")
                    response = Response(request=request, route=route)
                    response.set_status(408)
                    response.json({"message": "Request canceled", "status": 408})
                except aiohttp.ClientDisconnectedError:
                    log.warn("Client disconnected")
                    response = Response(request=request, route=route)
                    response.set_status(408)
                    response.json({"message": "Client disconnected", "status": 408})
                except ConnectionResetError:
                    log.error("Client connection reset")
                    response = Response(request=request, route=route)
                    response.set_status(408)
                    response.json({"message": "Connection reset", "status": 408})
                except Exception as e:
                    log.error("Uncaught exception detected: {type}".format(type=type(e)), exc_info=1)
                    response = Response(request=request, route=route)
                    response.set_status(500)
                    CrashReport.instance().capture_exception(request)
                    exc_type, exc_value, exc_tb = sys.exc_info()
                    lines = traceback.format_exception(exc_type, exc_value, exc_tb)
                    if api_version is not None:
                        tb = "".join(lines)
                        response.json({"message": tb, "status": 500})
                    else:
                        tb = "\n".join(lines)
                        response.html("<h1>Internal error</h1><pre>{}</pre>".format(tb))

                return response

            @asyncio.coroutine
            def vm_concurrency(request):
                """
                To avoid strange effect we prevent concurrency
                between the same instance of the vm
                """

                if "vm_id" in request.match_info or "device_id" in request.match_info:
                    vm_id = request.match_info.get("vm_id")
                    if vm_id is None:
                        vm_id = request.match_info["device_id"]
                    cls._vm_locks.setdefault(vm_id, {"lock": asyncio.Lock(), "concurrency": 0})
                    cls._vm_locks[vm_id]["concurrency"] += 1

                    with (yield from cls._vm_locks[vm_id]["lock"]):
                        response = yield from control_schema(request)
                    cls._vm_locks[vm_id]["concurrency"] -= 1

                    # No more waiting requests, garbage collect the lock
                    if cls._vm_locks[vm_id]["concurrency"] <= 0:
                        del cls._vm_locks[vm_id]
                else:
                    response = yield from control_schema(request)
                return response

            cls._routes.append((method, cls._path, vm_concurrency))

            return vm_concurrency
        return register

Example 18

Project: Flask-aiohttp
Source File: handler.py
View license
    @asyncio.coroutine
    def handle_request(self, request: aiohttp.web.Request) -> \
            aiohttp.web.StreamResponse:
        """Handle WSGI request with aiohttp"""

        # Use aiohttp's WSGI implementation
        protocol = WSGIServerHttpProtocol(request.app, True)
        protocol.transport = request.transport

        # Build WSGI Response
        environ = protocol.create_wsgi_environ(request, request.content)

        # Create responses
        ws = aiohttp.web.WebSocketResponse()
        response = aiohttp.web.StreamResponse()

        #: Write delegate
        @asyncio.coroutine
        def write(data):
            yield from response.write(data)

        #: EOF Write delegate
        @asyncio.coroutine
        def write_eof():
            yield from response.write_eof()

        # WSGI start_response function
        def start_response(status, headers, exc_info=None):
            if exc_info:
                raise exc_info[1]

            status_parts = status.split(' ', 1)
            status = int(status_parts.pop(0))
            reason = status_parts[0] if status_parts else None

            response.set_status(status, reason=reason)

            for name, value in headers:
                response.headers[name] = value

            response.start(request)

            return write
        if is_websocket_request(request):
            ws.start(request)

            # WSGI HTTP responses in websocket are meaningless.
            def start_response(status, headers, exc_info=None):
                if exc_info:
                    raise exc_info[1]
                ws.start(request)
                return []

            @asyncio.coroutine
            def write(data):
                return

            @asyncio.coroutine
            def write_eof():
                return

            response = ws
        else:
            ws = None

        # Add websocket response to WSGI environment
        environ['wsgi.websocket'] = ws

        # Run WSGI app
        response_iter = self.wsgi(environ, start_response)

        try:
            iterator = iter(response_iter)

            wsgi_response = []
            try:
                item = next(iterator)
            except StopIteration as stop:
                try:
                    iterator = iter(stop.value)
                except TypeError:
                    pass
                else:
                    wsgi_response = iterator
            else:
                if isinstance(item, bytes):
                    # This is plain WSGI response iterator
                    wsgi_response = itertools.chain([item], iterator)
                else:
                    # This is coroutine
                    yield item
                    wsgi_response = yield from iterator
            for item in wsgi_response:
                yield from write(item)

            yield from write_eof()
        finally:
            if hasattr(response_iter, 'close'):
                response_iter.close()

        # Return selected response
        return response

Example 19

Project: home-assistant
Source File: __init__.py
View license
@asyncio.coroutine
def async_setup(hass, config):
    """Setup the automation."""
    component = EntityComponent(_LOGGER, DOMAIN, hass,
                                group_name=GROUP_NAME_ALL_AUTOMATIONS)

    success = yield from _async_process_config(hass, config, component)

    if not success:
        return False

    descriptions = yield from hass.loop.run_in_executor(
        None, conf_util.load_yaml_config_file, os.path.join(
            os.path.dirname(__file__), 'services.yaml')
    )

    @asyncio.coroutine
    def trigger_service_handler(service_call):
        """Handle automation triggers."""
        tasks = []
        for entity in component.async_extract_from_service(service_call):
            tasks.append(entity.async_trigger(
                service_call.data.get(ATTR_VARIABLES), True))
        yield from asyncio.wait(tasks, loop=hass.loop)

    @asyncio.coroutine
    def turn_onoff_service_handler(service_call):
        """Handle automation turn on/off service calls."""
        tasks = []
        method = 'async_{}'.format(service_call.service)
        for entity in component.async_extract_from_service(service_call):
            tasks.append(getattr(entity, method)())
        yield from asyncio.wait(tasks, loop=hass.loop)

    @asyncio.coroutine
    def toggle_service_handler(service_call):
        """Handle automation toggle service calls."""
        tasks = []
        for entity in component.async_extract_from_service(service_call):
            if entity.is_on:
                tasks.append(entity.async_turn_off())
            else:
                tasks.append(entity.async_turn_on())
        yield from asyncio.wait(tasks, loop=hass.loop)

    @asyncio.coroutine
    def reload_service_handler(service_call):
        """Remove all automations and load new ones from config."""
        conf = yield from component.async_prepare_reload()
        if conf is None:
            return
        yield from _async_process_config(hass, conf, component)

    hass.services.async_register(
        DOMAIN, SERVICE_TRIGGER, trigger_service_handler,
        descriptions.get(SERVICE_TRIGGER), schema=TRIGGER_SERVICE_SCHEMA)

    hass.services.async_register(
        DOMAIN, SERVICE_RELOAD, reload_service_handler,
        descriptions.get(SERVICE_RELOAD), schema=RELOAD_SERVICE_SCHEMA)

    hass.services.async_register(
        DOMAIN, SERVICE_TOGGLE, toggle_service_handler,
        descriptions.get(SERVICE_TOGGLE), schema=SERVICE_SCHEMA)

    for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF):
        hass.services.async_register(
            DOMAIN, service, turn_onoff_service_handler,
            descriptions.get(service), schema=SERVICE_SCHEMA)

    return True

Example 20

Project: xcat
Source File: xpath1.py
View license
    @asyncio.coroutine
    def retrieve_node(self, node, simple=False):
        # Sub-task that run in parallel to retrieve the information of a node.
        def attributes(self, node):
            attribute_count = yield from self.count_nodes(node.attributes)
            attributes_result = yield from self.get_attributes(node, attribute_count)
            return attributes_result
        
        def child_node_count(self, node):
            child_node_count_result = yield from self.count_nodes(node.children)
            return child_node_count_result

        def text(self, node):
            text_count = yield from self.count_nodes(node.text)
            text_result = yield from self.get_node_text(node, text_count)
            return text_result

        def comments(self, node):
            comment_count = yield from self.count_nodes(node.comments)
            comments_result = yield from self.get_comments(node, comment_count)
            return comments_result

        def node_name(self, node):
            node_name_result = yield from self.get_string(node.name)
            return node_name_result

        @asyncio.coroutine
        def simple_attributes(self, node):
            attribute_count = yield from self.count_nodes(node.attributes)
            return { "att%i" % i : "att%i_placeholder" % i for i in range(attribute_count) }

        @asyncio.coroutine
        def simple_text(self, node):
            text_count = yield from self.count_nodes(node.text)
            count_not_empty = 0
            
            for text in node.text(text_count):
                if not (yield from self.is_empty_string(text)):
                    count_not_empty += 1
            
            
            if count_not_empty > 0:
                return "%i text node found." % count_not_empty
            else:
                return ""

        @asyncio.coroutine
        def simple_comments(self, node):
            comment_count = yield from self.count_nodes(node.comments)
            comments = ["%i comments found." % comment_count] if comment_count > 0 else []
            return comments
        
        @asyncio.coroutine
        def simple_node_name(self, node):
            node_name_length = yield from self.string_length(node.name)
            
            if node_name_length <= 6:
                node_name_result = yield from self.get_string(node.name)
                return node_name_result
            
            left_part = yield from self.get_string(substring(node.name, 0, 3))
            right_part = yield from self.get_string(substring(node.name, node_name_length - 3, node_name_length))
            remaining = node_name_length - 6
            return "%s ... %i more chracters  ... %s" % (left_part, remaining, right_part)
        
        if simple:
    	    tasks = {
    	        "attributes"       : simple_attributes,
    	        "child_node_count" : child_node_count,
    	        "text"             : simple_text,
    	        "comments"         : simple_comments,
    	        "node_name"        : simple_node_name
    	    }
        else:
    	    tasks = {
    	        "attributes"       : attributes,
    	        "child_node_count" : child_node_count,
    	        "text"             : text,
    	        "comments"         : comments,
    	        "node_name"        : node_name
    	    }
        
        task_list = list(tasks.keys())
		
        futures = map(asyncio.Task, (tasks[task_name](self, node) for task_name in task_list ))
        results = (yield from asyncio.gather(*futures))
        results = dict(zip(task_list, results))
        
        return XMLNode(
            name=results["node_name"],
            attributes=results["attributes"],
            comments=results["comments"],
            text=results["text"],
            child_count=results["child_node_count"],
            node=node,
            children=[]
        )

Example 21

Project: xcat
Source File: xcat.py
View license
@run.command(help="Let's you manually explore the XML file with a console.")
@click.pass_context
def console(ctx):
    click.echo("Opening console")

    current_node = "/*[1]"
    executor = ctx.obj["executor"]

    @asyncio.coroutine
    def command_attr(node, params):
        attribute_count   = yield from executor.count_nodes(node.attributes)
        attributes_result = yield from executor.get_attributes(node, attribute_count)

        if attribute_count == 0:
            click.echo("No attributes found.")
        else:
            for name in attributes_result:
                if not name == "":
                    click.echo("%s = %s" % (name, attributes_result[name]))

    @asyncio.coroutine
    def command_ls(node, params):
        child_node_count_result = yield from executor.count_nodes(node.children)
        click.echo("%i child node found." % child_node_count_result)

        futures = map(asyncio.Task, (executor.get_string(child.name) for child in node.children(child_node_count_result) ))
        results = (yield from asyncio.gather(*futures))
        
        for result in results:
            click.echo(result)

    @asyncio.coroutine
    def command_cd(node, params):
        if len(params) < 1:
            click.echo("You must specify a node to navigate to.")
            return

        selected_node = params[0]

        # We consider anything that starts with a slash is an absolute path
        if selected_node[0] == "/":
            new_node = selected_node
        elif selected_node == "..":
            new_node = "/".join(current_node.split("/")[:-1])
        elif selected_node == ".":
            new_node = current_node
        else:
            new_node = current_node + "/" + selected_node

        if (yield from executor.is_empty_string(E(new_node).name)):
            click.echo("Node does not exists.")
        else:
            return new_node

    @asyncio.coroutine
    def command_content(node, params):
        text_count = yield from executor.count_nodes(node.text)
        click.echo((yield from executor.get_node_text(node, text_count)))

    @asyncio.coroutine
    def command_comment(node, params):
        comment_count = yield from executor.count_nodes(node.comments)
        click.echo("%i comment node found." % comment_count)

        for comment in (yield from executor.get_comments(node, comment_count)):
            click.echo("<!-- %s -->" % comment)

    @asyncio.coroutine
    def command_name(node, params):
        node_name = yield from executor.get_string(E(current_node).name)
        click.echo(node_name)

    commands = {
        "ls"      : command_ls,
        "attr"    : command_attr,
        "cd"      : command_cd,
        "content" : command_content,
        "comment" : command_comment,
        "name"    : command_name
    }

    while True:
        command = click.prompt("%s : " % current_node, prompt_suffix="")
        command_part = command.split(" ")
        command_name = command_part[0]
        parameters = command_part[1:]

        if command_name in commands:
            command_execution = commands[command_name](E(current_node), parameters)
            new_node = run_then_return(command_execution)

            if not new_node == None:
                current_node = new_node
        else:
            click.echo("Unknown command")

Example 22

View license
@asyncio.coroutine
def websocket_to_order_book():
    try:
        coinbase_websocket = yield from websockets.connect("wss://ws-feed.exchange.coinbase.com")
    except gaierror:
        order_book_file_logger.error('socket.gaierror - had a problem connecting to Coinbase feed')
        return

    yield from coinbase_websocket.send('{"type": "subscribe", "product_id": "BTC-USD"}')

    messages = []
    while True:
        message = yield from coinbase_websocket.recv()
        message = json.loads(message)
        messages += [message]
        if len(messages) > 20:
            break

    order_book.get_level3()

    [order_book.process_message(message) for message in messages if message['sequence'] > order_book.level3_sequence]
    messages = []
    while True:
        message = yield from coinbase_websocket.recv()
        if message is None:
            order_book_file_logger.error('Websocket message is None.')
            return False
        try:
            message = json.loads(message)
        except TypeError:
            order_book_file_logger.error('JSON did not load, see ' + str(message))
            return False
        if args.command_line:
            messages += [datetime.strptime(message['time'], '%Y-%m-%dT%H:%M:%S.%fZ').replace(tzinfo=pytz.UTC)]
            messages = [message for message in messages if (datetime.now(tzlocal()) - message).seconds < 60]
            if len(messages) > 2:
                diff = numpy.diff(messages)
                diff = [float(sec.microseconds) for sec in diff]
                order_book.average_rate = numpy.mean(diff)
                order_book.fastest_rate = min(diff)
                order_book.slowest_rate = max(diff)
        if not order_book.process_message(message):
            print(pformat(message))
            return False
        if args.trading:
            if 'order_id' in message and message['order_id'] == open_orders.open_ask_order_id:
                if message['type'] == 'done':
                    open_orders.open_ask_order_id = None
                    open_orders.open_ask_price = None
                    open_orders.open_ask_status = None
                    open_orders.open_ask_rejections = Decimal('0.0')
                    open_orders.open_ask_cancelled = False
                else:
                    open_orders.open_ask_status = message['type']
            elif 'order_id' in message and message['order_id'] == open_orders.open_bid_order_id:
                if message['type'] == 'done':
                    open_orders.open_bid_order_id = None
                    open_orders.open_bid_price = None
                    open_orders.open_bid_status = None
                    open_orders.open_bid_rejections = Decimal('0.0')
                    open_orders.open_bid_cancelled = False
                else:
                    open_orders.open_bid_status = message['type']

Example 23

Project: asyncio
Source File: crawl.py
View license
    @asyncio.coroutine
    def fetch(self):
        """Attempt to fetch the contents of the URL.

        If successful, and the data is HTML, extract further links and
        add them to the crawler.  Redirects are also added back there.
        """
        while self.tries < self.max_tries:
            self.tries += 1
            self.request = None
            try:
                self.request = Request(self.log, self.url, self.crawler.pool)
                yield from self.request.connect()
                yield from self.request.send_request()
                self.response = yield from self.request.get_response()
                self.body = yield from self.response.read()
                h_conn = self.response.get_header('connection').lower()
                if h_conn != 'close':
                    self.request.close(recycle=True)
                    self.request = None
                if self.tries > 1:
                    self.log(1, 'try', self.tries, 'for', self.url, 'success')
                break
            except (BadStatusLine, OSError) as exc:
                self.exceptions.append(exc)
                self.log(1, 'try', self.tries, 'for', self.url,
                            'raised', repr(exc))
                ##import pdb; pdb.set_trace()
                # Don't reuse the connection in this case.
            finally:
                if self.request is not None:
                    self.request.close()
        else:
            # We never broke out of the while loop, i.e. all tries failed.
            self.log(0, 'no success for', self.url,
                        'in', self.max_tries, 'tries')
            return
        next_url = self.response.get_redirect_url()
        if next_url:
            self.next_url = urllib.parse.urljoin(self.url, next_url)
            if self.max_redirect > 0:
                self.log(1, 'redirect to', self.next_url, 'from', self.url)
                self.crawler.add_url(self.next_url, self.max_redirect-1)
            else:
                self.log(0, 'redirect limit reached for', self.next_url,
                            'from', self.url)
        else:
            if self.response.status == 200:
                self.ctype = self.response.get_header('content-type')
                self.pdict = {}
                if self.ctype:
                    self.ctype, self.pdict = cgi.parse_header(self.ctype)
                self.encoding = self.pdict.get('charset', 'utf-8')
                if self.ctype == 'text/html':
                    body = self.body.decode(self.encoding, 'replace')
                    # Replace href with (?:href|src) to follow image links.
                    self.urls = set(re.findall(r'(?i)href=["\']?([^\s"\'<>]+)',
                                               body))
                    if self.urls:
                        self.log(1, 'got', len(self.urls),
                                    'distinct urls from', self.url)
                    self.new_urls = set()
                    for url in self.urls:
                        url = unescape(url)
                        url = urllib.parse.urljoin(self.url, url)
                        url, frag = urllib.parse.urldefrag(url)
                        if self.crawler.add_url(url):
                            self.new_urls.add(url)

Example 24

Project: asyncio
Source File: test_streams.py
View license
    def test_start_server(self):

        class MyServer:

            def __init__(self, loop):
                self.server = None
                self.loop = loop

            @asyncio.coroutine
            def handle_client(self, client_reader, client_writer):
                data = yield from client_reader.readline()
                client_writer.write(data)
                yield from client_writer.drain()
                client_writer.close()

            def start(self):
                sock = socket.socket()
                sock.bind(('127.0.0.1', 0))
                self.server = self.loop.run_until_complete(
                    asyncio.start_server(self.handle_client,
                                         sock=sock,
                                         loop=self.loop))
                return sock.getsockname()

            def handle_client_callback(self, client_reader, client_writer):
                self.loop.create_task(self.handle_client(client_reader,
                                                         client_writer))

            def start_callback(self):
                sock = socket.socket()
                sock.bind(('127.0.0.1', 0))
                addr = sock.getsockname()
                sock.close()
                self.server = self.loop.run_until_complete(
                    asyncio.start_server(self.handle_client_callback,
                                         host=addr[0], port=addr[1],
                                         loop=self.loop))
                return addr

            def stop(self):
                if self.server is not None:
                    self.server.close()
                    self.loop.run_until_complete(self.server.wait_closed())
                    self.server = None

        @asyncio.coroutine
        def client(addr):
            reader, writer = yield from asyncio.open_connection(
                *addr, loop=self.loop)
            # send a line
            writer.write(b"hello world!\n")
            # read it back
            msgback = yield from reader.readline()
            writer.close()
            return msgback

        # test the server variant with a coroutine as client handler
        server = MyServer(self.loop)
        addr = server.start()
        msg = self.loop.run_until_complete(asyncio.Task(client(addr),
                                                        loop=self.loop))
        server.stop()
        self.assertEqual(msg, b"hello world!\n")

        # test the server variant with a callback as client handler
        server = MyServer(self.loop)
        addr = server.start_callback()
        msg = self.loop.run_until_complete(asyncio.Task(client(addr),
                                                        loop=self.loop))
        server.stop()
        self.assertEqual(msg, b"hello world!\n")

Example 25

Project: asyncio
Source File: test_streams.py
View license
    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
    def test_start_unix_server(self):

        class MyServer:

            def __init__(self, loop, path):
                self.server = None
                self.loop = loop
                self.path = path

            @asyncio.coroutine
            def handle_client(self, client_reader, client_writer):
                data = yield from client_reader.readline()
                client_writer.write(data)
                yield from client_writer.drain()
                client_writer.close()

            def start(self):
                self.server = self.loop.run_until_complete(
                    asyncio.start_unix_server(self.handle_client,
                                              path=self.path,
                                              loop=self.loop))

            def handle_client_callback(self, client_reader, client_writer):
                self.loop.create_task(self.handle_client(client_reader,
                                                         client_writer))

            def start_callback(self):
                start = asyncio.start_unix_server(self.handle_client_callback,
                                                  path=self.path,
                                                  loop=self.loop)
                self.server = self.loop.run_until_complete(start)

            def stop(self):
                if self.server is not None:
                    self.server.close()
                    self.loop.run_until_complete(self.server.wait_closed())
                    self.server = None

        @asyncio.coroutine
        def client(path):
            reader, writer = yield from asyncio.open_unix_connection(
                path, loop=self.loop)
            # send a line
            writer.write(b"hello world!\n")
            # read it back
            msgback = yield from reader.readline()
            writer.close()
            return msgback

        # test the server variant with a coroutine as client handler
        with test_utils.unix_socket_path() as path:
            server = MyServer(self.loop, path)
            server.start()
            msg = self.loop.run_until_complete(asyncio.Task(client(path),
                                                            loop=self.loop))
            server.stop()
            self.assertEqual(msg, b"hello world!\n")

        # test the server variant with a callback as client handler
        with test_utils.unix_socket_path() as path:
            server = MyServer(self.loop, path)
            server.start_callback()
            msg = self.loop.run_until_complete(asyncio.Task(client(path),
                                                            loop=self.loop))
            server.stop()
            self.assertEqual(msg, b"hello world!\n")

Example 26

Project: hyper-h2
Source File: wsgi-server.py
View license
    @asyncio.coroutine
    def sending_loop(self):
        """
        A call that loops forever, attempting to send data. This sending loop
        contains most of the flow-control smarts of this class: it pulls data
        off of the asyncio queue and then attempts to send it.

        The difficulties here are all around flow control. Specifically, a
        chunk of data may be too large to send. In this case, what will happen
        is that this coroutine will attempt to send what it can and will then
        store the unsent data locally. When a flow control event comes in that
        data will be freed up and placed back onto the asyncio queue, causing
        it to pop back up into the sending logic of this coroutine.

        This method explicitly *does not* handle HTTP/2 priority. That adds an
        extra layer of complexity to what is already a fairly complex method,
        and we'll look at how to do it another time.

        This coroutine explicitly *does not end*.
        """
        while True:
            stream_id, data, event = yield from self._stream_data.get()

            # If this stream got reset, just drop the data on the floor. Note
            # that we need to reset the event here to make sure that
            # application doesn't lock up.
            if stream_id in self._reset_streams:
                event.set()

            # Check if the body is done. If it is, this is really easy! Again,
            # we *must* set the event here or the application will lock up.
            if data is END_DATA_SENTINEL:
                self.conn.end_stream(stream_id)
                self.transport.write(self.conn.data_to_send())
                event.set()
                continue

            # We need to send data, but not to exceed the flow control window.
            # For that reason, grab only the data that fits: we'll buffer the
            # rest.
            window_size = self.conn.local_flow_control_window(stream_id)
            chunk_size = min(window_size, len(data))
            data_to_send = data[:chunk_size]
            data_to_buffer = data[chunk_size:]

            if data_to_send:
                # There's a maximum frame size we have to respect. Because we
                # aren't paying any attention to priority here, we can quite
                # safely just split this string up into chunks of max frame
                # size and blast them out.
                #
                # In a *real* application you'd want to consider priority here.
                max_size = self.conn.max_outbound_frame_size
                chunks = (
                    data_to_send[x:x+max_size]
                    for x in range(0, len(data_to_send), max_size)
                )
                for chunk in chunks:
                    self.conn.send_data(stream_id, chunk)
                self.transport.write(self.conn.data_to_send())

            # If there's data left to buffer, we should do that. Put it in a
            # dictionary and *don't set the event*: the app must not generate
            # any more data until we got rid of all of this data.
            if data_to_buffer:
                self._flow_controlled_data[stream_id] = (
                    stream_id, data_to_buffer, event
                )
            else:
                # We sent everything. We can let the WSGI app progress.
                event.set()

Example 27

Project: django-c10k-demo
Source File: client.py
View license
@asyncio.coroutine
def run(row, col, size, wrap, speed, steps=None, state=None):

    if state is None:
        state = random.choice((True, False, False, False))

    neighbors = get_neighbors(row, col, size, wrap)
    neighbors = {n: i for i, n in enumerate(neighbors)}
    n = len(neighbors)

    # Throttle at 100 connections / second on average
    yield from asyncio.sleep(size * size / 100 * random.random())
    ws = yield from websockets.connect(BASE_URL + '/worker/')

    # Wait until all clients are connected.
    msg = yield from ws.recv()
    if msg != 'sub':
        raise Exception("Unexpected message: {}".format(msg))

    # Subscribe to updates sent by neighbors.
    for neighbor in neighbors:
        yield from ws.send('{} {}'.format(*neighbor))
    yield from ws.send('sub')

    # Wait until all clients are subscribed.
    msg = yield from ws.recv()
    if msg != 'run':
        raise Exception("Unexpected message: {}".format(msg))

    yield from ws.send('{} {} {} {}'.format(0, row, col, int(state)))

    # This is the step for which we last sent our state, and for which we're
    # collecting the states of our neighbors.
    step = 0
    # Once we know all our neighbors' states at step N - 1, we compute and
    # send our state at step N. At this point, our neighbors can send their
    # states at steps N and N + 1, but not N + 2, since that requires our
    # state at step N + 1. We only need to keep track of two sets of states.
    states = [[None] * n, [None] * n]

    # Gather state updates from neighbors and send our own state updates.
    while (steps is None or step < steps):
        msg = yield from ws.recv()
        if msg is None:
            break
        _step, _row, _col, _state = (int(x) for x in msg.split())
        target = _step % 2
        states[target][neighbors[(_row, _col)]] = bool(_state)
        # Compute next state
        if None not in states[target]:
            assert _step == step
            step += 1
            alive = states[target].count(True)
            state = alive == 3 or (state and alive == 2)
            states[target] = [None] * n
            yield from ws.send('{} {} {} {}'.format(step, row, col, int(state)))
            # Throttle, speed is a number of steps per second
            yield from asyncio.sleep(1 / speed)

    yield from ws.close()

Example 28

Project: django-userlog
Source File: realtime.py
View license
@asyncio.coroutine
def userlog(websocket, uri):
    token = yield from websocket.recv()

    redis = yield from redis_connection()

    token_key = 'token:{}'.format(token)

    # Access control
    username = yield from redis.get(token_key)
    if username is None:
        return

    log_key = 'log:{}'.format(username)
    channel = 'userlog:{}'.format(log_key)

    try:
        if channel.endswith('*'):       # logs for several users
            # Stream new lines
            subscriber = yield from redis.start_subscribe()
            yield from subscriber.psubscribe([channel])
            while True:
                reply = yield from subscriber.next_published()
                data = json.loads(reply.value)
                data['username'] = reply.channel.rpartition(':')[2]
                line = json.dumps(data)
                try:
                    yield from websocket.send(line)
                except websockets.ConnectionClosed:
                    return

        else:                           # logs for a single user
            # Send backlock
            log = yield from redis.lrange(log_key, 0, -1)
            for item in reversed(list(log)):
                line = yield from item
                try:
                    yield from websocket.send(line)
                except websockets.ConnectionClosed:
                    return

            # Stream new lines
            subscriber = yield from redis.start_subscribe()
            yield from subscriber.subscribe([channel])
            while True:
                reply = yield from subscriber.next_published()
                line = reply.value
                try:
                    yield from websocket.send(line)
                except websockets.ConnectionClosed:
                    return

    finally:
        redis.close()
        # Loop one more time to complete the cancellation of redis._reader_f,
        # which runs redis._reader_coroutine(), after redis.connection_lost().
        yield from asyncio.sleep(0)

Example 29

Project: websockets
Source File: client.py
View license
    @asyncio.coroutine
    def handshake(self, wsuri,
                  origin=None, subprotocols=None, extra_headers=None):
        """
        Perform the client side of the opening handshake.

        If provided, ``origin`` sets the Origin HTTP header.

        If provided, ``subprotocols`` is a list of supported subprotocols in
        order of decreasing preference.

        If provided, ``extra_headers`` sets additional HTTP request headers.
        It must be a mapping or an iterable of (name, value) pairs.

        """
        headers = []
        set_header = lambda k, v: headers.append((k, v))
        if wsuri.port == (443 if wsuri.secure else 80):     # pragma: no cover
            set_header('Host', wsuri.host)
        else:
            set_header('Host', '{}:{}'.format(wsuri.host, wsuri.port))
        if origin is not None:
            set_header('Origin', origin)
        if subprotocols is not None:
            set_header('Sec-WebSocket-Protocol', ', '.join(subprotocols))
        if extra_headers is not None:
            if isinstance(extra_headers, collections.abc.Mapping):
                extra_headers = extra_headers.items()
            for name, value in extra_headers:
                set_header(name, value)
        set_header('User-Agent', USER_AGENT)
        key = build_request(set_header)

        self.request_headers = email.message.Message()
        for name, value in headers:
            self.request_headers[name] = value
        self.raw_request_headers = headers

        # Send handshake request. Since the URI and the headers only contain
        # ASCII characters, we can keep this simple.
        request = ['GET %s HTTP/1.1' % wsuri.resource_name]
        request.extend('{}: {}'.format(k, v) for k, v in headers)
        request.append('\r\n')
        request = '\r\n'.join(request).encode()
        self.writer.write(request)

        # Read handshake response.
        try:
            status_code, headers = yield from read_response(self.reader)
        except ValueError as exc:
            raise InvalidHandshake("Malformed HTTP message") from exc
        if status_code != 101:
            raise InvalidHandshake("Bad status code: {}".format(status_code))

        self.response_headers = headers
        self.raw_response_headers = list(headers.raw_items())

        get_header = lambda k: headers.get(k, '')
        check_response(get_header, key)

        self.subprotocol = headers.get('Sec-WebSocket-Protocol', None)
        if (self.subprotocol is not None and
                self.subprotocol not in subprotocols):
            raise InvalidHandshake(
                "Unknown subprotocol: {}".format(self.subprotocol))

        assert self.state == CONNECTING
        self.state = OPEN
        self.opening_handshake.set_result(True)

Example 30

Project: websockets
Source File: server.py
View license
    @asyncio.coroutine
    def handshake(self, origins=None, subprotocols=None, extra_headers=None):
        """
        Perform the server side of the opening handshake.

        If provided, ``origins`` is a list of acceptable HTTP Origin values.
        Include ``''`` if the lack of an origin is acceptable.

        If provided, ``subprotocols`` is a list of supported subprotocols in
        order of decreasing preference.

        If provided, ``extra_headers`` sets additional HTTP response headers.
        It can be a mapping or an iterable of (name, value) pairs. It can also
        be a callable taking the request path and headers in arguments.

        Return the URI of the request.

        """
        # Read handshake request.
        try:
            path, headers = yield from read_request(self.reader)
        except ValueError as exc:
            raise InvalidHandshake("Malformed HTTP message") from exc

        self.request_headers = headers
        self.raw_request_headers = list(headers.raw_items())

        get_header = lambda k: headers.get(k, '')
        key = check_request(get_header)

        if origins is not None:
            origin = get_header('Origin')
            if not set(origin.split() or ['']) <= set(origins):
                raise InvalidOrigin("Origin not allowed: {}".format(origin))

        if subprotocols is not None:
            protocol = get_header('Sec-WebSocket-Protocol')
            if protocol:
                client_subprotocols = [p.strip() for p in protocol.split(',')]
                self.subprotocol = self.select_subprotocol(
                    client_subprotocols, subprotocols)

        headers = []
        set_header = lambda k, v: headers.append((k, v))
        set_header('Server', USER_AGENT)
        if self.subprotocol:
            set_header('Sec-WebSocket-Protocol', self.subprotocol)
        if extra_headers is not None:
            if callable(extra_headers):
                extra_headers = extra_headers(path, self.raw_request_headers)
            if isinstance(extra_headers, collections.abc.Mapping):
                extra_headers = extra_headers.items()
            for name, value in extra_headers:
                set_header(name, value)
        build_response(set_header, key)

        self.response_headers = email.message.Message()
        for name, value in headers:
            self.response_headers[name] = value
        self.raw_response_headers = headers

        # Send handshake response. Since the status line and headers only
        # contain ASCII characters, we can keep this simple.
        response = ['HTTP/1.1 101 Switching Protocols']
        response.extend('{}: {}'.format(k, v) for k, v in headers)
        response.append('\r\n')
        response = '\r\n'.join(response).encode()
        self.writer.write(response)

        assert self.state == CONNECTING
        self.state = OPEN
        self.opening_handshake.set_result(True)

        return path

Example 31

Project: websockets
Source File: server.py
View license
@asyncio.coroutine
def serve(ws_handler, host=None, port=None, *,
          klass=WebSocketServerProtocol,
          timeout=10, max_size=2 ** 20, max_queue=2 ** 5,
          loop=None, legacy_recv=False,
          origins=None, subprotocols=None, extra_headers=None,
          **kwds):
    """
    This coroutine creates a WebSocket server.

    It yields a :class:`~asyncio.Server` which provides:

    * a :meth:`~asyncio.Server.close` method that closes open connections with
      status code 1001 and stops accepting new connections
    * a :meth:`~asyncio.Server.wait_closed` coroutine that waits until closing
      handshakes complete and connections are closed.

    ``ws_handler`` is the WebSocket handler. It must be a coroutine accepting
    two arguments: a :class:`WebSocketServerProtocol` and the request URI.

    :func:`serve` is a wrapper around the event loop's
    :meth:`~asyncio.BaseEventLoop.create_server` method. ``host``, ``port`` as
    well as extra keyword arguments are passed to
    :meth:`~asyncio.BaseEventLoop.create_server`.

    For example, you can set the ``ssl`` keyword argument to a
    :class:`~ssl.SSLContext` to enable TLS.

    The behavior of the ``timeout``, ``max_size``, and ``max_queue`` optional
    arguments is described the documentation of
    :class:`~websockets.protocol.WebSocketCommonProtocol`.

    :func:`serve` also accepts the following optional arguments:

    * ``origins`` defines acceptable Origin HTTP headers — include
      ``''`` if the lack of an origin is acceptable
    * ``subprotocols`` is a list of supported subprotocols in order of
      decreasing preference
    * ``extra_headers`` sets additional HTTP response headers — it can be a
      mapping, an iterable of (name, value) pairs, or a callable taking the
      request path and headers in arguments.

    Whenever a client connects, the server accepts the connection, creates a
    :class:`WebSocketServerProtocol`, performs the opening handshake, and
    delegates to the WebSocket handler. Once the handler completes, the server
    performs the closing handshake and closes the connection.

    Since there's no useful way to propagate exceptions triggered in handlers,
    they're sent to the ``'websockets.server'`` logger instead. Debugging is
    much easier if you configure logging to print them::

        import logging
        logger = logging.getLogger('websockets.server')
        logger.setLevel(logging.ERROR)
        logger.addHandler(logging.StreamHandler())

    """
    if loop is None:
        loop = asyncio.get_event_loop()

    ws_server = WebSocketServer(loop)

    secure = kwds.get('ssl') is not None
    factory = lambda: klass(
        ws_handler, ws_server,
        host=host, port=port, secure=secure,
        timeout=timeout, max_size=max_size, max_queue=max_queue,
        loop=loop, legacy_recv=legacy_recv,
        origins=origins, subprotocols=subprotocols,
        extra_headers=extra_headers,
    )
    server = yield from loop.create_server(factory, host, port, **kwds)

    ws_server.wrap(server)

    return ws_server

Example 32

Project: discord.py
Source File: gateway.py
View license
    @asyncio.coroutine
    def received_message(self, msg):
        self._dispatch('socket_raw_receive', msg)

        if isinstance(msg, bytes):
            msg = zlib.decompress(msg, 15, 10490000) # This is 10 MiB
            msg = msg.decode('utf-8')

        msg = json.loads(msg)
        state = self._connection

        log.debug('WebSocket Event: {}'.format(msg))
        self._dispatch('socket_response', msg)

        op = msg.get('op')
        data = msg.get('d')

        if 's' in msg:
            state.sequence = msg['s']

        if op == self.RECONNECT:
            # "reconnect" can only be handled by the Client
            # so we terminate our connection and raise an
            # internal exception signalling to reconnect.
            log.info('Received RECONNECT opcode.')
            yield from self.close()
            raise ReconnectWebSocket()

        if op == self.HEARTBEAT_ACK:
            return # disable noisy logging for now

        if op == self.HEARTBEAT:
            beat = self._keep_alive.get_payload()
            yield from self.send_as_json(beat)
            return

        if op == self.HELLO:
            interval = data['heartbeat_interval'] / 1000.0
            self._keep_alive = KeepAliveHandler(ws=self, interval=interval)
            self._keep_alive.start()
            return

        if op == self.INVALIDATE_SESSION:
            state.sequence = None
            state.session_id = None
            if data == True:
                yield from self.close()
                raise ResumeWebSocket()

            yield from self.identify()
            return

        if op != self.DISPATCH:
            log.info('Unhandled op {}'.format(op))
            return

        event = msg.get('t')
        is_ready = event == 'READY'

        if is_ready:
            state.clear()
            state.sequence = msg['s']
            state.session_id = data['session_id']

        parser = 'parse_' + event.lower()

        try:
            func = getattr(self._connection, parser)
        except AttributeError:
            log.info('Unhandled event {}'.format(event))
        else:
            func(data)

        # remove the dispatched listeners
        removed = []
        for index, entry in enumerate(self._dispatch_listeners):
            if entry.event != event:
                continue

            future = entry.future
            if future.cancelled():
                removed.append(index)

            try:
                valid = entry.predicate(data)
            except Exception as e:
                future.set_exception(e)
                removed.append(index)
            else:
                if valid:
                    ret = data if entry.result is None else entry.result(data)
                    future.set_result(ret)
                    removed.append(index)

        for index in reversed(removed):
            del self._dispatch_listeners[index]

Example 33

Project: discord.py
Source File: http.py
View license
    @asyncio.coroutine
    def request(self, method, url, *, bucket=None, **kwargs):
        lock = self._locks.get(bucket)
        if lock is None:
            lock = asyncio.Lock(loop=self.loop)
            if bucket is not None:
                self._locks[bucket] = lock

        # header creation
        headers = {
            'User-Agent': self.user_agent,
        }

        if self.token is not None:
            headers['Authorization'] = 'Bot ' + self.token if self.bot_token else self.token

        # some checking if it's a JSON request
        if 'json' in kwargs:
            headers['Content-Type'] = 'application/json'
            kwargs['data'] = utils.to_json(kwargs.pop('json'))

        kwargs['headers'] = headers
        with (yield from lock):
            for tries in range(5):
                r = yield from self.session.request(method, url, **kwargs)
                log.debug(self.REQUEST_LOG.format(method=method, url=url, status=r.status, json=kwargs.get('data')))
                try:
                    # even errors have text involved in them so this is safe to call
                    data = yield from json_or_text(r)

                    # the request was successful so just return the text/json
                    if 300 > r.status >= 200:
                        log.debug(self.SUCCESS_LOG.format(method=method, url=url, text=data))
                        return data

                    # we are being rate limited
                    if r.status == 429:
                        fmt = 'We are being rate limited. Retrying in {:.2} seconds. Handled under the bucket "{}"'

                        # sleep a bit
                        retry_after = data['retry_after'] / 1000.0
                        log.info(fmt.format(retry_after, bucket))
                        yield from asyncio.sleep(retry_after)
                        continue

                    # we've received a 502, unconditional retry
                    if r.status == 502 and tries <= 5:
                        yield from asyncio.sleep(1 + tries * 2)
                        continue

                    # the usual error cases
                    if r.status == 403:
                        raise Forbidden(r, data)
                    elif r.status == 404:
                        raise NotFound(r, data)
                    else:
                        raise HTTPException(r, data)
                finally:
                    # clean-up just in case
                    yield from r.release()

Example 34

Project: discord.py
Source File: voice_client.py
View license
    @asyncio.coroutine
    def create_ytdl_player(self, url, *, ytdl_options=None, **kwargs):
        """|coro|

        Creates a stream player for youtube or other services that launches
        in a separate thread to play the audio.

        The player uses the ``youtube_dl`` python library to get the information
        required to get audio from the URL. Since this uses an external library,
        you must install it yourself. You can do so by calling
        ``pip install youtube_dl``.

        You must have the ffmpeg or avconv executable in your path environment
        variable in order for this to work.

        The operations that can be done on the player are the same as those in
        :meth:`create_stream_player`. The player has been augmented and enhanced
        to have some info extracted from the URL. If youtube-dl fails to extract
        the information then the attribute is ``None``. The ``yt``, ``url``, and
        ``download_url`` attributes are always available.

        +---------------------+---------------------------------------------------------+
        |      Operation      |                       Description                       |
        +=====================+=========================================================+
        | player.yt           | The `YoutubeDL <ytdl>` instance.                        |
        +---------------------+---------------------------------------------------------+
        | player.url          | The URL that is currently playing.                      |
        +---------------------+---------------------------------------------------------+
        | player.download_url | The URL that is currently being downloaded to ffmpeg.   |
        +---------------------+---------------------------------------------------------+
        | player.title        | The title of the audio stream.                          |
        +---------------------+---------------------------------------------------------+
        | player.description  | The description of the audio stream.                    |
        +---------------------+---------------------------------------------------------+
        | player.uploader     | The uploader of the audio stream.                       |
        +---------------------+---------------------------------------------------------+
        | player.upload_date  | A datetime.date object of when the stream was uploaded. |
        +---------------------+---------------------------------------------------------+
        | player.duration     | The duration of the audio in seconds.                   |
        +---------------------+---------------------------------------------------------+
        | player.likes        | How many likes the audio stream has.                    |
        +---------------------+---------------------------------------------------------+
        | player.dislikes     | How many dislikes the audio stream has.                 |
        +---------------------+---------------------------------------------------------+
        | player.is_live      | Checks if the audio stream is currently livestreaming.  |
        +---------------------+---------------------------------------------------------+
        | player.views        | How many views the audio stream has.                    |
        +---------------------+---------------------------------------------------------+

        .. _ytdl: https://github.com/rg3/youtube-dl/blob/master/youtube_dl/YoutubeDL.py#L128-L278

        Examples
        ----------

        Basic usage: ::

            voice = await client.join_voice_channel(channel)
            player = await voice.create_ytdl_player('https://www.youtube.com/watch?v=d62TYemN6MQ')
            player.start()

        Parameters
        -----------
        url : str
            The URL that ``youtube_dl`` will take and download audio to pass
            to ``ffmpeg`` or ``avconv`` to convert to PCM bytes.
        ytdl_options : dict
            A dictionary of options to pass into the ``YoutubeDL`` instance.
            See `the documentation <ytdl>`_ for more details.
        \*\*kwargs
            The rest of the keyword arguments are forwarded to
            :func:`create_ffmpeg_player`.

        Raises
        -------
        ClientException
            Popen failure from either ``ffmpeg``/``avconv``.

        Returns
        --------
        StreamPlayer
            An augmented StreamPlayer that uses ffmpeg.
            See :meth:`create_stream_player` for base operations.
        """
        import youtube_dl

        use_avconv = kwargs.get('use_avconv', False)
        opts = {
            'format': 'webm[abr>0]/bestaudio/best',
            'prefer_ffmpeg': not use_avconv
        }

        if ytdl_options is not None and isinstance(ytdl_options, dict):
            opts.update(ytdl_options)

        ydl = youtube_dl.YoutubeDL(opts)
        func = functools.partial(ydl.extract_info, url, download=False)
        info = yield from self.loop.run_in_executor(None, func)
        if "entries" in info:
            info = info['entries'][0]

        log.info('playing URL {}'.format(url))
        download_url = info['url']
        player = self.create_ffmpeg_player(download_url, **kwargs)

        # set the dynamic attributes from the info extraction
        player.download_url = download_url
        player.url = url
        player.yt = ydl
        player.views = info.get('view_count')
        player.is_live = bool(info.get('is_live'))
        player.likes = info.get('like_count')
        player.dislikes = info.get('dislike_count')
        player.duration = info.get('duration')
        player.uploader = info.get('uploader')

        is_twitch = 'twitch' in url
        if is_twitch:
            # twitch has 'title' and 'description' sort of mixed up.
            player.title = info.get('description')
            player.description = None
        else:
            player.title = info.get('title')
            player.description = info.get('description')

        # upload date handling
        date = info.get('upload_date')
        if date:
            try:
                date = datetime.datetime.strptime(date, '%Y%M%d').date()
            except ValueError:
                date = None

        player.upload_date = date
        return player

Example 35

Project: swift-nbd-server
Source File: server.py
View license
    @asyncio.coroutine
    def handler(self, reader, writer):
        """Handle the connection"""
        try:
            host, port = writer.get_extra_info("peername")
            store, container = None, None
            self.log.info("Incoming connection from %s:%s" % (host,port))

            # initial handshake
            writer.write(b"NBDMAGIC" + struct.pack(">QH", self.NBD_HANDSHAKE, self.NBD_HANDSHAKE_FLAGS))
            yield from writer.drain()

            data = yield from reader.readexactly(4)
            try:
                client_flag = struct.unpack(">L", data)[0]
            except struct.error:
                raise IOError("Handshake failed, disconnecting")

            # we support both fixed and unfixed new-style handshake
            if client_flag == 0:
                fixed = False
                self.log.warning("Client using new-style non-fixed handshake")
            elif client_flag & 1:
                fixed = True
            else:
                raise IOError("Handshake failed, disconnecting")

            # negotiation phase
            while True:
                header = yield from reader.readexactly(16)
                try:
                    (magic, opt, length) = struct.unpack(">QLL", header)
                except struct.error as ex:
                    raise IOError("Negotiation failed: Invalid request, disconnecting")

                if magic != self.NBD_HANDSHAKE:
                    raise IOError("Negotiation failed: bad magic number: %s" % magic)

                if length:
                    data = yield from reader.readexactly(length)
                    if(len(data) != length):
                        raise IOError("Negotiation failed: %s bytes expected" % length)
                else:
                    data = None

                self.log.debug("[%s:%s]: opt=%s, len=%s, data=%s" % (host, port, opt, length, data))

                if opt == self.NBD_OPT_EXPORTNAME:
                    if not data:
                        raise IOError("Negotiation failed: no export name was provided")

                    data = data.decode("utf-8")
                    if data not in self.stores:
                        if not fixed:
                            raise IOError("Negotiation failed: unknown export name")

                        writer.write(struct.pack(">QLLL", self.NBD_REPLY, opt, self.NBD_REP_ERR_UNSUP, 0))
                        yield from writer.drain()
                        continue

                    # we have negotiated a store and it will be used
                    # until the client disconnects
                    store = self.stores[data]
                    store.lock("%s:%s" % (host, port))

                    self.log.info("[%s:%s] Negotiated export: %s" % (host, port, store.container))

                    export_flags = self.NBD_EXPORT_FLAGS
                    if store.read_only:
                        export_flags ^= self.NBD_RO_FLAG
                        self.log.info("[%s:%s] %s is read only" % (host, port, store.container))
                    writer.write(struct.pack('>QH', store.size, export_flags))
                    writer.write(b"\x00"*124)
                    yield from writer.drain()

                    break

                elif opt == self.NBD_OPT_LIST:
                    for container in self.stores.keys():
                        writer.write(struct.pack(">QLLL", self.NBD_REPLY, opt, self.NBD_REP_SERVER, len(container) + 4))
                        container_encoded = container.encode("utf-8")
                        writer.write(struct.pack(">L", len(container_encoded)))
                        writer.write(container_encoded)
                        yield from writer.drain()

                    writer.write(struct.pack(">QLLL", self.NBD_REPLY, opt, self.NBD_REP_ACK, 0))
                    yield from writer.drain()

                elif opt == self.NBD_OPT_ABORT:
                    writer.write(struct.pack(">QLLL", self.NBD_REPLY, opt, self.NBD_REP_ACK, 0))
                    yield from writer.drain()

                    raise AbortedNegotiationError()
                else:
                    # we don't support any other option
                    if not fixed:
                        raise IOError("Unsupported option")

                    writer.write(struct.pack(">QLLL", self.NBD_REPLY, opt, self.NBD_REP_ERR_UNSUP, 0))
                    yield from writer.drain()

            # operation phase
            while True:
                header = yield from reader.readexactly(28)
                try:
                    (magic, cmd, handle, offset, length) = struct.unpack(">LLQQL", header)
                except struct.error:
                    raise IOError("Invalid request, disconnecting")

                if magic != self.NBD_REQUEST:
                    raise IOError("Bad magic number, disconnecting")

                self.log.debug("[%s:%s]: cmd=%s, handle=%s, offset=%s, len=%s" % (host, port, cmd, handle, offset, length))

                if cmd == self.NBD_CMD_DISC:
                    self.log.info("[%s:%s] disconnecting" % (host, port))
                    break

                elif cmd == self.NBD_CMD_WRITE:
                    data = yield from reader.readexactly(length)
                    if(len(data) != length):
                        raise IOError("%s bytes expected, disconnecting" % length)

                    try:
                        store.seek(offset)
                        store.write(data)
                    except IOError as ex:
                        self.log.error("[%s:%s] %s" % (host, port, ex))
                        yield from self.nbd_response(writer, handle, error=ex.errno)
                        continue

                    self.stats[store].bytes_in += length
                    yield from self.nbd_response(writer, handle)

                elif cmd == self.NBD_CMD_READ:
                    try:
                        store.seek(offset)
                        data = store.read(length)
                    except IOError as ex:
                        self.log.error("[%s:%s] %s" % (host, port, ex))
                        yield from self.nbd_response(writer, handle, error=ex.errno)
                        continue

                    if data:
                        self.stats[store].bytes_out += len(data)
                    yield from self.nbd_response(writer, handle, data=data)

                elif cmd == self.NBD_CMD_FLUSH:
                    store.flush()
                    yield from self.nbd_response(writer, handle)

                else:
                    self.log.warning("[%s:%s] Unknown cmd %s, disconnecting" % (host, port, cmd))
                    break

        except AbortedNegotiationError:
            self.log.info("[%s:%s] Client aborted negotiation" % (host, port))

        except (asyncio.IncompleteReadError, IOError) as ex:
            self.log.error("[%s:%s] %s" % (host, port, ex))

        finally:
            if store:
                try:
                    store.unlock()
                except IOError as ex:
                    self.log.error(ex)

            writer.close()

Example 36

Project: Harness
Source File: PSHandler.py
View license
    @asyncio.coroutine
    def handle_client(self, SID, client_reader, client_writer):
        

        # Handle the requests for a specific client with a line oriented protocol
        while self.isrunning():

            '''

                SEND DATA

            '''
            
            cmd = yield from self.get_input(SID)

            if cmd:

                # Begin Remote Module Loading Code
                if cmd != "\r\n" and cmd != "\n":
                    cmd_parts = cmd.split()

                    if cmd_parts[0] == "^import-module":

                        client_writer.write("<rf>".encode())
                        yield from asyncio.sleep(1)

                        try:

                            with open(cmd_parts[1], 'rb') as fin, open(cmd_parts[1]+".base64", 'wb') as fout:
                                self.print_output("Encoding file for transfer")
                                fout.write(base64.b64encode(fin.read()))

                            with open(cmd_parts[1]+".base64", 'rb') as f:
                                self.print_output("Sending " + cmd_parts[1])
                                client_writer.writelines(f)

                                yield from client_writer.drain()

                            os.remove(cmd_parts[1]+".base64")

                        except OSError:
                            self.print_error("File not found")
                            pass

                        yield from asyncio.sleep(1)             # Give buffer chance to flush before sending closing tag         
                        client_writer.write("</rf>".encode())   # signal that we're done transfering the module

                    else:
                        client_writer.write(cmd.encode())

                # End Remote Module Loading code. For a plain TCP handler remove this block
                # and just include the following code after the else
                else:
                    
                    client_writer.write(cmd.encode())

                    if cmd.lower() == "exit":
                        break

            '''

                RECEIVE DATA

            '''
                    
            while self.isrunning():
                try:
                    if client_reader.at_eof():
                        raise ConnectionError

                    _data = (yield from asyncio.wait_for(client_reader.read(1024),timeout=0.1))
                    self.print(_data.decode(),end="",flush=True)

                except ConnectionError:

                    return
                    
                except:
                    
                    # Did not received any new data break out
                    break

Example 37

Project: pycoinnet
Source File: InvCollector.py
View license
    @asyncio.coroutine
    def fetch(self, inv_item, peer_timeout=10):
        # create the queue of peers that have this inv_item available
        q = asyncio.Queue()
        items = sorted(self.inv_item_db[inv_item.data].items(), key=lambda pair: pair[-1])
        for peer, when in items:
            q.put_nowait(peer)
        # make the queue available to the object so if more peers
        # announce they have it, they can be queried
        self.inv_item_peers_q[inv_item.data] = q

        # this is the set of futures that we are trying to fetch the item from
        pending_fetchers = set()

        # this async method delays a specific amount of time, then
        # fetches a new peer from the queue
        @asyncio.coroutine
        def _wait_for_timeout_and_peer(q, initial_delay=0):
            yield from asyncio.sleep(initial_delay)
            while True:
                peer = yield from q.get()
                fetcher = self.fetchers_by_peer.get(peer)
                if fetcher:
                    break
            logging.debug("requesting %s from %s", inv_item, peer)
            return asyncio.Task(fetcher.fetch(inv_item))

        # the loop works like this:
        #   request the item from a peer
        #   wait 10 s or for response (either notfound or found)
        #   if found, done
        #   if notfound or time out, get another peer

        most_recent_fetcher = None

        while True:
            if most_recent_fetcher is None and q.qsize() > 0:
                most_recent_fetcher = yield from _wait_for_timeout_and_peer(q)
                timer_future = asyncio.Task(_wait_for_timeout_and_peer(q, initial_delay=peer_timeout))

            futures = pending_fetchers.union(set([timer_future]))

            if most_recent_fetcher:
                futures.add(most_recent_fetcher)

            done, pending_fetchers = \
                yield from asyncio.wait(list(futures), return_when=asyncio.FIRST_COMPLETED)

            # is it time to try a new fetcher?
            if timer_future in done:
                # we timed out, so we need to queue up another peer
                most_recent_fetcher = timer_future.result()
                logging.debug("timeout, need to request from a new peer, %s", inv_item)
                timer_future = asyncio.Task(_wait_for_timeout_and_peer(q, initial_delay=peer_timeout))
                # we have a new peer available as the result of get_fetcher_future
                continue

            # is the most recent done?
            if most_recent_fetcher and most_recent_fetcher.done():
                r = most_recent_fetcher.result()
                if r:
                    return r
                # we got a "notfound" from this one
                # queue up another peer
                logging.debug("got a notfound, need to try a new peer for %s", inv_item)
                timer_future.cancel()
                pending_fetchers.discard(timer_future)
                most_recent_fetcher = None
                timer_future = asyncio.Task(_wait_for_timeout_and_peer(q, initial_delay=0))
                # we have a new peer available as the result of timer_future
                continue

            # one or more fetchers is done
            # if any of them have a non-None result, we're golden
            for f in done:
                r = f.result()
                if r:
                    logging.info("Got %s", r)
                    return r

Example 38

Project: asyncssh
Source File: server.py
View license
    @classmethod
    @asyncio.coroutine
    def asyncSetUpClass(cls):
        """Set up keys, an SSH server, and an SSH agent for the tests to use"""

        ckey_dsa = asyncssh.generate_private_key('ssh-dss')
        ckey_dsa.write_private_key('ckey_dsa')
        ckey_dsa.write_public_key('ckey_dsa.pub')

        ckey = asyncssh.generate_private_key('ssh-rsa')
        ckey.write_private_key('ckey')
        ckey.write_public_key('ckey.pub')

        ckey_cert = ckey.generate_user_certificate(ckey, 'name')
        ckey_cert.write_certificate('ckey-cert.pub')

        skey = asyncssh.generate_private_key('ssh-rsa')
        skey.write_private_key('skey')
        skey.write_public_key('skey.pub')

        skey_cert = skey.generate_host_certificate(skey, 'name')
        skey_cert.write_certificate('skey-cert.pub')

        exp_cert = skey.generate_host_certificate(skey, 'name',
                                                  valid_after='-2d',
                                                  valid_before='-1d')
        skey.write_private_key('exp_skey')
        exp_cert.write_certificate('exp_skey-cert.pub')

        run('chmod 600 ckey_dsa ckey skey exp_skey')

        run('mkdir .ssh')
        run('chmod 700 .ssh')
        run('cp ckey_dsa .ssh/id_dsa')
        run('cp ckey_dsa.pub .ssh/id_dsa.pub')
        run('cp ckey .ssh/id_rsa')
        run('cp ckey.pub .ssh/id_rsa.pub')

        run('printf "cert-authority,principals=\"ckey\" " > authorized_keys')
        run('cat ckey.pub >> authorized_keys')
        run('printf "permitopen=\":*\" " >> authorized_keys')
        run('cat ckey.pub >> authorized_keys')
        run('cat ckey_dsa.pub >> authorized_keys')

        cls._server = yield from cls.start_server()

        sock = cls._server.sockets[0]
        cls._server_addr, cls._server_port = sock.getsockname()[:2]

        run('printf "[%s]:%s " > .ssh/known_hosts' % (cls._server_addr,
                                                      cls._server_port))
        run('cat skey.pub >> .ssh/known_hosts')

        output = run('ssh-agent -a agent 2>/dev/null')
        cls._agent_pid = int(output.splitlines()[2].split()[3][:-1])

        os.environ['SSH_AUTH_SOCK'] = 'agent'

        agent = yield from asyncssh.connect_agent()
        yield from agent.add_keys([ckey_dsa, (ckey, ckey_cert)])
        agent.close()

        os.environ['LOGNAME'] = 'guest'
        os.environ['HOME'] = '.'

Example 39

Project: umongo
Source File: test_motor_asyncio.py
View license
    def test_cursor(self, loop, classroom_model):
        Student = classroom_model.Student

        @asyncio.coroutine
        def do_test():
            Student.collection.drop()

            for i in range(10):
                yield from Student(name='student-%s' % i).commit()
            cursor = Student.find(limit=5, skip=6)
            count = yield from cursor.count()
            assert count == 10
            count_with_limit_and_skip = yield from cursor.count(with_limit_and_skip=True)
            assert count_with_limit_and_skip == 4

            # Make sure returned documents are wrapped
            names = []
            for elem in (yield from cursor.to_list(length=100)):
                assert isinstance(elem, Student)
                names.append(elem.name)
            assert sorted(names) == ['student-%s' % i for i in range(6, 10)]

            # Try with fetch_next as well
            names = []
            cursor.rewind()
            while (yield from cursor.fetch_next):
                elem = cursor.next_object()
                assert isinstance(elem, Student)
                names.append(elem.name)
            assert sorted(names) == ['student-%s' % i for i in range(6, 10)]

            # Try with each as well
            names = []
            cursor.rewind()
            future = asyncio.Future()

            def callback(result, error):
                if error:
                    future.set_exception(error)
                elif result is None:
                    # Iteration complete
                    future.set_result(names)
                else:
                    names.append(result.name)

            cursor.each(callback=callback)
            yield from future
            assert sorted(names) == ['student-%s' % i for i in range(6, 10)]

            # Make sure this kind of notation doesn't create new cursor
            cursor = Student.find()
            cursor_limit = cursor.limit(5)
            cursor_skip = cursor.skip(6)
            assert cursor is cursor_limit is cursor_skip

            # Test clone&rewind as well
            cursor = Student.find()
            cursor2 = cursor.clone()
            yield from cursor.fetch_next
            yield from cursor2.fetch_next
            cursor_student = cursor.next_object()
            cursor2_student = cursor2.next_object()
            assert cursor_student == cursor2_student

        loop.run_until_complete(do_test())

Example 40

Project: umongo
Source File: test_motor_asyncio.py
View license
    @pytest.mark.xfail
    def test_unique_index_inheritance(self, loop, instance):

        @asyncio.coroutine
        def do_test():

            @instance.register
            class UniqueIndexParentDoc(Document):
                not_unique = fields.StrField(unique=False)
                unique = fields.IntField(unique=True)

                class Meta:
                    collection = 'unique_index_inheritance_doc'
                    allow_inheritance = True

            @instance.register
            class UniqueIndexChildDoc(UniqueIndexParentDoc):
                child_not_unique = fields.StrField(unique=False)
                child_unique = fields.IntField(unique=True)
                manual_index = fields.IntField()

                class Meta:
                    indexes = ['manual_index']

            UniqueIndexChildDoc.collection.drop_indexes()

            # Now ask for indexes building
            UniqueIndexChildDoc.ensure_indexes()
            indexes = [e for e in UniqueIndexChildDoc.collection.list_indexes()]
            expected_indexes = [
                {
                    'key': {'_id': 1},
                    'name': '_id_',
                    'ns': '%s.unique_index_inheritance_doc' % TEST_DB,
                    'v': 1
                },
                {
                    'v': 1,
                    'key': {'unique': 1},
                    'name': 'unique_1',
                    'unique': True,
                    'ns': '%s.unique_index_inheritance_doc' % TEST_DB
                },
                {
                    'v': 1,
                    'key': {'manual_index': 1, '_cls': 1},
                    'name': 'manual_index_1__cls_1',
                    'ns': '%s.unique_index_inheritance_doc' % TEST_DB
                },
                {
                    'v': 1,
                    'key': {'_cls': 1},
                    'name': '_cls_1',
                    'unique': True,
                    'ns': '%s.unique_index_inheritance_doc' % TEST_DB
                },
                {
                    'v': 1,
                    'key': {'child_unique': 1, '_cls': 1},
                    'name': 'child_unique_1__cls_1',
                    'unique': True,
                    'ns': '%s.unique_index_inheritance_doc' % TEST_DB
                }
            ]
            assert name_sorted(indexes) == name_sorted(expected_indexes)

            # Redoing indexes building should do nothing
            UniqueIndexChildDoc.ensure_indexes()
            indexes = [e for e in UniqueIndexChildDoc.collection.list_indexes()]
            assert name_sorted(indexes) == name_sorted(expected_indexes)

        loop.run_until_complete(do_test())

Example 41

Project: umongo
Source File: motor_asyncio.py
View license
    @asyncio.coroutine
    def commit(self, io_validate_all=False, conditions=None):
        """
        Commit the document in database.
        If the document doesn't already exist it will be inserted, otherwise
        it will be updated.

        :param io_validate_all:
        :param conditions: only perform commit if matching record in db
            satisfies condition(s) (e.g. version number).
            Raises :class:`umongo.exceptions.UpdateError` if the
            conditions are not satisfied.
        :return: Update result dict returned by underlaying driver or
            ObjectId of the inserted document.
        """
        try:
            if self.is_created:
                if self.is_modified():
                    query = conditions or {}
                    query['_id'] = self.pk
                    # pre_update can provide additional query filter and/or
                    # modify the fields' values
                    additional_filter = yield from self.__coroutined_pre_update()
                    if additional_filter:
                        query.update(map_query(additional_filter, self.schema.fields))
                    yield from self.io_validate(validate_all=io_validate_all)
                    payload = self._data.to_mongo(update=True)
                    ret = yield from self.collection.update(query, payload)
                    if ret.get('ok') != 1 or ret.get('n') != 1:
                        raise UpdateError(ret)
                    yield from self.__coroutined_post_update(ret)
                else:
                    ret = None
            elif conditions:
                raise RuntimeError('Document must already exist in database to use `conditions`.')
            else:
                yield from self.__coroutined_pre_insert()
                self.required_validate()
                yield from self.io_validate(validate_all=io_validate_all)
                payload = self._data.to_mongo(update=False)
                ret = yield from self.collection.insert(payload)
                # TODO: check ret ?
                self._data.set_by_mongo_name('_id', ret)
                self.is_created = True
                yield from self.__coroutined_post_insert(ret)
        except DuplicateKeyError as exc:
            # Need to dig into error message to find faulting index
            errmsg = exc.details['errmsg']
            for index in self.opts.indexes:
                if ('.$%s' % index.document['name'] in errmsg or
                        ' %s ' % index.document['name'] in errmsg):
                    keys = index.document['key'].keys()
                    if len(keys) == 1:
                        key = tuple(keys)[0]
                        msg = self.schema.fields[key].error_messages['unique']
                        raise ValidationError({key: msg})
                    else:
                        fields = self.schema.fields
                        # Compound index (sort value to make testing easier)
                        keys = sorted(keys)
                        raise ValidationError({k: fields[k].error_messages[
                            'unique_compound'].format(fields=keys) for k in keys})
            # Unknown index, cannot wrap the error so just reraise it
            raise
        self._data.clear_modified()
        return ret

Example 42

Project: netwrok-server
Source File: mailer.py
View license
@asyncio.coroutine
def mailer():
    """
    Periodicaly poll the database mailqueue table and send emails.
    """
    while True: 
        with (yield from nwdb.connection()) as conn:
            cursor = yield from conn.cursor()
            cursor.execute("rollback")
            cursor.execute("begin")
            yield from cursor.execute("""
            select id, member_id, address, subject, body
            from mailqueue where sent = false and error = false
            order by created limit 1
            """)
            rs = yield from cursor.fetchone()
            if rs is None:
                yield from asyncio.sleep(config["MAIL"]["MAILER_IDLE_TIME"])
            else:
                sent = False
                error = False
                try:
                    server = smtplib.SMTP(config["MAIL"]["SERVER"])
                    fromaddr = config["MAIL"]["FROM_ADDRESS"]
                    toaddrs = [rs[2]]
                    msg = "From: %s\r\nTo: %s\r\nSubject: %s\r\n\r\n%s"%(fromaddr, rs[2], rs[3], rs[4])
                    server.sendmail(fromaddr, toaddrs, msg)
                    server.quit()
                    print ("Email Sent: (" + str(rs[0]) + ") to " + rs[2])  
                except Exception as e:
                    print(type(e), str(e))
                    error = True
                    sent = False
                else:
                    error = False
                    sent = True

                yield from cursor.execute("""
                update mailqueue set sent = %s, error = %s where id = %s
                """, [sent, error, rs[0]])
                yield from cursor.execute("""
                delete from mailqueue where sent = true and now() - created > interval '1 days'
                """)
                cursor.execute("commit")
                yield from asyncio.sleep(0.1)

Example 43

Project: aiohttp_utils
Source File: negotiation.py
View license
def negotiation_middleware(
    renderers=DEFAULTS['RENDERERS'],
    negotiator=DEFAULTS['NEGOTIATOR'],
    force_negotiation=DEFAULTS['FORCE_NEGOTIATION']
):
    """Middleware which selects a renderer for a given request then renders
    a handler's data to a `aiohttp.web.Response`.
    """
    @asyncio.coroutine
    def factory(app, handler):

        @asyncio.coroutine
        def middleware(request):
            content_type, renderer = negotiator(
                request,
                renderers,
                force_negotiation,
            )
            request['selected_media_type'] = content_type
            response = yield from handler(request)

            if getattr(response, 'data', None):
                # Render data with the selected renderer
                if asyncio.iscoroutinefunction(renderer):
                    render_result = yield from renderer(request, response.data)
                else:
                    render_result = renderer(request, response.data)
            else:
                render_result = response

            if isinstance(render_result, web.Response):
                return render_result

            if getattr(response, 'data', None):
                response.body = render_result
                response.content_type = content_type

            return response
        return middleware
    return factory

Example 44

Project: PushBank2
Source File: hana.py
View license
@asyncio.coroutine
def query(account, password, resident):
    """
    하나은행 계좌 잔액 빠른조회. 빠른조회 서비스에 등록이 되어있어야 사용 가능.
    빠른조회 서비스:
    https://open.hanabank.com/flex/quick/quickService.do?oid=quickservice

    account  -- 계좌번호 ('-' 제외)
    password -- 계좌 비밀번호 (숫자 4자리)
    resident -- 주민등록번호 앞 6자리
    """

    if len(password) != 4 or not password.isdigit():
        raise ValueError("password: 비밀번호는 숫자 4자리여야 합니다.")

    if len(resident) != 6 or not resident.isdigit():
        raise ValueError("resident: 주민등록번호 앞 6자리를 입력해주세요.")

    params = {
        'ajax': 'true',
        'acctNo': account,
        'acctPw': password,
        'bkfgResRegNo': resident,
        'curCd': '',
        'inqStrDt': (datetime.now(_kst_timezone) - timedelta(days=14)).strftime('%Y%m%d'),
        'inqEndDt': datetime.now(_kst_timezone).strftime('%Y%m%d'),
        'rvSeqInqYn': 'Y',
        'rcvWdrwDvCd': '',
        'rqstNcnt': '30',
        'maxRowCount': '700',
        'rqstPage': '1',
        'acctType': '01',
        'language': 'KOR'
    }

    try:
        r = _session.get(_url, params=params, timeout=10)
        data = r.text
        success = True
    except:
        success = False

    d = {
        'success': success,
        'account': account,
    }
    if success:
        data = data.replace('&nbsp;', '')
        data = BeautifulSoup(data)
        balance = data.select('table.tbl_col01' +
                              ' tr:nth-of-type(2) td')[0].text.strip()
        balance = int(balance.replace(',', ''))
        history = [
            [y.text.strip() for y in x.select('td')]
            for x in data.select('table.tbl_col01')[1].select('tbody tr')
        ]

        '''
        순서:
            거래일, 구분, 적요, 입금액, 출금액, 잔액, 거래시간, 거래점
        '''

        d['balance'] = balance
        d['history'] = [{
            'date': datetime.strptime('{0},{1}'.format(x[0], x[6]),
                                      '%Y-%m-%d,%H:%M').date(),
            'type': x[1],
            'depositor': x[2],
            'withdraw': int(x[3].replace(',', '') if x[3] else '0'),
            'pay': int(x[4].replace(',', '') if x[4] else '0'),
            'balance': int(x[5].replace(',', '')),
            'distributor': x[7],
        } for x in history]

    return d

Example 45

Project: PushBank2
Source File: kbstar.py
View license
@asyncio.coroutine
def query(account, password, resident, username):
    """
    국민은행 계좌 잔액 빠른조회. 빠른조회 서비스에 등록이 되어있어야 사용 가능.
    빠른조회 서비스: https://obank.kbstar.com/quics?page=C018920

    account  -- 계좌번호 ('-' 제외)
    password -- 계좌 비밀번호 (숫자 4자리)
    resident -- 주민등록번호 끝 7자리
    username -- 인터넷 뱅킹 ID (대문자)
    """

    if len(password) != 4 or not password.isdigit():
        raise ValueError("password: 비밀번호는 숫자 4자리여야 합니다.")

    if len(resident) != 7 or not resident.isdigit():
        raise ValueError("resident: 주민등록번호 끝 7자리를 입력해주세요.")

    params = {
        '다음거래년월일키': '',
        '다음거래일련번호키': '',
        '계좌번호': account,
        '비밀번호': password,
        '조회시작일': (datetime.now(_kst_timezone) - timedelta(days=14)).strftime('%Y%m%d'),
        '조회종료일': datetime.now(_kst_timezone).strftime('%Y%m%d'),
        '주민사업자번호': '000000' + resident,
        '고객식별번호': username.upper(),
        '응답방법': '2',
        '조회구분': '2',
        'USER_TYPE': '02',
        '_FILE_NAME': 'KB_거래내역빠른조회.html',
        '_LANG_TYPE': 'KOR'
    }

    try:
        r = _session.get(_url, params=params, timeout=10)
        data = r.text
        success = True
    except:
        success = False

    d = {
        'success': success,
        'account': account,
    }
    if success:
        data = data.replace('&nbsp;', '')
        data = BeautifulSoup(data)
        balance = data.select('table table:nth-of-type(1)' +
                              ' tr:nth-of-type(3) td')[-1].text
        balance = int(balance.replace(',', ''))
        history = [
            [y.text.strip() for y in x.select('td')]
            for x in
            data.select('table table:nth-of-type(2) tr[align="center"]')
        ]

        '''
        순서:
            거래일, 적요, 의뢰인/수치인, 내통장표시, 출금금액, 입금금액, 잔액, 취급점, 구분
        '''

        d['balance'] = balance
        d['history'] = [{
            'date': datetime.strptime(x[0], '%Y.%m.%d%H:%M:%S').date(),
            'type': x[1],
            'depositor': x[2],
            'pay': int(x[4].replace(',', '')),
            'withdraw': int(x[5].replace(',', '')),
            'balance': int(x[6].replace(',', '')),
            'distributor': x[7],
        } for x in history]

    return d

Example 46

Project: PushBank2
Source File: nhbank.py
View license
@asyncio.coroutine
def query(account, password, resident):
    """
    농협은행 계좌 잔액 빠른조회. 빠른조회 서비스에 등록이 되어있어야 사용 가능.
    빠른조회 서비스:
    https://banking.nonghyup.com/servlet/IPAM0011I.view

    account  -- 계좌번호 ('-' 제외)
    password -- 계좌 비밀번호 (숫자 4자리)
    resident -- 주민등록번호 앞 6자리
    """

    if len(password) != 4 or not password.isdigit():
        raise ValueError("password: 비밀번호는 숫자 4자리여야 합니다.")

    if len(resident) != 6 or not resident.isdigit():
        raise ValueError("resident: 주민등록번호 앞 6자리를 입력해주세요.")

    tokens = _acquire_tokens()
    start_date = (datetime.now(_kst_timezone) - timedelta(days=14)).strftime('%Y%m%d')
    end_date = datetime.now(_kst_timezone).strftime('%Y%m%d')

    payload = {
        'GjaGbn': '1',
        'InqGjaNbr': account,
        'GjaSctNbr': password,
        'rlno1': resident,
        'InqGbn_2': '2',
        'InqGbn': '1',
        'InqFdt': start_date,
        'InqEndDat': end_date,
        'InqDat': start_date,
        'EndDat': end_date,
        'SESSION_TOKEN': tokens[0],
        'TOKEN': tokens[1],
    }

    try:
        response = _session.post('https://banking.nonghyup.com/servlet/IPMS0012R.frag', data=payload, timeout=10)
        data = response.text
        success = True

        if '<div class="error">' in data: # e.g. maintenance
            success =False
    except:
        success = False

    result = {
        'success': success,
        'account': account,
    }

    if success:
        data = data.replace('<br>', ' ')
        soup = BeautifulSoup(data)

        balance = soup.select('.tb_row tr')[1].select('td')[1].text.strip()
        transactions = [
            [td.text.strip() for td in tr.select('td')]
            for tr in soup.select('#listTable tbody tr')
        ]

        '''
        순서:
            순번, 거래일자, 출금금액, 입금금액, 거래후잔액, 거래내용, 거래기록사항, 거래점
        '''

        result['balance'] = _as_int(balance)
        result['history'] = [{
            'date': datetime.strptime(transaction[1], '%Y/%m/%d %H:%M:%S').date(),
            'withdraw': _as_int(transaction[2]),
            'pay': _as_int(transaction[3]),
            'balance': _as_int(transaction[4]),
            'type': transaction[5],
            'depositor': transaction[6],
            'distributor': transaction[7],
        } for transaction in transactions]

    return result

Example 47

Project: hbmqtt
Source File: broker.py
View license
    @asyncio.coroutine
    def start(self):
        """
            Start the broker to serve with the given configuration

            Start method opens network sockets and will start listening for incoming connections.

            This method is a *coroutine*.
        """
        try:
            self._sessions = dict()
            self._subscriptions = dict()
            self._retained_messages = dict()
            self.transitions.start()
            self.logger.debug("Broker starting")
        except MachineError as me:
            self.logger.warn("[WARN-0001] Invalid method call at this moment: %s" % me)
            raise BrokerException("Broker instance can't be started: %s" % me)

        yield from self.plugins_manager.fire_event(EVENT_BROKER_PRE_START)
        try:
            # Start network listeners
            for listener_name in self.listeners_config:
                listener = self.listeners_config[listener_name]

                if 'bind' not in listener:
                    self.logger.debug("Listener configuration '%s' is not bound" % listener_name)
                else:
                    # Max connections
                    try:
                        max_connections = listener['max_connections']
                    except KeyError:
                        max_connections = -1

                    # SSL Context
                    sc = None

                    # accept string "on" / "off" or boolean
                    ssl_active = listener.get('ssl', False)
                    if isinstance(ssl_active, str):
                        ssl_active = ssl_active.upper() == 'ON'

                    if ssl_active:
                        try:
                            sc = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
                            sc.load_cert_chain(listener['certfile'], listener['keyfile'])
                            sc.verify_mode = ssl.CERT_OPTIONAL
                        except KeyError as ke:
                            raise BrokerException("'certfile' or 'keyfile' configuration parameter missing: %s" % ke)
                        except FileNotFoundError as fnfe:
                            raise BrokerException("Can't read cert files '%s' or '%s' : %s" %
                                                  (listener['certfile'], listener['keyfile'], fnfe))

                    address, s_port = listener['bind'].split(':')
                    port = 0
                    try:
                        port = int(s_port)
                    except ValueError as ve:
                        raise BrokerException("Invalid port value in bind value: %s" % listener['bind'])

                    if listener['type'] == 'tcp':
                        cb_partial = partial(self.stream_connected, listener_name=listener_name)
                        instance = yield from asyncio.start_server(cb_partial,
                                                                   address,
                                                                   port,
                                                                   ssl=sc,
                                                                   loop=self._loop)
                        self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)
                    elif listener['type'] == 'ws':
                        cb_partial = partial(self.ws_connected, listener_name=listener_name)
                        instance = yield from websockets.serve(cb_partial, address, port, ssl=sc, loop=self._loop,
                                                               subprotocols=['mqtt'])
                        self._servers[listener_name] = Server(listener_name, instance, max_connections, self._loop)

                    self.logger.info("Listener '%s' bind to %s (max_connections=%d)" %
                                     (listener_name, listener['bind'], max_connections))

            self.transitions.starting_success()
            yield from self.plugins_manager.fire_event(EVENT_BROKER_POST_START)

            #Start broadcast loop
            self._broadcast_task = ensure_future(self._broadcast_loop(), loop=self._loop)

            self.logger.debug("Broker started")
        except Exception as e:
            self.logger.error("Broker startup failed: %s" % e)
            self.transitions.starting_fail()
            raise BrokerException("Broker instance can't be started: %s" % e)

Example 48

Project: hbmqtt
Source File: broker.py
View license
    @asyncio.coroutine
    def client_connected(self, listener_name, reader: ReaderAdapter, writer: WriterAdapter):
        # Wait for connection available on listener
        server = self._servers.get(listener_name, None)
        if not server:
            raise BrokerException("Invalid listener name '%s'" % listener_name)
        yield from server.acquire_connection()

        remote_address, remote_port = writer.get_peer_info()
        self.logger.info("Connection from %s:%d on listener '%s'" % (remote_address, remote_port, listener_name))

        # Wait for first packet and expect a CONNECT
        try:
            handler, client_session = yield from BrokerProtocolHandler.init_from_connect(reader, writer, self.plugins_manager, loop=self._loop)
        except HBMQTTException as exc:
            self.logger.warn("[MQTT-3.1.0-1] %s: Can't read first packet an CONNECT: %s" %
                             (format_client_message(address=remote_address, port=remote_port), exc))
            #yield from writer.close()
            self.logger.debug("Connection closed")
            return
        except MQTTException as me:
            self.logger.error('Invalid connection from %s : %s' %
                              (format_client_message(address=remote_address, port=remote_port), me))
            yield from writer.close()
            self.logger.debug("Connection closed")
            return

        if client_session.clean_session:
            # Delete existing session and create a new one
            if client_session.client_id is not None:
                self.delete_session(client_session.client_id)
            else:
                client_session.client_id = gen_client_id()
            client_session.parent = 0
        else:
            # Get session from cache
            if client_session.client_id in self._sessions:
                self.logger.debug("Found old session %s" % repr(self._sessions[client_session.client_id]))
                (client_session,h) = self._sessions[client_session.client_id]
                client_session.parent = 1
            else:
                client_session.parent = 0
        if client_session.keep_alive > 0:
            client_session.keep_alive += self.config['timeout-disconnect-delay']
        self.logger.debug("Keep-alive timeout=%d" % client_session.keep_alive)

        handler.attach(client_session, reader, writer)
        self._sessions[client_session.client_id] = (client_session, handler)

        authenticated = yield from self.authenticate(client_session, self.listeners_config[listener_name])
        if not authenticated:
            yield from writer.close()
            return

        while True:
            try:
                client_session.transitions.connect()
                break
            except MachineError:
                self.logger.warning("Client %s is reconnecting too quickly, make it wait" % client_session.client_id)
                # Wait a bit may be client is reconnecting too fast
                yield from asyncio.sleep(1, loop=self._loop)
        yield from handler.mqtt_connack_authorize(authenticated)

        yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_CONNECTED, client_id=client_session.client_id)

        self.logger.debug("%s Start messages handling" % client_session.client_id)
        yield from handler.start()
        self.logger.debug("Retained messages queue size: %d" % client_session.retained_messages.qsize())
        yield from self.publish_session_retained_messages(client_session)

        # Init and start loop for handling client messages (publish, subscribe/unsubscribe, disconnect)
        disconnect_waiter = ensure_future(handler.wait_disconnect(), loop=self._loop)
        subscribe_waiter = ensure_future(handler.get_next_pending_subscription(), loop=self._loop)
        unsubscribe_waiter = ensure_future(handler.get_next_pending_unsubscription(), loop=self._loop)
        wait_deliver = ensure_future(handler.mqtt_deliver_next_message(), loop=self._loop)
        connected = True
        while connected:
            try:
                done, pending = yield from asyncio.wait(
                    [disconnect_waiter, subscribe_waiter, unsubscribe_waiter, wait_deliver],
                    return_when=asyncio.FIRST_COMPLETED, loop=self._loop)
                if disconnect_waiter in done:
                    result = disconnect_waiter.result()
                    self.logger.debug("%s Result from wait_diconnect: %s" % (client_session.client_id, result))
                    if result is None:
                        self.logger.debug("Will flag: %s" % client_session.will_flag)
                        # Connection closed anormally, send will message
                        if client_session.will_flag:
                            self.logger.debug("Client %s disconnected abnormally, sending will message" %
                                              format_client_message(client_session))
                            yield from self._broadcast_message(
                                client_session,
                                client_session.will_topic,
                                client_session.will_message,
                                client_session.will_qos)
                            if client_session.will_retain:
                                self.retain_message(client_session,
                                                    client_session.will_topic,
                                                    client_session.will_message,
                                                    client_session.will_qos)
                    self.logger.debug("%s Disconnecting session" % client_session.client_id)
                    yield from self._stop_handler(handler)
                    client_session.transitions.disconnect()
                    yield from self.plugins_manager.fire_event(EVENT_BROKER_CLIENT_DISCONNECTED, client_id=client_session.client_id)
                    connected = False
                if unsubscribe_waiter in done:
                    self.logger.debug("%s handling unsubscription" % client_session.client_id)
                    unsubscription = unsubscribe_waiter.result()
                    for topic in unsubscription['topics']:
                        self._del_subscription(topic, client_session)
                        yield from self.plugins_manager.fire_event(
                            EVENT_BROKER_CLIENT_UNSUBSCRIBED,
                            client_id=client_session.client_id,
                            topic=topic)
                    yield from handler.mqtt_acknowledge_unsubscription(unsubscription['packet_id'])
                    unsubscribe_waiter = asyncio.Task(handler.get_next_pending_unsubscription(), loop=self._loop)
                if subscribe_waiter in done:
                    self.logger.debug("%s handling subscription" % client_session.client_id)
                    subscriptions = subscribe_waiter.result()
                    return_codes = []
                    for subscription in subscriptions['topics']:
                        return_codes.append(self.add_subscription(subscription, client_session))
                    yield from handler.mqtt_acknowledge_subscription(subscriptions['packet_id'], return_codes)
                    for index, subscription in enumerate(subscriptions['topics']):
                        if return_codes[index] != 0x80:
                            yield from self.plugins_manager.fire_event(
                                EVENT_BROKER_CLIENT_SUBSCRIBED,
                                client_id=client_session.client_id,
                                topic=subscription[0],
                                qos=subscription[1])
                            yield from self.publish_retained_messages_for_subscription(subscription, client_session)
                    subscribe_waiter = asyncio.Task(handler.get_next_pending_subscription(), loop=self._loop)
                    self.logger.debug(repr(self._subscriptions))
                if wait_deliver in done:
                    if self.logger.isEnabledFor(logging.DEBUG):
                        self.logger.debug("%s handling message delivery" % client_session.client_id)
                    app_message = wait_deliver.result()
                    if not app_message.topic:
                        self.logger.warn("[MQTT-4.7.3-1] - %s invalid TOPIC sent in PUBLISH message, closing connection" % client_session.client_id)
                        break
                    if "#" in app_message.topic or "+" in app_message.topic:
                        self.logger.warn("[MQTT-3.3.2-2] - %s invalid TOPIC sent in PUBLISH message, closing connection" % client_session.client_id)
                        break
                    yield from self.plugins_manager.fire_event(EVENT_BROKER_MESSAGE_RECEIVED,
                                                               client_id=client_session.client_id,
                                                               message=app_message)
                    yield from self._broadcast_message(client_session, app_message.topic, app_message.data)
                    if app_message.publish_packet.retain_flag:
                        self.retain_message(client_session, app_message.topic, app_message.data, app_message.qos)
                    wait_deliver = asyncio.Task(handler.mqtt_deliver_next_message(), loop=self._loop)
            except asyncio.CancelledError:
                self.logger.debug("Client loop cancelled")
                break
        disconnect_waiter.cancel()
        subscribe_waiter.cancel()
        unsubscribe_waiter.cancel()
        wait_deliver.cancel()

        self.logger.debug("%s Client disconnected" % client_session.client_id)
        server.release_connection()

Example 49

Project: hbmqtt
Source File: broker_handler.py
View license
    @classmethod
    @asyncio.coroutine
    def init_from_connect(cls, reader: ReaderAdapter, writer: WriterAdapter, plugins_manager, loop=None):
        """

        :param reader:
        :param writer:
        :param plugins_manager:
        :param loop:
        :return:
        """
        remote_address, remote_port = writer.get_peer_info()
        connect = yield from ConnectPacket.from_stream(reader)
        yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_RECEIVED, packet=connect)
        if connect.payload.client_id is None:
            raise MQTTException('[[MQTT-3.1.3-3]] : Client identifier must be present' )

        if connect.variable_header.will_flag:
            if connect.payload.will_topic is None or connect.payload.will_message is None:
                raise MQTTException('will flag set, but will topic/message not present in payload')

        if connect.variable_header.reserved_flag:
            raise MQTTException('[MQTT-3.1.2-3] CONNECT reserved flag must be set to 0')
        if connect.proto_name != "MQTT":
            raise MQTTException('[MQTT-3.1.2-1] Incorrect protocol name: "%s"' % connect.proto_name)

        connack = None
        error_msg = None
        if connect.proto_level != 4:
            # only MQTT 3.1.1 supported
            error_msg = 'Invalid protocol from %s: %d' % \
                              (format_client_message(address=remote_address, port=remote_port), connect.proto_level)
            connack = ConnackPacket.build(0, UNACCEPTABLE_PROTOCOL_VERSION)  # [MQTT-3.2.2-4] session_parent=0
        elif not connect.username_flag and connect.password_flag:
            connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD)  # [MQTT-3.1.2-22]
        elif connect.username_flag and not connect.password_flag:
            connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD)  # [MQTT-3.1.2-22]
        elif connect.username_flag and connect.username is None:
            error_msg = 'Invalid username from %s' % \
                              (format_client_message(address=remote_address, port=remote_port))
            connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD)  # [MQTT-3.2.2-4] session_parent=0
        elif connect.password_flag and connect.password is None:
            error_msg = 'Invalid password %s' % (format_client_message(address=remote_address, port=remote_port))
            connack = ConnackPacket.build(0, BAD_USERNAME_PASSWORD)  # [MQTT-3.2.2-4] session_parent=0
        elif connect.clean_session_flag is False and (connect.payload.client_id is None or connect.payload.client_id == ""):
            error_msg = '[MQTT-3.1.3-8] [MQTT-3.1.3-9] %s: No client Id provided (cleansession=0)' % \
                              format_client_message(address=remote_address, port=remote_port)
            connack = ConnackPacket.build(0, IDENTIFIER_REJECTED)
        if connack is not None:
            yield from plugins_manager.fire_event(EVENT_MQTT_PACKET_SENT, packet=connack)
            yield from connack.to_stream(writer)
            yield from writer.close()
            raise MQTTException(error_msg)

        incoming_session = Session(loop)
        incoming_session.client_id = connect.client_id
        incoming_session.clean_session = connect.clean_session_flag
        incoming_session.will_flag = connect.will_flag
        incoming_session.will_retain = connect.will_retain_flag
        incoming_session.will_qos = connect.will_qos
        incoming_session.will_topic = connect.will_topic
        incoming_session.will_message = connect.will_message
        incoming_session.username = connect.username
        incoming_session.password = connect.password
        if connect.keep_alive > 0:
            incoming_session.keep_alive = connect.keep_alive
        else:
            incoming_session.keep_alive = 0

        handler = cls(plugins_manager, loop=loop)
        return handler, incoming_session

Example 50

Project: hangups
Source File: channel.py
View license
    @asyncio.coroutine
    def _longpoll_request(self):
        """Open a long-polling request and receive arrays.

        This method uses keep-alive to make re-opening the request faster, but
        the remote server will set the "Connection: close" header once an hour.

        Raises hangups.NetworkError or UnknownSIDError.
        """
        params = {
            'VER': 8,  # channel protocol version
            'gsessionid': self._gsessionid_param,
            'RID': 'rpc',  # request identifier
            't': 1,  # trial
            'SID': self._sid_param,  # session ID
            'CI': 0,  # 0 if streaming/chunked requests should be used
            'ctype': 'hangouts',  # client type
            'TYPE': 'xmlhttp',  # type of request
        }
        headers = get_authorization_headers(self._cookies['SAPISID'])
        logger.info('Opening new long-polling request')
        try:
            res = yield from asyncio.wait_for(aiohttp.request(
                'get', CHANNEL_URL_PREFIX.format('channel/bind'),
                params=params, cookies=self._cookies, headers=headers,
                connector=self._connector
            ), CONNECT_TIMEOUT)
        except asyncio.TimeoutError:
            raise exceptions.NetworkError('Request timed out')
        except aiohttp.ClientError as e:
            raise exceptions.NetworkError('Request connection error: {}'
                                          .format(e))
        except aiohttp.ServerDisconnectedError as e:
            raise exceptions.NetworkError('Server disconnected error: {}'
                                          .format(e))
        if res.status == 400 and res.reason == 'Unknown SID':
            raise UnknownSIDError('SID became invalid')
        elif res.status != 200:
            raise exceptions.NetworkError(
                'Request return unexpected status: {}: {}'
                .format(res.status, res.reason)
            )
        while True:
            try:
                chunk = yield from asyncio.wait_for(
                    res.content.read(MAX_READ_BYTES), PUSH_TIMEOUT
                )
            except asyncio.TimeoutError:
                raise exceptions.NetworkError('Request timed out')
            except aiohttp.ClientError as e:
                raise exceptions.NetworkError('Request connection error: {}'
                                              .format(e))
            except aiohttp.ServerDisconnectedError as e:
                raise exceptions.NetworkError('Server disconnected error: {}'
                                              .format(e))
            except asyncio.CancelledError:
                # Prevent ResourceWarning when channel is disconnected.
                res.close()
                raise
            if chunk:
                yield from self._on_push_data(chunk)
            else:
                # Close the response to allow the connection to be reused for
                # the next request.
                res.close()
                break