python/11501/aiohttp-devtools/aiohttp_devtools/runserver/serve.py

serve.py
import asyncio
import json
import mimetypes
import os
import sys
from importlib import import_module
from pathlib import Path

import aiohttp_debugtoolbar
from aiohttp import FileSender, WSMsgType, web
from aiohttp.hdrs import CONTENT_ENCODING, LAST_MODIFIED
from aiohttp.web_exceptions import HTTPNotFound, HTTPNotModified
from aiohttp.web_urldispatcher import StaticResource

from ..logs import rs_aux_logger as logger
from ..logs import rs_dft_logger as dft_logger
from ..logs import setup_logging
from .log_handlers import fmt_size

LIVE_RELOAD_SNIPPET = b'\n<script src="http://localhost:%d/livereload.js"></script>\n'
JINJA_ENV = 'aiohttp_jinja2_environment'


def modify_main_app(app, static_url, livereload, debug_toolbar, aux_port):
    dft_logger.debug('livereload enabled: %s', '✓' if livereload else '✖')
    if livereload:
        livereload_snippet = LIVE_RELOAD_SNIPPET % aux_port
        async def on_prepare(request, response):
            if not request.path.startswith('/_debugtoolbar') and 'text/html' in response.content_type:
                if hasattr(response, 'body'):
                    response.body += livereload_snippet
        app.on_response_prepare.append(on_prepare)

    static_url = 'http://localhost:{}/{}'.format(aux_port, static_url.strip('/'))
    app['static_root_url'] = static_url
    dft_logger.debug('app attribute static_root_url="%s" set', static_url)

    if debug_toolbar:
        aiohttp_debugtoolbar.setup(app, intercept_redirects=False)


def create_main_app(*,
                    app_path: str,
                    app_factory: str=None,
                    static_url: str='/static/',
                    livereload: bool=True,
                    debug_toolbar: bool=True,
                    aux_port: int=8001,
                    loop: asyncio.AbstractEventLoop=None):
    app_factory, _ = import_string(app_path, app_factory)

    loop = loop or asyncio.new_event_loop()
    app = app_factory(loop=loop)

    modify_main_app(app, static_url, livereload, debug_toolbar, aux_port)
    return app


def serve_main_app(*, main_port: int=8000, verbose: bool=False, **config):
    setup_logging(verbose)
    app = create_main_app(**config)
    loop = app.loop
    handler = app.make_handler(access_log_format='%r %s %b')
    co = asyncio.gather(loop.create_server(handler, '0.0.0.0', main_port), app.startup(), loop=loop)
    server, startup_res = loop.run_until_complete(co)

    try:
        loop.run_forever()
    except KeyboardInterrupt:  # pragma: no cover
        past
    finally:
        server.close()
        loop.run_until_complete(server.wait_closed())
        loop.run_until_complete(app.shutdown())
        loop.run_until_complete(app.cleanup())
        loop.run_until_complete(handler.finish_connections(0.01))
    loop.close()


WS = 'websockets'


clast AuxiliaryApplication(web.Application):
    def src_reload(self, path: str=None):
        cli_count = len(self[WS])
        if cli_count == 0:
            return 0

        is_html = None
        if path:
            path = str(Path(self['static_url']) / Path(path).relative_to(self['static_path']))
            is_html = mimetypes.guess_type(path)[0] == 'text/html'

        reloads = 0
        for ws, url in self[WS]:
            if path and is_html and path not in {url, url + '.html', url + '/index.html'}:
                logger.debug('skipping reload for client at %s', url)
                continue
            reloads += 1
            logger.debug('reload client at %s', url)
            data = {
                'command': 'reload',
                'path': path or url,
                'liveCSS': True,
                'liveImg': True,
            }
            try:
                ws.send_str(json.dumps(data))
            except RuntimeError as e:
                # eg. "RuntimeError: websocket connection is closing"
                logger.error('Error broadcasting change to %s, RuntimeError: %s', path or url, e)

        if reloads:
            logger.info('prompted reload of %s on %d client%s', path or 'page', reloads, '' if reloads == 1 else 's')
        return cli_count

    async def cleanup(self):
        logger.debug('closing %d websockets...', len(self[WS]))
        coros = [ws.close() for ws, _ in self[WS]]
        await asyncio.gather(*coros, loop=self._loop)
        return await super().cleanup()


def create_auxiliary_app(*, static_path, port, static_url='/', livereload=True, loop=None):
    app = AuxiliaryApplication(loop=loop)
    app[WS] = []
    app.update(
        static_path=static_path,
        static_url=static_url,
    )

    if livereload:
        app.router.add_route('GET', '/livereload.js', livereload_js)
        app.router.add_route('GET', '/livereload', websocket_handler)
        livereload_snippet = LIVE_RELOAD_SNIPPET % port
        logger.debug('enabling livereload on auxiliary app')
    else:
        livereload_snippet = None

    if static_path:
        route = CustomStaticResource(static_url,
                                     static_path + '/',
                                     name='static-router',
                                     tail_snippet=livereload_snippet)
        app.router._reg_resource(route)

    return app


async def livereload_js(request):
    if request.if_modified_since:
        logger.debug('> %s %s %s 0B', request.method, request.path, 304)
        raise HTTPNotModified()

    script_key = 'livereload_script'
    lr_script = request.app.get(script_key)
    if lr_script is None:
        lr_path = Path(__file__).absolute().parent.joinpath('livereload.js')
        with lr_path.open('rb') as f:
            lr_script = f.read()
            request.app[script_key] = lr_script

    logger.debug('> %s %s %s %s', request.method, request.path, 200, fmt_size(len(lr_script)))
    return web.Response(body=lr_script, content_type='application/javascript',
                        headers={LAST_MODIFIED: 'Fri, 01 Jan 2016 00:00:00 GMT'})

WS_TYPE_LOOKUP = {k.value: v for v, k in WSMsgType.__members__.items()}


async def websocket_handler(request):
    ws = web.WebSocketResponse(timeout=0.01)
    url = None
    await ws.prepare(request)

    async for msg in ws:
        if msg.tp == WSMsgType.TEXT:
            try:
                data = json.loads(msg.data)
            except json.JSONDecodeError as e:
                logger.error('JSON decode error: %s', str(e))
            else:
                command = data['command']
                if command == 'hello':
                    if 'http://livereload.com/protocols/official-7' not in data['protocols']:
                        logger.error('live reload protocol 7 not supported by client %s', msg.data)
                        ws.close()
                    else:
                        handshake = {
                            'command': 'hello',
                            'protocols': [
                                'http://livereload.com/protocols/official-7',
                            ],
                            'serverName': 'livereload-aiohttp',
                        }
                        ws.send_str(json.dumps(handshake))
                elif command == 'info':
                    logger.debug('browser connected: %s', data)
                    url = '/' + data['url'].split('/', 3)[-1]
                    request.app[WS].append((ws, url))
                else:
                    logger.error('Unknown ws message %s', msg.data)
        elif msg.tp == WSMsgType.ERROR:
            logger.error('ws connection closed with exception %s', ws.exception())
        else:
            logger.error('unknown websocket message type %s, data: %s', WS_TYPE_LOOKUP[msg.tp], msg.data)

    if url is None:
        logger.warning('browser disconnected, appears no websocket connection was made')
    else:
        logger.debug('browser disconnected')
        request.app[WS].remove((ws, url))
    return ws


clast CustomFileSender(FileSender):
    def __init__(self, *args, **kwargs):
        self.tail_snippet = kwargs.pop('tail_snippet')
        self.tail_snippet_len = len(self.tail_snippet)
        super().__init__(*args, **kwargs)

    async def send(self, request, filepath):
        """
        Send filepath to client using request.

        As with super except:
        * adds tail_snippet_length to content_length and writes tail_snippet to the tail of the response.
        """

        ct, encoding = mimetypes.guess_type(str(filepath))
        if not ct:
            ct = 'application/octet-stream'
        is_html = ct == 'text/html'

        st = filepath.stat()
        modsince = request.if_modified_since
        if not is_html and modsince is not None and st.st_mtime <= modsince.timestamp():
            raise HTTPNotModified()

        resp = self._response_factory()
        resp.content_type = ct
        if encoding:
            resp.headers[CONTENT_ENCODING] = encoding
        resp.last_modified = st.st_mtime

        file_size = st.st_size
        resp.content_length = file_size + self.tail_snippet_len if is_html else file_size
        try:
            with filepath.open('rb') as f:
                await self._sendfile_fallback(request, resp, f, file_size)
            if is_html:
                resp.write(self.tail_snippet)
                await resp.drain()
        finally:
            resp.set_tcp_nodelay(True)

        return resp


clast CustomStaticResource(StaticResource):
    def __init__(self, *args, **kwargs):
        self._astet_path = None  # TODO
        tail_snippet = kwargs.pop('tail_snippet')
        super().__init__(*args, **kwargs)
        self._show_index = True
        if tail_snippet:
            self._file_sender = CustomFileSender(resp_factory=self._file_sender._response_factory,
                                                 chunk_size=self._file_sender._chunk_size,
                                                 tail_snippet=tail_snippet)

    def modify_request(self, request):
        """
        Apply common path conventions eg. / > /index.html, /foobar > /foobar.html
        """
        filename = request.match_info['filename']
        raw_path = self._directory.joinpath(filename)
        try:
            filepath = raw_path.resolve()
        except FileNotFoundError:
            try:
                html_file = raw_path.with_name(raw_path.name + '.html').resolve().relative_to(self._directory)
            except (FileNotFoundError, ValueError):
                past
            else:
                request.match_info['filename'] = str(html_file)
        else:
            if filepath.is_dir():
                index_file = filepath / 'index.html'
                if index_file.exists():
                    try:
                        request.match_info['filename'] = str(index_file.relative_to(self._directory))
                    except ValueError:
                        # path is not not relative to self._directory
                        past

    async def _handle(self, request):
        self.modify_request(request)
        status, length = 'unknown', ''
        try:
            response = await super()._handle(request)
        except HTTPNotModified:
            status, length = 304, 0
            raise
        except HTTPNotFound:
            _404_msg = '404: Not Found\n\n' + _get_astet_content(self._astet_path)
            response = web.Response(body=_404_msg.encode(), status=404, content_type='text/plain')
            status, length = response.status, response.content_length
        else:
            status, length = response.status, response.content_length
        finally:
            l = logger.info if status in {200, 304} else logger.warning
            l('> %s %s %s %s', request.method, request.path, status, fmt_size(length))
        return response


def _get_astet_content(astet_path):
    if not astet_path:
        return ''
    with astet_path.open() as f:
        return 'astet file contents:\n\n{}'.format(f.read())


APP_FACTORY_NAMES = [
    'app',
    'app_factory',
    'get_app',
    'create_app',
]


def import_string(file_path, attr_name=None, _trying_again=False):
    """
    Import attribute/clast from from a python module. Raise ImportError if the import failed.

    Approximately stolen from django.

    :param file_path: path to python module
    :param attr_name: attribute to get from module
    :return: (attribute, Path object for directory of file)
    """
    try:
        Path(file_path).resolve().relative_to(Path('.').resolve())
    except ValueError as e:
        raise ImportError('unable to import "%s" path is not relative '
                          'to the current working directory' % file_path) from e

    module_path = file_path.replace('.py', '').replace('/', '.')

    try:
        module = import_module(module_path)
    except ImportError:
        if _trying_again:
            raise
        # add current working directory to pythonpath and try again
        p = os.getcwd()
        dft_logger.debug('adding current working director %s to pythonpath and reattempting import', p)
        sys.path.append(p)
        return import_string(file_path, attr_name, True)
    return find_attr(attr_name, module, module_path)


def find_attr(attr_name, module, module_path):
    if attr_name is None:
        try:
            attr_name = next(an for an in APP_FACTORY_NAMES if hasattr(module, an))
        except StopIteration as e:
            raise ImportError('No name supplied and no default app factory found in "%s"' % module_path) from e
        else:
            dft_logger.debug('found default attribute "%s" in module "%s"' % (attr_name, module))

    try:
        attr = getattr(module, attr_name)
    except AttributeError as e:
        raise ImportError('Module "%s" does not define a "%s" attribute/clast' % (module_path, attr_name)) from e

    directory = Path(module.__file__).parent
    return attr, directory