contextlib.contextmanager

Here are the examples of the python api contextlib.contextmanager taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 151

Project: djangae
Source File: sandbox.py
View license
@contextlib.contextmanager
def activate(sandbox_name, add_sdk_to_path=False, new_env_vars=None, app_id=None, **overrides):
    """Context manager for command-line scripts started outside of dev_appserver.

    :param sandbox_name: str, one of 'local', 'remote' or 'test'
    :param add_sdk_to_path: bool, optionally adds the App Engine SDK to sys.path
    :param options_override: an options structure to pass down to dev_appserver setup

    Available sandboxes:

      local: Adds libraries specified in app.yaml to the path and initializes local service stubs as though
             dev_appserver were running.

      remote: Adds libraries specified in app.yaml to the path and initializes remote service stubs.

      test: Adds libraries specified in app.yaml to the path and sets up no service stubs. Use this
            with `google.appengine.ext.testbed` to provide isolation for tests.

    Example usage:

        import djangae.sandbox as sandbox

        with sandbox.activate('local'):
            from django.core.management import execute_from_command_line
            execute_from_command_line(sys.argv)

    """
    if sandbox_name not in SANDBOXES:
        raise RuntimeError('Unknown sandbox "{}"'.format(sandbox_name))

    project_root = environment.get_application_root()

   # Store our original sys.path before we do anything, this must be tacked
    # onto the end of the other paths so we can access globally installed things (e.g. ipdb etc.)
    original_path = sys.path[:]

    # Setup paths as though we were running dev_appserver. This is similar to
    # what the App Engine script wrappers do.
    if add_sdk_to_path:
        try:
            import wrapper_util  # Already on sys.path
        except ImportError:
            sys.path[0:0] = [_find_sdk_from_path()]
            import wrapper_util
    else:
        try:
            import wrapper_util
        except ImportError:
            raise RuntimeError("Couldn't find a recent enough Google App Engine SDK, make sure you are using at least 1.9.6")

    sdk_path = _find_sdk_from_python_path()
    _PATHS = wrapper_util.Paths(sdk_path)

    project_paths = [] # Paths under the application root
    system_paths = [] # All other paths
    app_root = environment.get_application_root()

    # We need to look at the original path, and make sure that any paths
    # which are under the project root are first, then any other paths
    # are added after the SDK ones
    for path in _PATHS.scrub_path(_SCRIPT_NAME, original_path):
        if commonprefix([app_root, path]) == app_root:
            project_paths.append(path)
        else:
            system_paths.append(path)

    # We build a list of SDK paths, and add any additional ones required for
    # the oauth client
    appengine_paths = _PATHS.script_paths(_SCRIPT_NAME)
    for path in _PATHS.oauth_client_extra_paths:
        if path not in appengine_paths:
            appengine_paths.append(path)

    # Now, we make sure that paths within the project take precedence, followed
    # by the SDK, then finally any paths from the system Python (for stuff like
    # ipdb etc.)
    sys.path = (
        project_paths +
        appengine_paths +
        system_paths
    )

    # Gotta set the runtime properly otherwise it changes appengine imports, like wepapp
    # when you are not running dev_appserver
    import yaml
    with open(os.path.join(project_root, 'app.yaml'), 'r') as app_yaml:
        app_yaml = yaml.load(app_yaml)
        os.environ['APPENGINE_RUNTIME'] = app_yaml.get('runtime', '')

    # Initialize as though `dev_appserver.py` is about to run our app, using all the
    # configuration provided in app.yaml.
    import google.appengine.tools.devappserver2.application_configuration as application_configuration
    import google.appengine.tools.devappserver2.python.sandbox as sandbox
    import google.appengine.tools.devappserver2.devappserver2 as devappserver2
    import google.appengine.tools.devappserver2.wsgi_request_info as wsgi_request_info
    import google.appengine.ext.remote_api.remote_api_stub as remote_api_stub
    import google.appengine.api.apiproxy_stub_map as apiproxy_stub_map

    # The argparser is the easiest way to get the default options.
    options = devappserver2.PARSER.parse_args([project_root])
    options.enable_task_running = False # Disable task running by default, it won't work without a running server
    options.skip_sdk_update_check = True

    for option in overrides:
        if not hasattr(options, option):
            raise ValueError("Unrecognized sandbox option: {}".format(option))

        setattr(options, option, overrides[option])

    if app_id:
        configuration = application_configuration.ApplicationConfiguration(options.config_paths, app_id=app_id)
    else:
        configuration = application_configuration.ApplicationConfiguration(options.config_paths)

    # Enable built-in libraries from app.yaml without enabling the full sandbox.
    module = configuration.modules[0]
    for l in sandbox._enable_libraries(module.normalized_libraries):
        sys.path.insert(1, l)

    # Propagate provided environment variables to the sandbox.
    # This is required for the runserver management command settings flag,
    # which sets an environment variable needed by Django.
    from google.appengine.api.appinfo import EnvironmentVariables
    old_env_vars = module.env_variables if module.env_variables else {}
    if new_env_vars is None:
        new_env_vars = {}
    module._app_info_external.env_variables = EnvironmentVariables.Merge(
        old_env_vars,
        new_env_vars,
    )

    try:
        global _OPTIONS
        global _CONFIG
        _CONFIG = configuration
        _OPTIONS = options # Store the options globally so they can be accessed later
        kwargs = dict(
            devappserver2=devappserver2,
            configuration=configuration,
            options=options,
            wsgi_request_info=wsgi_request_info,
            remote_api_stub=remote_api_stub,
            apiproxy_stub_map=apiproxy_stub_map,
        )
        with SANDBOXES[sandbox_name](**kwargs):
            yield

    finally:
        sys.path = original_path

Example 152

Project: pymo
Source File: test_support.py
View license
@contextlib.contextmanager
def transient_internet(resource_name, timeout=30.0, errnos=()):
    """Return a context manager that raises ResourceDenied when various issues
    with the Internet connection manifest themselves as exceptions."""
    default_errnos = [
        ('ECONNREFUSED', 111),
        ('ECONNRESET', 104),
        ('EHOSTUNREACH', 113),
        ('ENETUNREACH', 101),
        ('ETIMEDOUT', 110),
    ]
    default_gai_errnos = [
        ('EAI_NONAME', -2),
        ('EAI_NODATA', -5),
    ]

    denied = ResourceDenied("Resource '%s' is not available" % resource_name)
    captured_errnos = errnos
    gai_errnos = []
    if not captured_errnos:
        captured_errnos = [getattr(errno, name, num)
                           for (name, num) in default_errnos]
        gai_errnos = [getattr(socket, name, num)
                      for (name, num) in default_gai_errnos]

    def filter_error(err):
        n = getattr(err, 'errno', None)
        if (isinstance(err, socket.timeout) or
            (isinstance(err, socket.gaierror) and n in gai_errnos) or
            n in captured_errnos):
            if not verbose:
                sys.stderr.write(denied.args[0] + "\n")
            raise denied

    old_timeout = socket.getdefaulttimeout()
    try:
        if timeout is not None:
            socket.setdefaulttimeout(timeout)
        yield
    except IOError as err:
        # urllib can wrap original socket errors multiple times (!), we must
        # unwrap to get at the original error.
        while True:
            a = err.args
            if len(a) >= 1 and isinstance(a[0], IOError):
                err = a[0]
            # The error can also be wrapped as args[1]:
            #    except socket.error as msg:
            #        raise IOError('socket error', msg).with_traceback(sys.exc_info()[2])
            elif len(a) >= 2 and isinstance(a[1], IOError):
                err = a[1]
            else:
                break
        filter_error(err)
        raise
    # XXX should we catch generic exceptions and look for their
    # __cause__ or __context__?
    finally:
        socket.setdefaulttimeout(old_timeout)

Example 153

Project: iris
Source File: pp.py
View license
    @contextlib.contextmanager
    def cube_save_test(self, reference_txt_path, reference_cubes=None, reference_pp_path=None, **kwargs):
        """
        A context manager for testing the saving of Cubes to PP files.

        Args:

        * reference_txt_path:
            The path of the file containing the textual PP reference data.

        Kwargs:

        * reference_cubes:
            The cube(s) from which the textual PP reference can be re-built if necessary.
        * reference_pp_path:
            The location of a PP file from which the textual PP reference can be re-built if necessary.
            NB. The "reference_cubes" argument takes precedence over this argument.

        The return value from the context manager is the name of a temporary file
        into which the PP data to be tested should be saved.

        Example::
            with self.cube_save_test(reference_txt_path, reference_cubes=cubes) as temp_pp_path:
                iris.save(cubes, temp_pp_path)

        """
        # Watch out for a missing reference text file
        if not os.path.isfile(reference_txt_path):
            tests.logger.warning('Creating result file: %s', reference_txt_path)
            if reference_cubes:
                temp_pp_path = iris.util.create_temp_filename(".pp")
                try:
                    iris.save(reference_cubes, temp_pp_path, **kwargs)
                    self._create_reference_txt(reference_txt_path, temp_pp_path)
                finally:
                    os.remove(temp_pp_path)
            elif reference_pp_path:
                self._create_reference_txt(reference_txt_path, reference_pp_path)
            else:
                raise ValueError('Missing all of reference txt file, cubes, and PP path.')

        temp_pp_path = iris.util.create_temp_filename(".pp")
        try:
            # This value is returned to the target of the "with" statement's "as" clause.
            yield temp_pp_path

            # Load deferred data for all of the fields (but don't do anything with it)
            pp_fields = list(iris.fileformats.pp.load(temp_pp_path))
            for pp_field in pp_fields:
                pp_field.data

            with open(reference_txt_path, 'r') as reference_fh:
                reference = ''.join(reference_fh)
            self._assert_str_same(reference + '\n', str(pp_fields) + '\n',
                                    reference_txt_path, type_comparison_name='PP files')
        finally:
            os.remove(temp_pp_path)

Example 154

Project: maya-capture
Source File: capture.py
View license
@contextlib.contextmanager
def _independent_panel(width, height, off_screen=False):
    """Create capture-window context without decorations

    Arguments:
        width (int): Width of panel
        height (int): Height of panel

    Example:
        >>> with _independent_panel(800, 600):
        ...   cmds.capture()

    """

    # center panel on screen
    screen_width, screen_height = _get_screen_size()
    topLeft = [int((screen_height-height)/2.0),
               int((screen_width-width)/2.0)]

    window = cmds.window(width=width,
                         height=height,
                         topLeftCorner=topLeft,
                         menuBarVisible=False,
                         titleBar=False,
                         visible=not off_screen)
    cmds.paneLayout()
    panel = cmds.modelPanel(menuBarVisible=False,
                            label='CapturePanel')

    # Hide icons under panel menus
    bar_layout = cmds.modelPanel(panel, q=True, barLayout=True)
    cmds.frameLayout(bar_layout, edit=True, collapse=True)

    if not off_screen:
        cmds.showWindow(window)

    # Set the modelEditor of the modelPanel as the active view so it takes
    # the playback focus. Does seem redundant with the `refresh` added in.
    editor = cmds.modelPanel(panel, query=True, modelEditor=True)
    cmds.modelEditor(editor, edit=True, activeView=True)

    # Force a draw refresh of Maya so it keeps focus on the new panel
    # This focus is required to force preview playback in the independent panel
    cmds.refresh(force=True)

    try:
        yield panel
    finally:
        # Delete the panel to fix memory leak (about 5 mb per capture)
        cmds.deleteUI(panel, panel=True)
        cmds.deleteUI(window)

Example 155

Project: pyvisa
Source File: rname.py
View license
def filter2(resources, query, open_resource):
    """Filter a list of resources according to a query expression.

    It accepts the optional part of the expression.

    .. warning: This function is experimental and unsafe as it uses eval,
                It also might require to open the resource.

    :param resources: iterable of resources.
    :param query: query expression.
    :param open_resource: function to open the resource.
    """

    if '{' in query:
        try:
            query, optional = query.split('{')
            optional, _ = optional.split('}')
        except ValueError:
            raise errors.VisaIOError(constants.VI_ERROR_INV_EXPR)
    else:
        optional = None

    filtered = filter(resources, query)

    if not optional:
        return filtered

    optional = optional.replace('&&', 'and').replace('||', 'or').replace('!', 'not ')
    optional = optional.replace('VI_', 'res.VI_')

    class AttrGetter():

        def __init__(self, resource_name):
            self.resource_name = resource_name
            self.parsed = parse_resource_name(resource_name)
            self.resource = None

        def __getattr__(self, item):
            if item == 'VI_ATTR_INTF_NUM':
                   return int(self.parsed.board)
            elif item == 'VI_ATTR_MANF_ID':
                   return self.parsed.manufacturer_id
            elif item == 'VI_ATTR_MODEL_CODE':
                   return self.parsed.model_code
            elif item == 'VI_ATTR_USB_SERIAL_NUM':
                   return self.parsed.serial_number
            elif item == 'VI_ATTR_USB_INTFC_NUM':
                   return int(self.parsed.board)
            elif item == 'VI_ATTR_TCPIP_ADDR':
                   return self.parsed.host_address
            elif item == 'VI_ATTR_TCPIP_DEVICE_NAME':
                   return self.parsed.lan_device_name
            elif item == 'VI_ATTR_TCPIP_PORT':
                   return int(self.parsed.port)
            elif item == 'VI_ATTR_INTF_NUM':
                   return int(self.parsed.board)
            elif item == 'VI_ATTR_GPIB_PRIMARY_ADDR':
                   return int(self.parsed.primary_address)
            elif item == 'VI_ATTR_GPIB_SECONDARY_ADDR':
                   return int(self.parsed.secondary_address)
            elif item == 'VI_ATTR_PXI_CHASSIS':
                return self.parsed.chassis_number
            elif item == 'VI_ATTR_MAINFRAME_LA':
                return self.parsed.vxi_logical_address

            if self.resource is None:
                self.resource = open_resource(self.resource_name)

            return self.resource.get_visa_attribute(item)

    @contextlib.contextmanager
    def open_close(resource_name):
        getter = AttrGetter(resource_name)
        yield getter
        if getter.resource is not None:
            getter.resource.close()

    selected = []
    for rn in filtered:
        with open_close(rn) as getter:
            if eval(optional, None, dict(res=getter)):
                selected.append(rn)

Example 156

Project: kozmic-ci
Source File: tasks.py
View license
@contextlib.contextmanager
def _run(publisher, stall_timeout, clone_url, commit_sha,
         docker_image, script, deploy_key=None, remove_container=True):
    yielded = False
    stdout = ''
    try:
        with create_temp_dir() as working_dir:
            message_queue = Queue.Queue()
            builder = Builder(
                docker=docker._get_current_object(),  # `docker` is a local proxy
                deploy_key=deploy_key,
                clone_url=clone_url,
                commit_sha=commit_sha,
                docker_image=docker_image,
                script=script,
                working_dir=working_dir,
                message_queue=message_queue)

            log_path = os.path.join(working_dir, 'script.log')
            stop_reason = ''
            try:
                # Start Builder and wait until it will create the container
                builder.start()
                container = message_queue.get(block=True, timeout=60)

                # Now the container id is known and we can pass it to Tailer
                tailer = Tailer(
                    log_path=log_path,
                    publisher=publisher,
                    container=container,
                    kill_timeout=stall_timeout)
                tailer.start()
                try:
                    # Tell Builder to continue and wait for it to finish
                    message_queue.task_done()
                    builder.join()
                finally:
                    tailer.stop()
                    if tailer.has_killed_container:
                        stop_reason = '\nSorry, your script has stalled and been killed.\n'
            finally:
                if builder.container and remove_container:
                    docker.remove_container(builder.container)

                if os.path.exists(log_path):
                    with open(log_path, 'r') as log:
                        stdout = log.read()

                assert ((builder.return_code is not None) ^
                        (builder.exc_info is not None))
                if builder.exc_info:
                    # Re-raise exception happened in builder
                    # (it will be catched in the outer try-except)
                    raise builder.exc_info[1], None, builder.exc_info[2]
                else:
                    try:
                        yield (builder.return_code,
                               stdout + stop_reason,
                               builder.container)
                    except:
                        raise
                    finally:
                        yielded = True  # otherwise we get "generator didn't
                                        # stop after throw()" error if nested
                                        # code raised exception
    except:
        stdout += ('\nSorry, something went wrong. We are notified of '
                   'the issue and will fix it soon.')
        if not yielded:
            yield 1, stdout, None
        raise

Example 157

Project: deep_recommend_system
Source File: test_util.py
View license
  @contextlib.contextmanager
  def test_session(self,
                   graph=None,
                   config=None,
                   use_gpu=False,
                   force_gpu=False):
    """Returns a TensorFlow Session for use in executing tests.

    This method should be used for all functional tests.

    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
    `force_gpu` is True, all ops are pinned to `/gpu:0`. Otherwise, if `use_gpu`
    is True, TensorFlow tries to run as many ops on the GPU as possible. If both
    `force_gpu and `use_gpu` are False, all ops are pinned to the CPU.

    Example:

      class MyOperatorTest(test_util.TensorFlowTestCase):
        def testMyOperator(self):
          with self.test_session(use_gpu=True):
            valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
            result = MyOperator(valid_input).eval()
            self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
            invalid_input = [-1.0, 2.0, 7.0]
            with self.assertRaisesOpError("negative input not supported"):
              MyOperator(invalid_input).eval()

    Args:
      graph: Optional graph to use during the returned session.
      config: An optional config_pb2.ConfigProto to use to configure the
        session.
      use_gpu: If True, attempt to run as many ops as possible on GPU.
      force_gpu: If True, pin all ops to `/gpu:0`.

    Returns:
      A Session object that should be used as a context manager to surround
      the graph building and execution code in a test case.
    """
    if self.id().endswith(".test_session"):
      self.skipTest("Not a test.")
    def prepare_config(config):
      if config is None:
        config = config_pb2.ConfigProto()
        config.allow_soft_placement = not force_gpu
        config.gpu_options.per_process_gpu_memory_fraction = 0.3
      elif force_gpu and config.allow_soft_placement:
        config = config_pb2.ConfigProto().CopyFrom(config)
        config.allow_soft_placement = False
      # Don't perform optimizations for tests so we don't inadvertently run
      # gpu ops on cpu
      config.graph_options.optimizer_options.opt_level = -1
      return config

    if graph is None:
      if self._cached_session is None:
        self._cached_session = session.Session(graph=None,
                                               config=prepare_config(config))
      sess = self._cached_session
      with sess.graph.as_default(), sess.as_default():
        if force_gpu:
          with sess.graph.device("/gpu:0"):
            yield sess
        elif use_gpu:
          yield sess
        else:
          with sess.graph.device("/cpu:0"):
            yield sess
    else:
      with session.Session(graph=graph, config=prepare_config(config)) as sess:
        if force_gpu:
          with sess.graph.device("/gpu:0"):
            yield sess
        elif use_gpu:
          yield sess
        else:
          with sess.graph.device("/cpu:0"):
            yield sess

Example 158

Project: djangocms-helper
Source File: main.py
View license
def core(args, application):
    from django.conf import settings

    # configure django
    warnings.filterwarnings(
        'error', r'DateTimeField received a naive datetime',
        RuntimeWarning, r'django\.db\.models\.fields')
    if args['--persistent']:
        create_dir = persistent_dir
        if args['--persistent'] is not True:
            parent_path = args['--persistent']
        else:
            parent_path = 'data'
    else:
        create_dir = temp_dir
        parent_path = '/dev/shm'

    with create_dir('static', parent_path) as STATIC_ROOT:
        with create_dir('media', parent_path) as MEDIA_ROOT:
            args['MEDIA_ROOT'] = MEDIA_ROOT
            args['STATIC_ROOT'] = STATIC_ROOT
            if args['cms_check']:
                args['--cms'] = True

            if args['<command>']:
                from django.core.management import execute_from_command_line
                options = [option for option in args['options'] if (
                    option != '--cms' and '--extra-settings' not in option
                )]
                _make_settings(args, application, settings, STATIC_ROOT, MEDIA_ROOT)
                execute_from_command_line(options)

            else:
                _make_settings(args, application, settings, STATIC_ROOT, MEDIA_ROOT)
                # run
                if args['test']:
                    if args['--nose-runner']:
                        runner = 'django_nose.NoseTestSuiteRunner'
                    elif args['--simple-runner']:
                        runner = 'django.test.simple.DjangoTestSuiteRunner'
                    elif args['--runner']:
                        runner = args['--runner']
                    else:
                        runner = 'django.test.runner.DiscoverRunner'

                    # make "Address already in use" errors less likely, see Django
                    # docs for more details on this env variable.
                    os.environ.setdefault(
                        'DJANGO_LIVE_TEST_SERVER_ADDRESS',
                        'localhost:8000-9000'
                    )
                    if args['--xvfb']:  # pragma: no cover
                        import xvfbwrapper

                        context = xvfbwrapper.Xvfb(width=1280, height=720)
                    else:
                        @contextlib.contextmanager
                        def null_context():
                            yield

                        context = null_context()

                    with context:
                        num_failures = test(args['<test-label>'], application,
                                            args['--failfast'], runner,
                                            args['--runner-options'], args.get('--verbose', 1))
                        sys.exit(num_failures)
                elif args['server']:
                    server(
                        args['--bind'], args['--port'], args.get('--migrate', True),
                        args.get('--verbose', 1)
                    )
                elif args['cms_check']:
                    cms_check(args.get('--migrate', True))
                elif args['compilemessages']:
                    compilemessages(application)
                elif args['makemessages']:
                    makemessages(application, locale=args['--locale'])
                elif args['makemigrations']:
                    makemigrations(application, merge=args['--merge'], dry_run=args['--dry-run'],
                                   empty=args['--empty'],
                                   extra_applications=args['<extra-applications>'])
                elif args['pyflakes']:
                    return static_analisys(application)
                elif args['authors']:
                    return generate_authors()
                elif args['setup']:
                    return setup_env(settings)

Example 159

Project: django
Source File: schema.py
View license
    def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None):
        """
        Shortcut to transform a model from old_model into new_model

        The essential steps are:
          1. rename the model's existing table, e.g. "app_model" to "app_model__old"
          2. create a table with the updated definition called "app_model"
          3. copy the data from the old renamed table to the new table
          4. delete the "app_model__old" table
        """
        # Self-referential fields must be recreated rather than copied from
        # the old model to ensure their remote_field.field_name doesn't refer
        # to an altered field.
        def is_self_referential(f):
            return f.is_relation and f.remote_field.model is model
        # Work out the new fields dict / mapping
        body = {
            f.name: f.clone() if is_self_referential(f) else f
            for f in model._meta.local_concrete_fields
        }
        # Since mapping might mix column names and default values,
        # its values must be already quoted.
        mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields}
        # This maps field names (not columns) for things like unique_together
        rename_mapping = {}
        # If any of the new or altered fields is introducing a new PK,
        # remove the old one
        restore_pk_field = None
        if getattr(create_field, 'primary_key', False) or (
                alter_field and getattr(alter_field[1], 'primary_key', False)):
            for name, field in list(body.items()):
                if field.primary_key:
                    field.primary_key = False
                    restore_pk_field = field
                    if field.auto_created:
                        del body[name]
                        del mapping[field.column]
        # Add in any created fields
        if create_field:
            body[create_field.name] = create_field
            # Choose a default and insert it into the copy map
            if not create_field.many_to_many and create_field.concrete:
                mapping[create_field.column] = self.quote_value(
                    self.effective_default(create_field)
                )
        # Add in any altered fields
        if alter_field:
            old_field, new_field = alter_field
            body.pop(old_field.name, None)
            mapping.pop(old_field.column, None)
            body[new_field.name] = new_field
            if old_field.null and not new_field.null:
                case_sql = "coalesce(%(col)s, %(default)s)" % {
                    'col': self.quote_name(old_field.column),
                    'default': self.quote_value(self.effective_default(new_field))
                }
                mapping[new_field.column] = case_sql
            else:
                mapping[new_field.column] = self.quote_name(old_field.column)
            rename_mapping[old_field.name] = new_field.name
        # Remove any deleted fields
        if delete_field:
            del body[delete_field.name]
            del mapping[delete_field.column]
            # Remove any implicit M2M tables
            if delete_field.many_to_many and delete_field.remote_field.through._meta.auto_created:
                return self.delete_model(delete_field.remote_field.through)
        # Work inside a new app registry
        apps = Apps()

        # Provide isolated instances of the fields to the new model body so
        # that the existing model's internals aren't interfered with when
        # the dummy model is constructed.
        body = copy.deepcopy(body)

        # Work out the new value of unique_together, taking renames into
        # account
        unique_together = [
            [rename_mapping.get(n, n) for n in unique]
            for unique in model._meta.unique_together
        ]

        # Work out the new value for index_together, taking renames into
        # account
        index_together = [
            [rename_mapping.get(n, n) for n in index]
            for index in model._meta.index_together
        ]

        indexes = model._meta.indexes
        if delete_field:
            indexes = [
                index for index in indexes
                if delete_field.name not in index.fields
            ]

        # Construct a new model for the new state
        meta_contents = {
            'app_label': model._meta.app_label,
            'db_table': model._meta.db_table,
            'unique_together': unique_together,
            'index_together': index_together,
            'indexes': indexes,
            'apps': apps,
        }
        meta = type("Meta", tuple(), meta_contents)
        body['Meta'] = meta
        body['__module__'] = model.__module__

        temp_model = type(model._meta.object_name, model.__bases__, body)

        # We need to modify model._meta.db_table, but everything explodes
        # if the change isn't reversed before the end of this method. This
        # context manager helps us avoid that situation.
        @contextlib.contextmanager
        def altered_table_name(model, temporary_table_name):
            original_table_name = model._meta.db_table
            model._meta.db_table = temporary_table_name
            yield
            model._meta.db_table = original_table_name

        with altered_table_name(model, model._meta.db_table + "__old"):
            # Rename the old table to make way for the new
            self.alter_db_table(model, temp_model._meta.db_table, model._meta.db_table)

            # Create a new table with the updated schema. We remove things
            # from the deferred SQL that match our table name, too
            self.deferred_sql = [x for x in self.deferred_sql if temp_model._meta.db_table not in x]
            self.create_model(temp_model)

            # Copy data from the old table into the new table
            field_maps = list(mapping.items())
            self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
                self.quote_name(temp_model._meta.db_table),
                ', '.join(self.quote_name(x) for x, y in field_maps),
                ', '.join(y for x, y in field_maps),
                self.quote_name(model._meta.db_table),
            ))

            # Delete the old table
            self.delete_model(model, handle_autom2m=False)

        # Run deferred SQL on correct table
        for sql in self.deferred_sql:
            self.execute(sql)
        self.deferred_sql = []
        # Fix any PK-removed field
        if restore_pk_field:
            restore_pk_field.primary_key = True

Example 160

Project: imagrium
Source File: test_support.py
View license
@contextlib.contextmanager
def transient_internet(resource_name, timeout=30.0, errnos=()):
    """Return a context manager that raises ResourceDenied when various issues
    with the Internet connection manifest themselves as exceptions."""
    default_errnos = [
        ('ECONNREFUSED', 111),
        ('ECONNRESET', 104),
        ('EHOSTUNREACH', 113),
        ('ENETUNREACH', 101),
        ('ETIMEDOUT', 110),
    ]
    default_gai_errnos = [
        ('EAI_AGAIN', -3),
        ('EAI_FAIL', -4),
        ('EAI_NONAME', -2),
        ('EAI_NODATA', -5),
    ]

    denied = ResourceDenied("Resource '%s' is not available" % resource_name)
    captured_errnos = errnos
    gai_errnos = []
    if not captured_errnos:
        captured_errnos = [getattr(errno, name, num)
                           for (name, num) in default_errnos]
        gai_errnos = [getattr(socket, name, num)
                      for (name, num) in default_gai_errnos]

    def filter_error(err):
        n = getattr(err, 'errno', None)
        if (isinstance(err, socket.timeout) or
            (isinstance(err, socket.gaierror) and n in gai_errnos) or
            n in captured_errnos):
            if not verbose:
                sys.stderr.write(denied.args[0] + "\n")
            raise denied

    old_timeout = socket.getdefaulttimeout()
    try:
        if timeout is not None:
            socket.setdefaulttimeout(timeout)
        yield
    except IOError as err:
        # urllib can wrap original socket errors multiple times (!), we must
        # unwrap to get at the original error.
        while True:
            a = err.args
            if len(a) >= 1 and isinstance(a[0], IOError):
                err = a[0]
            # The error can also be wrapped as args[1]:
            #    except socket.error as msg:
            #        raise IOError('socket error', msg).with_traceback(sys.exc_info()[2])
            elif len(a) >= 2 and isinstance(a[1], IOError):
                err = a[1]
            else:
                break
        filter_error(err)
        raise
    # XXX should we catch generic exceptions and look for their
    # __cause__ or __context__?
    finally:
        socket.setdefaulttimeout(old_timeout)

Example 161

Project: hitch
Source File: testing.py
View license
    @contextlib.contextmanager
    def isolation(self, input=None, env=None, color=False):
        """A context manager that sets up the isolation for invoking of a
        command line tool.  This sets up stdin with the given input data
        and `os.environ` with the overrides from the given dictionary.
        This also rebinds some internals in Click to be mocked (like the
        prompt functionality).

        This is automatically done in the :meth:`invoke` method.

        .. versionadded:: 4.0
           The ``color`` parameter was added.

        :param input: the input stream to put into sys.stdin.
        :param env: the environment overrides as dictionary.
        :param color: whether the output should contain color codes. The
                      application can still override this explicitly.
        """
        input = make_input_stream(input, self.charset)

        old_stdin = sys.stdin
        old_stdout = sys.stdout
        old_stderr = sys.stderr

        env = self.make_env(env)

        if PY2:
            sys.stdout = sys.stderr = bytes_output = StringIO()
            if self.echo_stdin:
                input = EchoingStdin(input, bytes_output)
        else:
            bytes_output = io.BytesIO()
            if self.echo_stdin:
                input = EchoingStdin(input, bytes_output)
            input = io.TextIOWrapper(input, encoding=self.charset)
            sys.stdout = sys.stderr = io.TextIOWrapper(
                bytes_output, encoding=self.charset)

        sys.stdin = input

        def visible_input(prompt=None):
            sys.stdout.write(prompt or '')
            val = input.readline().rstrip('\r\n')
            sys.stdout.write(val + '\n')
            sys.stdout.flush()
            return val

        def hidden_input(prompt=None):
            sys.stdout.write((prompt or '') + '\n')
            sys.stdout.flush()
            return input.readline().rstrip('\r\n')

        def _getchar(echo):
            char = sys.stdin.read(1)
            if echo:
                sys.stdout.write(char)
                sys.stdout.flush()
            return char

        default_color = color
        def should_strip_ansi(stream=None, color=None):
            if color is None:
                return not default_color
            return not color

        old_visible_prompt_func = clickpkg.termui.visible_prompt_func
        old_hidden_prompt_func = clickpkg.termui.hidden_prompt_func
        old__getchar_func = clickpkg.termui._getchar
        old_should_strip_ansi = clickpkg.utils.should_strip_ansi
        clickpkg.termui.visible_prompt_func = visible_input
        clickpkg.termui.hidden_prompt_func = hidden_input
        clickpkg.termui._getchar = _getchar
        clickpkg.utils.should_strip_ansi = should_strip_ansi

        old_env = {}
        try:
            for key, value in iteritems(env):
                old_env[key] = os.environ.get(value)
                if value is None:
                    try:
                        del os.environ[key]
                    except Exception:
                        pass
                else:
                    os.environ[key] = value
            yield bytes_output
        finally:
            for key, value in iteritems(old_env):
                if value is None:
                    try:
                        del os.environ[key]
                    except Exception:
                        pass
                else:
                    os.environ[key] = value
            sys.stdout = old_stdout
            sys.stderr = old_stderr
            sys.stdin = old_stdin
            clickpkg.termui.visible_prompt_func = old_visible_prompt_func
            clickpkg.termui.hidden_prompt_func = old_hidden_prompt_func
            clickpkg.termui._getchar = old__getchar_func
            clickpkg.utils.should_strip_ansi = old_should_strip_ansi

Example 162

Project: openmetadata
Source File: testing.py
View license
    @contextlib.contextmanager
    def isolation(self, input=None, env=None):
        """A context manager that sets up the isolation for invoking of a
        command line tool.  This sets up stdin with the given input data
        and `os.environ` with the overrides from the given dictionary.
        This also rebinds some internals in Click to be mocked (like the
        prompt functionality).

        This is automatically done in the :meth:`invoke` method.

        :param input: the input stream to put into sys.stdin.
        :param env: the environment overrides as dictionary.
        """
        input = make_input_stream(input, self.charset)

        old_stdin = sys.stdin
        old_stdout = sys.stdout
        old_stderr = sys.stderr

        env = self.make_env(env)

        if PY2:
            sys.stdout = sys.stderr = bytes_output = StringIO()
            if self.echo_stdin:
                input = EchoingStdin(input, bytes_output)
        else:
            bytes_output = io.BytesIO()
            if self.echo_stdin:
                input = EchoingStdin(input, bytes_output)
            input = io.TextIOWrapper(input, encoding=self.charset)
            sys.stdout = sys.stderr = io.TextIOWrapper(
                bytes_output, encoding=self.charset)

        sys.stdin = input

        def visible_input(prompt=None):
            sys.stdout.write(prompt or '')
            val = input.readline().rstrip('\r\n')
            sys.stdout.write(val + '\n')
            sys.stdout.flush()
            return val

        def hidden_input(prompt=None):
            sys.stdout.write((prompt or '') + '\n')
            sys.stdout.flush()
            return input.readline().rstrip('\r\n')

        def _getchar(echo):
            char = sys.stdin.read(1)
            if echo:
                sys.stdout.write(char)
                sys.stdout.flush()
            return char

        old_visible_prompt_func = click.termui.visible_prompt_func
        old_hidden_prompt_func = click.termui.hidden_prompt_func
        old__getchar_func = click.termui._getchar
        click.termui.visible_prompt_func = visible_input
        click.termui.hidden_prompt_func = hidden_input
        click.termui._getchar = _getchar

        old_env = {}
        try:
            for key, value in iteritems(env):
                old_env[key] = os.environ.get(value)
                if value is None:
                    try:
                        del os.environ[key]
                    except Exception:
                        pass
                else:
                    os.environ[key] = value
            yield bytes_output
        finally:
            for key, value in iteritems(old_env):
                if value is None:
                    try:
                        del os.environ[key]
                    except Exception:
                        pass
                else:
                    os.environ[key] = value
            sys.stdout = old_stdout
            sys.stderr = old_stderr
            sys.stdin = old_stdin
            click.termui.visible_prompt_func = old_visible_prompt_func
            click.termui.hidden_prompt_func = old_hidden_prompt_func
            click.termui._getchar = old__getchar_func

Example 163

Project: openwrt-mt7620
Source File: test_support.py
View license
@contextlib.contextmanager
def transient_internet(resource_name, timeout=30.0, errnos=()):
    """Return a context manager that raises ResourceDenied when various issues
    with the Internet connection manifest themselves as exceptions."""
    default_errnos = [
        ('ECONNREFUSED', 111),
        ('ECONNRESET', 104),
        ('EHOSTUNREACH', 113),
        ('ENETUNREACH', 101),
        ('ETIMEDOUT', 110),
    ]
    default_gai_errnos = [
        ('EAI_AGAIN', -3),
        ('EAI_FAIL', -4),
        ('EAI_NONAME', -2),
        ('EAI_NODATA', -5),
    ]

    denied = ResourceDenied("Resource '%s' is not available" % resource_name)
    captured_errnos = errnos
    gai_errnos = []
    if not captured_errnos:
        captured_errnos = [getattr(errno, name, num)
                           for (name, num) in default_errnos]
        gai_errnos = [getattr(socket, name, num)
                      for (name, num) in default_gai_errnos]

    def filter_error(err):
        n = getattr(err, 'errno', None)
        if (isinstance(err, socket.timeout) or
            (isinstance(err, socket.gaierror) and n in gai_errnos) or
            n in captured_errnos):
            if not verbose:
                sys.stderr.write(denied.args[0] + "\n")
            raise denied

    old_timeout = socket.getdefaulttimeout()
    try:
        if timeout is not None:
            socket.setdefaulttimeout(timeout)
        yield
    except IOError as err:
        # urllib can wrap original socket errors multiple times (!), we must
        # unwrap to get at the original error.
        while True:
            a = err.args
            if len(a) >= 1 and isinstance(a[0], IOError):
                err = a[0]
            # The error can also be wrapped as args[1]:
            #    except socket.error as msg:
            #        raise IOError('socket error', msg).with_traceback(sys.exc_info()[2])
            elif len(a) >= 2 and isinstance(a[1], IOError):
                err = a[1]
            else:
                break
        filter_error(err)
        raise
    # XXX should we catch generic exceptions and look for their
    # __cause__ or __context__?
    finally:
        socket.setdefaulttimeout(old_timeout)

Example 164

Project: openwrt-mt7620
Source File: test_support.py
View license
@contextlib.contextmanager
def transient_internet(resource_name, timeout=30.0, errnos=()):
    """Return a context manager that raises ResourceDenied when various issues
    with the Internet connection manifest themselves as exceptions."""
    default_errnos = [
        ('ECONNREFUSED', 111),
        ('ECONNRESET', 104),
        ('EHOSTUNREACH', 113),
        ('ENETUNREACH', 101),
        ('ETIMEDOUT', 110),
    ]
    default_gai_errnos = [
        ('EAI_AGAIN', -3),
        ('EAI_FAIL', -4),
        ('EAI_NONAME', -2),
        ('EAI_NODATA', -5),
    ]

    denied = ResourceDenied("Resource '%s' is not available" % resource_name)
    captured_errnos = errnos
    gai_errnos = []
    if not captured_errnos:
        captured_errnos = [getattr(errno, name, num)
                           for (name, num) in default_errnos]
        gai_errnos = [getattr(socket, name, num)
                      for (name, num) in default_gai_errnos]

    def filter_error(err):
        n = getattr(err, 'errno', None)
        if (isinstance(err, socket.timeout) or
            (isinstance(err, socket.gaierror) and n in gai_errnos) or
            n in captured_errnos):
            if not verbose:
                sys.stderr.write(denied.args[0] + "\n")
            raise denied

    old_timeout = socket.getdefaulttimeout()
    try:
        if timeout is not None:
            socket.setdefaulttimeout(timeout)
        yield
    except IOError as err:
        # urllib can wrap original socket errors multiple times (!), we must
        # unwrap to get at the original error.
        while True:
            a = err.args
            if len(a) >= 1 and isinstance(a[0], IOError):
                err = a[0]
            # The error can also be wrapped as args[1]:
            #    except socket.error as msg:
            #        raise IOError('socket error', msg).with_traceback(sys.exc_info()[2])
            elif len(a) >= 2 and isinstance(a[1], IOError):
                err = a[1]
            else:
                break
        filter_error(err)
        raise
    # XXX should we catch generic exceptions and look for their
    # __cause__ or __context__?
    finally:
        socket.setdefaulttimeout(old_timeout)

Example 165

Project: phabricator-tools
Source File: phlsys_fs.py
View license
@contextlib.contextmanager
def lockfile_context(filename):
    """Create 'filename' exclusively during context if poss. Fail otherwise.

    A lockfile is a file used as a mutex on the file system, to guarantee
    exclusive access to another resource.

    Manage lockfiles easily with this contextmanager.  It will exclusively
    create the lockfile on entering the context and destroy it when the context
    is left.  If the lockfile cannot be exclusively created then raise
    LockfileExistsError.

    Creating files may be done atomically on a POSIX file system if the correct
    flags are used (O_CREAT | O_EXCL).
    http://pubs.opengroup.org/onlinepubs/9699919799/functions/open.html

    Note that if an unexpected exception is raised while the context is being
    entered then it's possible the lock will be acquired and 'leaked'. e.g. if
    the program is terminated whilst entering the context.

    """
    try:
        handle = os.open(filename, os.O_CREAT | os.O_EXCL)
    except OSError as e:
        if e.errno == 17:
            raise LockfileExistsError()
        else:
            raise

    # XXX: note that if we are interrupted here (e.g. by program termination)
    #      then the file will still exist and the lock will be erroneously
    #      still acquired after program exit.
    #
    #      there doesn't seem to be a good way to completely exclude this
    #      possibility - there's always the space between os.open() and
    #      assignment to 'handle'.

    try:
        yield
    finally:
        os.close(handle)
        os.remove(filename)

Example 166

Project: LibrERP
Source File: http.py
View license
@contextlib.contextmanager
def session_context(request, storage_path, session_cookie='sessionid'):
    session_store, session_lock = STORES.get(storage_path, (None, None))
    if not session_store:
        session_store = werkzeug.contrib.sessions.FilesystemSessionStore(
            storage_path)
        session_lock = threading.Lock()
        STORES[storage_path] = session_store, session_lock

    sid = request.cookies.get(session_cookie)
    with session_lock:
        if sid:
            request.session = session_store.get(sid)
        else:
            request.session = session_store.new()

    try:
        yield request.session
    finally:
        # Remove all OpenERPSession instances with no uid, they're generated
        # either by login process or by HTTP requests without an OpenERP
        # session id, and are generally noise
        removed_sessions = set()
        for key, value in request.session.items():
            if not isinstance(value, session.OpenERPSession):
                continue
            if getattr(value, '_suicide', False) or (
                        not value._uid
                    and not value.jsonp_requests
                    # FIXME do not use a fixed value
                    and value._creation_time + (60*5) < time.time()):
                _logger.debug('remove session %s', key)
                removed_sessions.add(key)
                del request.session[key]

        with session_lock:
            if sid:
                # Re-load sessions from storage and merge non-literal
                # contexts and domains (they're indexed by hash of the
                # content so conflicts should auto-resolve), otherwise if
                # two requests alter those concurrently the last to finish
                # will overwrite the previous one, leading to loss of data
                # (a non-literal is lost even though it was sent to the
                # client and client errors)
                #
                # note that domains_store and contexts_store are append-only (we
                # only ever add items to them), so we can just update one with the
                # other to get the right result, if we want to merge the
                # ``context`` dict we'll need something smarter
                in_store = session_store.get(sid)
                for k, v in request.session.iteritems():
                    stored = in_store.get(k)
                    if stored and isinstance(v, session.OpenERPSession):
                        v.contexts_store.update(stored.contexts_store)
                        v.domains_store.update(stored.domains_store)
                        if not hasattr(v, 'jsonp_requests'):
                            v.jsonp_requests = {}
                        v.jsonp_requests.update(getattr(
                            stored, 'jsonp_requests', {}))

                # add missing keys
                for k, v in in_store.iteritems():
                    if k not in request.session and k not in removed_sessions:
                        request.session[k] = v

            session_store.save(request.session)

Example 167

Project: HealthStarter
Source File: schema.py
View license
    def _remake_table(self, model, create_fields=[], delete_fields=[], alter_fields=[], override_uniques=None,
                      override_indexes=None):
        """
        Shortcut to transform a model from old_model into new_model

        The essential steps are:
          1. rename the model's existing table, e.g. "app_model" to "app_model__old"
          2. create a table with the updated definition called "app_model"
          3. copy the data from the old renamed table to the new table
          4. delete the "app_model__old" table
        """
        # Self-referential fields must be recreated rather than copied from
        # the old model to ensure their remote_field.field_name doesn't refer
        # to an altered field.
        def is_self_referential(f):
            return f.is_relation and f.remote_field.model is model
        # Work out the new fields dict / mapping
        body = {
            f.name: f.clone() if is_self_referential(f) else f
            for f in model._meta.local_concrete_fields
        }
        # Since mapping might mix column names and default values,
        # its values must be already quoted.
        mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields}
        # This maps field names (not columns) for things like unique_together
        rename_mapping = {}
        # If any of the new or altered fields is introducing a new PK,
        # remove the old one
        restore_pk_field = None
        if any(f.primary_key for f in create_fields) or any(n.primary_key for o, n in alter_fields):
            for name, field in list(body.items()):
                if field.primary_key:
                    field.primary_key = False
                    restore_pk_field = field
                    if field.auto_created:
                        del body[name]
                        del mapping[field.column]
        # Add in any created fields
        for field in create_fields:
            body[field.name] = field
            # Choose a default and insert it into the copy map
            if not field.many_to_many and field.concrete:
                mapping[field.column] = self.quote_value(
                    self.effective_default(field)
                )
        # Add in any altered fields
        for (old_field, new_field) in alter_fields:
            body.pop(old_field.name, None)
            mapping.pop(old_field.column, None)
            body[new_field.name] = new_field
            if old_field.null and not new_field.null:
                case_sql = "coalesce(%(col)s, %(default)s)" % {
                    'col': self.quote_name(old_field.column),
                    'default': self.quote_value(self.effective_default(new_field))
                }
                mapping[new_field.column] = case_sql
            else:
                mapping[new_field.column] = self.quote_name(old_field.column)
            rename_mapping[old_field.name] = new_field.name
        # Remove any deleted fields
        for field in delete_fields:
            del body[field.name]
            del mapping[field.column]
            # Remove any implicit M2M tables
            if field.many_to_many and field.remote_field.through._meta.auto_created:
                return self.delete_model(field.remote_field.through)
        # Work inside a new app registry
        apps = Apps()

        # Provide isolated instances of the fields to the new model body so
        # that the existing model's internals aren't interfered with when
        # the dummy model is constructed.
        body = copy.deepcopy(body)

        # Work out the new value of unique_together, taking renames into
        # account
        if override_uniques is None:
            override_uniques = [
                [rename_mapping.get(n, n) for n in unique]
                for unique in model._meta.unique_together
            ]

        # Work out the new value for index_together, taking renames into
        # account
        if override_indexes is None:
            override_indexes = [
                [rename_mapping.get(n, n) for n in index]
                for index in model._meta.index_together
            ]

        # Construct a new model for the new state
        meta_contents = {
            'app_label': model._meta.app_label,
            'db_table': model._meta.db_table,
            'unique_together': override_uniques,
            'index_together': override_indexes,
            'apps': apps,
        }
        meta = type("Meta", tuple(), meta_contents)
        body['Meta'] = meta
        body['__module__'] = model.__module__

        temp_model = type(model._meta.object_name, model.__bases__, body)

        # We need to modify model._meta.db_table, but everything explodes
        # if the change isn't reversed before the end of this method. This
        # context manager helps us avoid that situation.
        @contextlib.contextmanager
        def altered_table_name(model, temporary_table_name):
            original_table_name = model._meta.db_table
            model._meta.db_table = temporary_table_name
            yield
            model._meta.db_table = original_table_name

        with altered_table_name(model, model._meta.db_table + "__old"):
            # Rename the old table to make way for the new
            self.alter_db_table(model, temp_model._meta.db_table, model._meta.db_table)

            # Create a new table with the updated schema. We remove things
            # from the deferred SQL that match our table name, too
            self.deferred_sql = [x for x in self.deferred_sql if temp_model._meta.db_table not in x]
            self.create_model(temp_model)

            # Copy data from the old table into the new table
            field_maps = list(mapping.items())
            self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
                self.quote_name(temp_model._meta.db_table),
                ', '.join(self.quote_name(x) for x, y in field_maps),
                ', '.join(y for x, y in field_maps),
                self.quote_name(model._meta.db_table),
            ))

            # Delete the old table
            self.delete_model(model, handle_autom2m=False)

        # Run deferred SQL on correct table
        for sql in self.deferred_sql:
            self.execute(sql)
        self.deferred_sql = []
        # Fix any PK-removed field
        if restore_pk_field:
            restore_pk_field.primary_key = True

Example 168

Project: brython
Source File: util.py
View license
@contextlib.contextmanager
def create_modules(*names):
    """Temporarily create each named module with an attribute (named 'attr')
    that contains the name passed into the context manager that caused the
    creation of the module.

    All files are created in a temporary directory returned by
    tempfile.mkdtemp(). This directory is inserted at the beginning of
    sys.path. When the context manager exits all created files (source and
    bytecode) are explicitly deleted.

    No magic is performed when creating packages! This means that if you create
    a module within a package you must also create the package's __init__ as
    well.

    """
    source = 'attr = {0!r}'
    created_paths = []
    mapping = {}
    state_manager = None
    uncache_manager = None
    try:
        temp_dir = tempfile.mkdtemp()
        mapping['.root'] = temp_dir
        import_names = set()
        for name in names:
            if not name.endswith('__init__'):
                import_name = name
            else:
                import_name = name[:-len('.__init__')]
            import_names.add(import_name)
            if import_name in sys.modules:
                del sys.modules[import_name]
            name_parts = name.split('.')
            file_path = temp_dir
            for directory in name_parts[:-1]:
                file_path = os.path.join(file_path, directory)
                if not os.path.exists(file_path):
                    os.mkdir(file_path)
                    created_paths.append(file_path)
            file_path = os.path.join(file_path, name_parts[-1] + '.py')
            with open(file_path, 'w') as file:
                file.write(source.format(name))
            created_paths.append(file_path)
            mapping[name] = file_path
        uncache_manager = util.uncache(*import_names)
        uncache_manager.__enter__()
        state_manager = util.import_state(path=[temp_dir])
        state_manager.__enter__()
        yield mapping
    finally:
        if state_manager is not None:
            state_manager.__exit__(None, None, None)
        if uncache_manager is not None:
            uncache_manager.__exit__(None, None, None)
        support.rmtree(temp_dir)

Example 169

Project: jinja-to-js
Source File: __init__.py
View license
    @contextlib.contextmanager
    def _scoped_variables(self, nodes_list, **kwargs):
        """
        Context manager for creating scoped variables defined by the nodes in `nodes_list`.
        These variables will be added to the context, and when the context manager exits the
        context object will be restored to it's previous state.
        """

        tmp_vars = []
        for node in nodes_list:

            is_assign_node = isinstance(node, nodes.Assign)
            name = node.target.name if is_assign_node else node.name

            # create a temp variable name
            tmp_var = next(self.temp_var_names)

            # save previous context value
            with self._execution():

                # save the current value of this name
                self.output.write('var %s = %s.%s;' % (tmp_var, self.context_name, name))

                # add new value to context
                self.output.write('%s.%s = ' % (self.context_name, name))

                if is_assign_node:
                    self._process_node(node.node, **kwargs)
                else:
                    self.output.write(node.name)

                self.output.write(';')

            tmp_vars.append((tmp_var, name))

        yield

        # restore context
        for tmp_var, name in tmp_vars:
            with self._execution():
                self.output.write('%s.%s = %s;' % (self.context_name, name, tmp_var))

Example 170

Project: cloudpulse
Source File: utils.py
View license
@contextlib.contextmanager
def temporary_mutation(obj, **kwargs):
    """Temporarily change object attribute.

    Temporarily set the attr on a particular object to a given value then
    revert when finished.

    One use of this is to temporarily set the read_deleted flag on a context
    object:

        with temporary_mutation(context, read_deleted="yes"):
            do_something_that_needed_deleted_objects()
    """
    def is_dict_like(thing):
        return hasattr(thing, 'has_key')

    def get(thing, attr, default):
        if is_dict_like(thing):
            return thing.get(attr, default)
        else:
            return getattr(thing, attr, default)

    def set_value(thing, attr, val):
        if is_dict_like(thing):
            thing[attr] = val
        else:
            setattr(thing, attr, val)

    def delete(thing, attr):
        if is_dict_like(thing):
            del thing[attr]
        else:
            delattr(thing, attr)

    NOT_PRESENT = object()

    old_values = {}
    for attr, new_value in kwargs.items():
        old_values[attr] = get(obj, attr, NOT_PRESENT)
        set_value(obj, attr, new_value)

    try:
        yield
    finally:
        for attr, old_value in old_values.items():
            if old_value is NOT_PRESENT:
                delete(obj, attr)
            else:
                set_value(obj, attr, old_value)

Example 171

Project: setuptools
Source File: test.py
View license
    @contextlib.contextmanager
    def project_on_sys_path(self, include_dists=[]):
        with_2to3 = six.PY3 and getattr(self.distribution, 'use_2to3', False)

        if with_2to3:
            # If we run 2to3 we can not do this inplace:

            # Ensure metadata is up-to-date
            self.reinitialize_command('build_py', inplace=0)
            self.run_command('build_py')
            bpy_cmd = self.get_finalized_command("build_py")
            build_path = normalize_path(bpy_cmd.build_lib)

            # Build extensions
            self.reinitialize_command('egg_info', egg_base=build_path)
            self.run_command('egg_info')

            self.reinitialize_command('build_ext', inplace=0)
            self.run_command('build_ext')
        else:
            # Without 2to3 inplace works fine:
            self.run_command('egg_info')

            # Build extensions in-place
            self.reinitialize_command('build_ext', inplace=1)
            self.run_command('build_ext')

        ei_cmd = self.get_finalized_command("egg_info")

        old_path = sys.path[:]
        old_modules = sys.modules.copy()

        try:
            project_path = normalize_path(ei_cmd.egg_base)
            sys.path.insert(0, project_path)
            working_set.__init__()
            add_activation_listener(lambda dist: dist.activate())
            require('%s==%s' % (ei_cmd.egg_name, ei_cmd.egg_version))
            with self.paths_on_pythonpath([project_path]):
                yield
        finally:
            sys.path[:] = old_path
            sys.modules.clear()
            sys.modules.update(old_modules)
            working_set.__init__()

Example 172

Project: Nuitka
Source File: SconsInterface.py
View license
@contextlib.contextmanager
def setupSconsEnvironment():
    # For the scons file to find the static C++ files and include path. The
    # scons file is unable to use __file__ for the task.
    os.environ["NUITKA_SCONS"] = getSconsDataPath()

    if Utils.getOS() == "Windows":
        # On Windows this Scons variable must be set by us.
        os.environ["SCONS_LIB_DIR"] = Utils.joinpath(
            getSconsInlinePath(),
            "lib",
            "scons-2.3.2"
        )

        # On Windows, we use the Python.DLL path for some things. We pass it
        # via environment variable
        os.environ["NUITKA_PYTHON_DLL_PATH"] = getTargetPythonDLLPath()

        os.environ["NUITKA_PYTHON_EXE_PATH"] = sys.executable
    # Remove environment variables that can only harm if we have to switch
    # major Python versions, these cannot help Python2 to execute scons, this
    # is a bit of noise, but helpful.
    if python_version >= 300:
        if "PYTHONPATH" in os.environ:
            old_pythonpath = os.environ["PYTHONPATH"]
            del os.environ["PYTHONPATH"]
        else:
            old_pythonpath = None

        if "PYTHONHOME" in os.environ:
            old_pythonhome = os.environ["PYTHONHOME"]
            del os.environ["PYTHONHOME"]
        else:
            old_pythonhome = None

    yield

    if python_version >= 300:
        if old_pythonpath is not None:
            os.environ["PYTHONPATH"] = old_pythonpath

        if old_pythonhome is not None:
            os.environ["PYTHONHOME"] = old_pythonhome

    if Utils.getOS() == "Windows":
        del os.environ["NUITKA_PYTHON_DLL_PATH"]
        del os.environ["NUITKA_PYTHON_EXE_PATH"]

Example 173

View license
@contextlib.contextmanager
def run_server(loop, *, listen_addr=('127.0.0.1', 0),
               use_ssl=False, router=None):
    properties = {}
    transports = []

    class HttpRequestHandler:

        def __init__(self, addr):
            host, port = addr
            self.host = host
            self.port = port
            self.address = addr
            self._url = '{}://{}:{}'.format(
                'https' if use_ssl else 'http', host, port)

        def __getitem__(self, key):
            return properties[key]

        def __setitem__(self, key, value):
            properties[key] = value

        def url(self, *suffix):
            return urllib.parse.urljoin(
                self._url, '/'.join(str(s) for s in suffix))

    class TestHttpServer(server.ServerHttpProtocol):

        def connection_made(self, transport):
            transports.append(transport)

            super().connection_made(transport)

        def handle_request(self, message, payload):
            if properties.get('close', False):
                return

            for hdr, val in message.headers.items():
                if (hdr.upper() == 'EXPECT') and (val == '100-continue'):
                    self.transport.write(b'HTTP/1.0 100 Continue\r\n\r\n')
                    break

            body = yield from payload.read()

            rob = router(
                self, properties, self.transport, message, body)
            rob.dispatch()

    if use_ssl:
        here = os.path.join(os.path.dirname(__file__), '..', 'tests')
        keyfile = os.path.join(here, 'sample.key')
        certfile = os.path.join(here, 'sample.crt')
        sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        sslcontext.load_cert_chain(certfile, keyfile)
    else:
        sslcontext = None

    def run(loop, fut):
        thread_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(thread_loop)

        host, port = listen_addr
        server_coroutine = thread_loop.create_server(
            lambda: TestHttpServer(keep_alive=0.5),
            host, port, ssl=sslcontext)
        server = thread_loop.run_until_complete(server_coroutine)

        waiter = helpers.create_future(thread_loop)
        loop.call_soon_threadsafe(
            fut.set_result, (thread_loop, waiter,
                             server.sockets[0].getsockname()))

        try:
            thread_loop.run_until_complete(waiter)
        finally:
            # call pending connection_made if present
            run_briefly(thread_loop)

            # close opened transports
            for tr in transports:
                tr.close()

            run_briefly(thread_loop)  # call close callbacks

            server.close()
            thread_loop.stop()
            thread_loop.close()
            gc.collect()

    fut = helpers.create_future(loop)
    server_thread = threading.Thread(target=run, args=(loop, fut))
    server_thread.start()

    thread_loop, waiter, addr = loop.run_until_complete(fut)
    try:
        yield HttpRequestHandler(addr)
    finally:
        thread_loop.call_soon_threadsafe(waiter.set_result, None)
        server_thread.join()

Example 174

Project: QGIS-CKAN-Browser
Source File: windows.py
View license
def init_windows_clipboard(cygwin=False):
    if cygwin:
        windll = ctypes.cdll  # TODO: This is untested
    else:
        windll = ctypes.windll

    safeCreateWindowExA = CheckedCall(windll.user32.CreateWindowExA)
    safeCreateWindowExA.argtypes = [DWORD, LPCSTR, LPCSTR, DWORD, INT, INT, INT, INT, HWND, HMENU, HINSTANCE, LPVOID]
    safeCreateWindowExA.restype = HWND

    safeDestroyWindow = CheckedCall(windll.user32.DestroyWindow)
    safeDestroyWindow.argtypes = [HWND]
    safeDestroyWindow.restype = BOOL

    OpenClipboard = windll.user32.OpenClipboard
    OpenClipboard.argtypes = [HWND]
    OpenClipboard.restype = BOOL

    safeCloseClipboard = CheckedCall(windll.user32.CloseClipboard)
    safeCloseClipboard.argtypes = []
    safeCloseClipboard.restype = BOOL

    safeEmptyClipboard = CheckedCall(windll.user32.EmptyClipboard)
    safeEmptyClipboard.argtypes = []
    safeEmptyClipboard.restype = BOOL

    safeGetClipboardData = CheckedCall(windll.user32.GetClipboardData)
    safeGetClipboardData.argtypes = [UINT]
    safeGetClipboardData.restype = HANDLE

    safeSetClipboardData = CheckedCall(windll.user32.SetClipboardData)
    safeSetClipboardData.argtypes = [UINT, HANDLE]
    safeSetClipboardData.restype = HANDLE

    safeGlobalAlloc = CheckedCall(windll.kernel32.GlobalAlloc)
    safeGlobalAlloc.argtypes = [UINT, c_size_t]
    safeGlobalAlloc.restype = HGLOBAL

    safeGlobalLock = CheckedCall(windll.kernel32.GlobalLock)
    safeGlobalLock.argtypes = [HGLOBAL]
    safeGlobalLock.restype = LPVOID

    safeGlobalUnlock = CheckedCall(windll.kernel32.GlobalUnlock)
    safeGlobalUnlock.argtypes = [HGLOBAL]
    safeGlobalUnlock.restype = BOOL

    wcscpy_s = ctypes.cdll.msvcrt.wcscpy_s
    wcscpy_s.argtypes = [c_wchar_p, c_size_t, c_wchar_p]
    wcscpy_s.restype = c_wchar_p

    GMEM_MOVEABLE = 0x0002
    CF_UNICODETEXT = 13

    @contextlib.contextmanager
    def window():
        """
        Context that provides a valid Windows hwnd.
        """
        # we really just need the hwnd, so setting "STATIC" as predefined lpClass is just fine.
        hwnd = safeCreateWindowExA(0, b"STATIC", None, 0, 0, 0, 0, 0, None, None, None, None)
        try:
            yield hwnd
        finally:
            safeDestroyWindow(hwnd)

    @contextlib.contextmanager
    def clipboard(hwnd):
        """
        Context manager that opens the clipboard and prevents other applications from modifying the clipboard content.
        """
        # We may not get the clipboard handle immediately because some other application is accessing it (?)
        # We try for at least 500ms to get the clipboard.
        t = time.time() + 0.5
        success = False
        while time.time() < t:
            success = OpenClipboard(hwnd)
            if success:
                break
            time.sleep(0.01)
        if not success:
            raise PyperclipWindowsException("Error calling OpenClipboard")

        try:
            yield
        finally:
            safeCloseClipboard()

    def copy_windows(text):
        # This function is heavily based on
        # https://msdn.microsoft.com/de-de/library/windows/desktop/ms649016(v=vs.85).aspx#_win32_Copying_Information_to_the_Clipboard

        with window() as hwnd:
            # https://msdn.microsoft.com/de-de/library/windows/desktop/ms649048(v=vs.85).aspx
            # > If an application calls OpenClipboard with hwnd set to NULL, EmptyClipboard sets the clipboard owner to
            # > NULL; this causes SetClipboardData to fail.
            # => We need a valid hwnd to copy something.
            with clipboard(hwnd):
                safeEmptyClipboard()

                if text:
                    # https://msdn.microsoft.com/de-de/library/windows/desktop/ms649051(v=vs.85).aspx
                    # > If the hMem parameter identifies a memory object, the object must have been allocated using the
                    # > function with the GMEM_MOVEABLE flag.
                    handle = safeGlobalAlloc(GMEM_MOVEABLE, (len(text) + 1) * sizeof(c_wchar))
                    locked_handle = safeGlobalLock(handle)

                    if wcscpy_s(c_wchar_p(locked_handle), len(text) + 1, c_wchar_p(text)):
                        raise PyperclipWindowsException("Error calling wcscpy_s")

                    safeGlobalUnlock(handle)
                    safeSetClipboardData(CF_UNICODETEXT, handle)

    def paste_windows():
        with clipboard(None):
            handle = safeGetClipboardData(CF_UNICODETEXT)
            if not handle:
                # GetClipboardData may return NULL with errno == NO_ERROR if the clipboard is empty.
                # (Also, it may return a handle to an empty buffer, but technically that's not empty)
                return ""
            return c_wchar_p(handle).value

    return copy_windows, paste_windows

Example 175

Project: pyface
Source File: gui_test_assistant.py
View license
    @contextlib.contextmanager
    def assertTraitChangesInEventLoop(self, obj, trait, condition, count=1,
                                      timeout=10.0):
        """Runs the real Qt event loop, collecting trait change events until
        the provided condition evaluates to True.

        Parameters
        ----------
        obj : HasTraits
            The HasTraits instance whose trait will change.
        trait : str
            The extended trait name of trait changes to listen too.
        condition : callable
            A callable to determine if the stop criteria have been met. This
            should accept no arguments.
        count : int
            The expected number of times the event should be fired. The default
            is to expect one event.
        timeout : float
            Number of seconds to run the event loop in the case that the trait
            change does not occur.
        """
        condition_ = lambda: condition(obj)
        collector = TraitsChangeCollector(obj=obj, trait=trait)

        collector.start_collecting()
        try:
            try:
                yield collector
                self.event_loop_helper.event_loop_until_condition(
                    condition_, timeout=timeout)
            except ConditionTimeoutError:
                actual_event_count = collector.event_count
                msg = ("Expected {} event on {} to be fired at least {} "
                       "times, but the event was only fired {} times "
                       "before timeout ({} seconds).")
                msg = msg.format(
                    trait, obj, count, actual_event_count, timeout)
                self.fail(msg)
        finally:
            collector.stop_collecting()

Example 176

Project: oslo.concurrency
Source File: lockutils.py
View license
@contextlib.contextmanager
def lock(name, lock_file_prefix=None, external=False, lock_path=None,
         do_log=True, semaphores=None, delay=0.01):
    """Context based lock

    This function yields a `threading.Semaphore` instance (if we don't use
    eventlet.monkey_patch(), else `semaphore.Semaphore`) unless external is
    True, in which case, it'll yield an InterProcessLock instance.

    :param lock_file_prefix: The lock_file_prefix argument is used to provide
      lock files on disk with a meaningful prefix.

    :param external: The external keyword argument denotes whether this lock
      should work across multiple processes. This means that if two different
      workers both run a method decorated with @synchronized('mylock',
      external=True), only one of them will execute at a time.

    :param lock_path: The path in which to store external lock files.  For
      external locking to work properly, this must be the same for all
      references to the lock.

    :param do_log: Whether to log acquire/release messages.  This is primarily
      intended to reduce log message duplication when `lock` is used from the
      `synchronized` decorator.

    :param semaphores: Container that provides semaphores to use when locking.
        This ensures that threads inside the same application can not collide,
        due to the fact that external process locks are unaware of a processes
        active threads.

    :param delay: Delay between acquisition attempts (in seconds).

    .. versionchanged:: 0.2
       Added *do_log* optional parameter.

    .. versionchanged:: 0.3
       Added *delay* and *semaphores* optional parameters.
    """
    int_lock = internal_lock(name, semaphores=semaphores)
    with int_lock:
        if do_log:
            LOG.debug('Acquired semaphore "%(lock)s"', {'lock': name})
        try:
            if external and not CONF.oslo_concurrency.disable_process_locking:
                ext_lock = external_lock(name, lock_file_prefix, lock_path)
                ext_lock.acquire(delay=delay)
                try:
                    yield ext_lock
                finally:
                    ext_lock.release()
            else:
                yield int_lock
        finally:
            if do_log:
                LOG.debug('Releasing semaphore "%(lock)s"', {'lock': name})

Example 177

Project: pandashells
Source File: utils_lib.py
View license
@contextlib.contextmanager
def Timer(name='', silent=False, pretty=False, header=True):
    """
    A context manager for timing sections of code.
    :type name: str
    :param name: The name you want to give the contextified code
    :type silent: bool
    :param silent: Setting this to true will mute all printing
    :type pretty: bool
    :param pretty: When set to true, prints elapsed time in hh:mm:ss.mmmmmm
    Example
    ---------------------------------------------------------------------------
    # Example code for timing different parts of your code
    import time
    from pandashells import Timer
    with Timer('entire script'):
        for nn in range(3):
            with Timer('loop {}'.format(nn + 1)):
                time.sleep(.1 * nn)
    # Will generate the following output on stdout
    #     col1: a string that is easily found with grep
    #     col2: the time in seconds (or in hh:mm:ss if pretty=True)
    #     col3: the value passed to the 'name' argument of Timer

    __time__,2.6e-05,loop 1
    __time__,0.105134,loop 2
    __time__,0.204489,loop 3
    __time__,0.310102,entire script

    ---------------------------------------------------------------------------
    # Example for measuring how a piece of of code scales (measuring "big-O")
    import time
    from pandashells import Timer

    # initialize a list to hold results
    results = []

    # run a piece of code with different values of the var you want to scale
    for nn in range(3):
        # time each iteration
        with Timer('loop {}'.format(nn + 1), silent=True) as timer:
            time.sleep(.1 * nn)
        # add results
        results.append((nn, timer))

    # print csv compatible text for further pandashells processing/plotting
    print 'nn,seconds'
    for nn, timer in results:
        print '{},{}'.format(nn,timer.seconds)
    """
    if not header:
        OutStream.header_needs_printing = False
    stream = OutStream()
    result = TimerResult(name, starting=datetime.datetime.now())
    yield result
    result.ending = datetime.datetime.now()
    dt = result.ending - result.starting
    result.seconds = dt.total_seconds()
    dt = dt if pretty else result.seconds
    if not silent:
        stream.write('__time__,{},'.format(dt))
        if name:
            stream.write('%s\n' % name)

Example 178

Project: python-cinderclient
Source File: base.py
View license
    @contextlib.contextmanager
    def completion_cache(self, cache_type, obj_class, mode):
        """
        The completion cache store items that can be used for bash
        autocompletion, like UUIDs or human-friendly IDs.

        A resource listing will clear and repopulate the cache.

        A resource create will append to the cache.

        Delete is not handled because listings are assumed to be performed
        often enough to keep the cache reasonably up-to-date.
        """
        base_dir = utils.env('CINDERCLIENT_UUID_CACHE_DIR',
                             default="~/.cinderclient")

        # NOTE(sirp): Keep separate UUID caches for each username + endpoint
        # pair
        username = utils.env('OS_USERNAME', 'CINDER_USERNAME')
        url = utils.env('OS_URL', 'CINDER_URL')
        uniqifier = hashlib.md5(username.encode('utf-8') +
                                url.encode('utf-8')).hexdigest()

        cache_dir = os.path.expanduser(os.path.join(base_dir, uniqifier))

        try:
            os.makedirs(cache_dir, 0o755)
        except OSError:
            # NOTE(kiall): This is typically either permission denied while
            #              attempting to create the directory, or the directory
            #              already exists. Either way, don't fail.
            pass

        resource = obj_class.__name__.lower()
        filename = "%s-%s-cache" % (resource, cache_type.replace('_', '-'))
        path = os.path.join(cache_dir, filename)

        cache_attr = "_%s_cache" % cache_type

        try:
            setattr(self, cache_attr, open(path, mode))
        except IOError:
            # NOTE(kiall): This is typically a permission denied while
            #              attempting to write the cache file.
            pass

        try:
            yield
        finally:
            cache = getattr(self, cache_attr, None)
            if cache:
                cache.close()
                delattr(self, cache_attr)

Example 179

Project: python-novaclient
Source File: base.py
View license
    @contextlib.contextmanager
    def completion_cache(self, cache_type, obj_class, mode):
        """The completion cache for bash autocompletion.

        The completion cache store items that can be used for bash
        autocompletion, like UUIDs or human-friendly IDs.

        A resource listing will clear and repopulate the cache.

        A resource create will append to the cache.

        Delete is not handled because listings are assumed to be performed
        often enough to keep the cache reasonably up-to-date.
        """
        # NOTE(wryan): This lock protects read and write access to the
        # completion caches
        with self.cache_lock:
            base_dir = utils.env('NOVACLIENT_UUID_CACHE_DIR',
                                 default="~/.novaclient")

            # NOTE(sirp): Keep separate UUID caches for each username +
            # endpoint pair
            username = utils.env('OS_USERNAME', 'NOVA_USERNAME')
            url = utils.env('OS_URL', 'NOVA_URL')
            uniqifier = hashlib.md5(username.encode('utf-8') +
                                    url.encode('utf-8')).hexdigest()

            cache_dir = os.path.expanduser(os.path.join(base_dir, uniqifier))

            try:
                os.makedirs(cache_dir, 0o755)
            except OSError:
                # NOTE(kiall): This is typically either permission denied while
                #              attempting to create the directory, or the
                #              directory already exists. Either way, don't
                #              fail.
                pass

            resource = obj_class.__name__.lower()
            filename = "%s-%s-cache" % (resource, cache_type.replace('_', '-'))
            path = os.path.join(cache_dir, filename)

            cache_attr = "_%s_cache" % cache_type

            try:
                setattr(self, cache_attr, open(path, mode))
            except IOError:
                # NOTE(kiall): This is typically a permission denied while
                #              attempting to write the cache file.
                pass

            try:
                yield
            finally:
                cache = getattr(self, cache_attr, None)
                if cache:
                    cache.close()
                    delattr(self, cache_attr)

Example 180

Project: FIDDLE
Source File: arg_scope.py
View license
@contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):
  """Stores the default arguments for the given set of list_ops.

  For usage, please see examples at top of the file.

  Args:
    list_ops_or_scope: List or tuple of operations to set argument scope for or
      a dictionary containg the current scope. When list_ops_or_scope is a dict,
      kwargs must be empty. When list_ops_or_scope is a list or tuple, then
      every op in it need to be decorated with @add_arg_scope to work.
    **kwargs: keyword=value that will define the defaults for each op in
              list_ops. All the ops need to accept the given set of arguments.

  Yields:
    the current_scope, which is a dictionary of {op: {arg: value}}
  Raises:
    TypeError: if list_ops is not a list or a tuple.
    ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
  """
  if isinstance(list_ops_or_scope, dict):
    # Assumes that list_ops_or_scope is a scope that is being reused.
    if kwargs:
      raise ValueError('When attempting to re-use a scope by suppling a'
                       'dictionary, kwargs must be empty.')
    current_scope = list_ops_or_scope.copy()
    try:
      _get_arg_stack().append(current_scope)
      yield current_scope
    finally:
      _get_arg_stack().pop()
  else:
    # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
    if not isinstance(list_ops_or_scope, (list, tuple)):
      raise TypeError('list_ops_or_scope must either be a list/tuple or reused'
                      'scope (i.e. dict)')
    try:
      current_scope = _current_arg_scope().copy()
      for op in list_ops_or_scope:
        key_op = _key_op(op)
        if not has_arg_scope(op):
          raise ValueError('%s is not decorated with @add_arg_scope',
                           _name_op(op))
        if key_op in current_scope:
          current_kwargs = current_scope[key_op].copy()
          current_kwargs.update(kwargs)
          current_scope[key_op] = current_kwargs
        else:
          current_scope[key_op] = kwargs.copy()
      _get_arg_stack().append(current_scope)
      yield current_scope
    finally:
      _get_arg_stack().pop()

Example 181

View license
@contextlib.contextmanager
def run_server(loop, *, listen_addr=('127.0.0.1', 0),
               use_ssl=False, router=None):
    properties = {}
    transports = []

    class HttpRequestHandler:

        def __init__(self, addr):
            host, port = addr
            self.host = host
            self.port = port
            self.address = addr
            self._url = '{}://{}:{}'.format(
                'https' if use_ssl else 'http', host, port)

        def __getitem__(self, key):
            return properties[key]

        def __setitem__(self, key, value):
            properties[key] = value

        def url(self, *suffix):
            return urllib.parse.urljoin(
                self._url, '/'.join(str(s) for s in suffix))

    class TestHttpServer(server.ServerHttpProtocol):

        def connection_made(self, transport):
            transports.append(transport)

            super().connection_made(transport)

        def handle_request(self, message, payload):
            if properties.get('close', False):
                return

            for hdr, val in message.headers.items():
                if (hdr.upper() == 'EXPECT') and (val == '100-continue'):
                    self.transport.write(b'HTTP/1.0 100 Continue\r\n\r\n')
                    break

            body = yield from payload.read()

            rob = router(
                self, properties, self.transport, message, body)
            rob.dispatch()

    if use_ssl:
        here = os.path.join(os.path.dirname(__file__), '..', 'tests')
        keyfile = os.path.join(here, 'sample.key')
        certfile = os.path.join(here, 'sample.crt')
        sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        sslcontext.load_cert_chain(certfile, keyfile)
    else:
        sslcontext = None

    def run(loop, fut):
        thread_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(thread_loop)

        host, port = listen_addr
        server_coroutine = thread_loop.create_server(
            lambda: TestHttpServer(keep_alive=0.5),
            host, port, ssl=sslcontext)
        server = thread_loop.run_until_complete(server_coroutine)

        waiter = helpers.create_future(thread_loop)
        loop.call_soon_threadsafe(
            fut.set_result, (thread_loop, waiter,
                             server.sockets[0].getsockname()))

        try:
            thread_loop.run_until_complete(waiter)
        finally:
            # call pending connection_made if present
            run_briefly(thread_loop)

            # close opened transports
            for tr in transports:
                tr.close()

            run_briefly(thread_loop)  # call close callbacks

            server.close()
            thread_loop.stop()
            thread_loop.close()
            gc.collect()

    fut = helpers.create_future(loop)
    server_thread = threading.Thread(target=run, args=(loop, fut))
    server_thread.start()

    thread_loop, waiter, addr = loop.run_until_complete(fut)
    try:
        yield HttpRequestHandler(addr)
    finally:
        thread_loop.call_soon_threadsafe(waiter.set_result, None)
        server_thread.join()

Example 182

Project: xtraceback
Source File: test_support.py
View license
@contextlib.contextmanager
def transient_internet(resource_name, timeout=30.0, errnos=()):
    """Return a context manager that raises ResourceDenied when various issues
    with the Internet connection manifest themselves as exceptions."""
    default_errnos = [
        ('ECONNREFUSED', 111),
        ('ECONNRESET', 104),
        ('EHOSTUNREACH', 113),
        ('ENETUNREACH', 101),
        ('ETIMEDOUT', 110),
    ]
    default_gai_errnos = [
        ('EAI_NONAME', -2),
        ('EAI_NODATA', -5),
    ]

    denied = ResourceDenied("Resource '%s' is not available" % resource_name)
    captured_errnos = errnos
    gai_errnos = []
    if not captured_errnos:
        captured_errnos = [getattr(errno, name, num)
                           for (name, num) in default_errnos]
        gai_errnos = [getattr(socket, name, num)
                      for (name, num) in default_gai_errnos]

    def filter_error(err):
        n = getattr(err, 'errno', None)
        if (isinstance(err, socket.timeout) or
            (isinstance(err, socket.gaierror) and n in gai_errnos) or
            n in captured_errnos):
            if not verbose:
                sys.stderr.write(denied.args[0] + "\n")
            raise denied

    old_timeout = socket.getdefaulttimeout()
    try:
        if timeout is not None:
            socket.setdefaulttimeout(timeout)
        yield
    except IOError as err:
        # urllib can wrap original socket errors multiple times (!), we must
        # unwrap to get at the original error.
        while True:
            a = err.args
            if len(a) >= 1 and isinstance(a[0], IOError):
                err = a[0]
            # The error can also be wrapped as args[1]:
            #    except socket.error as msg:
            #        raise IOError('socket error', msg).with_traceback(sys.exc_info()[2])
            elif len(a) >= 2 and isinstance(a[1], IOError):
                err = a[1]
            else:
                break
        filter_error(err)
        raise
    # XXX should we catch generic exceptions and look for their
    # __cause__ or __context__?
    finally:
        socket.setdefaulttimeout(old_timeout)

Example 183

Project: xtraceback
Source File: support.py
View license
@contextlib.contextmanager
def transient_internet(resource_name, *, timeout=30.0, errnos=()):
    """Return a context manager that raises ResourceDenied when various issues
    with the Internet connection manifest themselves as exceptions."""
    default_errnos = [
        ('ECONNREFUSED', 111),
        ('ECONNRESET', 104),
        ('EHOSTUNREACH', 113),
        ('ENETUNREACH', 101),
        ('ETIMEDOUT', 110),
    ]
    default_gai_errnos = [
        ('EAI_NONAME', -2),
        ('EAI_NODATA', -5),
    ]

    denied = ResourceDenied("Resource '%s' is not available" % resource_name)
    captured_errnos = errnos
    gai_errnos = []
    if not captured_errnos:
        captured_errnos = [getattr(errno, name, num)
                           for (name, num) in default_errnos]
        gai_errnos = [getattr(socket, name, num)
                      for (name, num) in default_gai_errnos]

    def filter_error(err):
        n = getattr(err, 'errno', None)
        if (isinstance(err, socket.timeout) or
            (isinstance(err, socket.gaierror) and n in gai_errnos) or
            n in captured_errnos):
            if not verbose:
                sys.stderr.write(denied.args[0] + "\n")
            raise denied from err

    old_timeout = socket.getdefaulttimeout()
    try:
        if timeout is not None:
            socket.setdefaulttimeout(timeout)
        yield
    except IOError as err:
        # urllib can wrap original socket errors multiple times (!), we must
        # unwrap to get at the original error.
        while True:
            a = err.args
            if len(a) >= 1 and isinstance(a[0], IOError):
                err = a[0]
            # The error can also be wrapped as args[1]:
            #    except socket.error as msg:
            #        raise IOError('socket error', msg).with_traceback(sys.exc_info()[2])
            elif len(a) >= 2 and isinstance(a[1], IOError):
                err = a[1]
            else:
                break
        filter_error(err)
        raise
    # XXX should we catch generic exceptions and look for their
    # __cause__ or __context__?
    finally:
        socket.setdefaulttimeout(old_timeout)

Example 184

Project: osf.io
Source File: utils.py
View license
@contextlib.contextmanager
def mock_archive(project, schema=None, auth=None, data=None, parent=None,
                 embargo=False, embargo_end_date=None,
                 retraction=False, justification=None, autoapprove_retraction=False,
                 autocomplete=True, autoapprove=False):
    """ A context manager for registrations. When you want to call Node#register_node in
    a test but do not want to deal with any of this side effects of archiver, this
    helper allows for creating a registration in a safe fashion.

    :param bool embargo: embargo the registration (rather than RegistrationApproval)
    :param bool autocomplete: automatically finish archival?
    :param bool autoapprove: automatically approve registration approval?
    :param bool retraction: retract the registration?
    :param str justification: a justification for the retraction
    :param bool autoapprove_retraction: automatically approve retraction?

    Example use:

    project = ProjectFactory()
    with mock_archive(project) as registration:
        assert_true(registration.is_registration)
        assert_true(registration.archiving)
        assert_true(registration.is_pending_registration)

    with mock_archive(project, autocomplete=True) as registration:
        assert_true(registration.is_registration)
        assert_false(registration.archiving)
        assert_true(registration.is_pending_registration)

    with mock_archive(project, autocomplete=True, autoapprove=True) as registration:
        assert_true(registration.is_registration)
        assert_false(registration.archiving)
        assert_false(registration.is_pending_registration)
    """
    schema = schema or DEFAULT_METASCHEMA
    auth = auth or Auth(project.creator)
    data = data or ''

    with mock.patch('framework.celery_tasks.handlers.enqueue_task'):
        registration = project.register_node(
            schema=schema,
            auth=auth,
            data=data,
            parent=parent,
        )
    if embargo:
        embargo_end_date = embargo_end_date or (
            datetime.datetime.now() + datetime.timedelta(days=20)
        )
        registration.root.embargo_registration(
            project.creator,
            embargo_end_date
        )
    else:
        registration.root.require_approval(project.creator)
    if autocomplete:
        root_job = registration.root.archive_job
        root_job.status = ARCHIVER_SUCCESS
        root_job.sent = False
        root_job.done = True
        root_job.save()
        sanction = registration.root.sanction
        with contextlib.nested(
            mock.patch.object(root_job, 'archive_tree_finished', mock.Mock(return_value=True)),
            mock.patch('website.archiver.tasks.archive_success.delay', mock.Mock())
        ):
            archiver_listeners.archive_callback(registration)
    if autoapprove:
        sanction = registration.root.sanction
        sanction.state = Sanction.APPROVED
        sanction._on_complete(project.creator)
        sanction.save()

    if retraction:
        justification = justification or "Because reasons"
        retraction = registration.retract_registration(project.creator, justification=justification)
        if autoapprove_retraction:
            retraction.state = Sanction.APPROVED
            retraction._on_complete(project.creator)
        retraction.save()
        registration.save()
    yield registration

Example 185

Project: WeasyPrint
Source File: __init__.py
View license
@contextlib.contextmanager
def _select_source(guess=None, filename=None, url=None, file_obj=None,
                   string=None, tree=None, base_url=None,
                   url_fetcher=default_url_fetcher, check_css_mime_type=False):
    """
    Check that only one input is not None, and return it with the
    normalized ``base_url``.

    """
    if base_url is not None:
        base_url = ensure_url(base_url)

    nones = [guess is None, filename is None, url is None,
             file_obj is None, string is None, tree is None]
    if nones == [False, True, True, True, True, True]:
        if hasattr(guess, 'read'):
            type_ = 'file_obj'
        elif url_is_absolute(guess):
            type_ = 'url'
        else:
            type_ = 'filename'
        result = _select_source(
            base_url=base_url, url_fetcher=url_fetcher,
            check_css_mime_type=check_css_mime_type,
            # Use str() to work around http://bugs.python.org/issue4978
            # See https://github.com/Kozea/WeasyPrint/issues/97
            **{str(type_): guess})
        with result as result:
            yield result
    elif nones == [True, False, True, True, True, True]:
        if base_url is None:
            base_url = path2url(filename)
        with open(filename, 'rb') as file_obj:
            yield 'file_obj', file_obj, base_url, None
    elif nones == [True, True, False, True, True, True]:
        with fetch(url_fetcher, url) as result:
            if check_css_mime_type and result['mime_type'] != 'text/css':
                LOGGER.warning(
                    'Unsupported stylesheet type %s for %s',
                    result['mime_type'], result['redirected_url'])
                yield 'string', '', base_url, None
            else:
                proto_encoding = result.get('encoding')
                if base_url is None:
                    base_url = result.get('redirected_url', url)
                if 'string' in result:
                    yield 'string', result['string'], base_url, proto_encoding
                else:
                    yield (
                        'file_obj', result['file_obj'], base_url,
                        proto_encoding)
    elif nones == [True, True, True, False, True, True]:
        if base_url is None:
            # filesystem file-like objects have a 'name' attribute.
            name = getattr(file_obj, 'name', None)
            # Some streams have a .name like '<stdin>', not a filename.
            if name and not name.startswith('<'):
                base_url = ensure_url(name)
        yield 'file_obj', file_obj, base_url, None
    elif nones == [True, True, True, True, False, True]:
        yield 'string', string, base_url, None
    elif nones == [True, True, True, True, True, False]:
        yield 'tree', tree, base_url, None
    else:
        raise TypeError('Expected exactly one source, got ' + (
            ', '.join(
                name for i, name in enumerate(
                    'guess filename url file_obj string tree'.split())
                if not nones[i]
            ) or 'nothing'
        ))

Example 186

Project: seqmagick
Source File: common.py
View license
@contextlib.contextmanager
def atomic_write(path, permissions=None, file_factory=None, **kwargs):
    """
    Open a file for atomic writing.

    Generates a temp file, renames to value of ``path``.

    Arguments:
    ``permissions``: Permissions to set (default: umask)
    ``file_factory``: If given, the handle yielded will be the result of
        calling file_factory(path)

    Additional arguments are passed to tempfile.NamedTemporaryFile
    """
    if permissions is None:
        permissions = apply_umask()
    # Handle stdout:
    if path == '-':
        yield sys.stdout
    else:
        base_dir = os.path.dirname(path)
        kwargs['suffix'] = os.path.basename(path)
        tf = tempfile.NamedTemporaryFile(dir=base_dir, delete=False,
                                         **kwargs)

        # If a file_factory is given, close, and re-open a handle using the
        # file_factory
        if file_factory is not None:
            tf.close()
            tf = file_factory(tf.name)
        try:
            with tf:
                yield tf
            # Move
            os.rename(tf.name, path)
            os.chmod(path, permissions)
        except:
            os.remove(tf.name)
            raise

Example 187

Project: WeasyPrint
Source File: __init__.py
View license
@contextlib.contextmanager
def _select_source(guess=None, filename=None, url=None, file_obj=None,
                   string=None, tree=None, base_url=None,
                   url_fetcher=default_url_fetcher, check_css_mime_type=False):
    """
    Check that only one input is not None, and return it with the
    normalized ``base_url``.

    """
    if base_url is not None:
        base_url = ensure_url(base_url)

    nones = [guess is None, filename is None, url is None,
             file_obj is None, string is None, tree is None]
    if nones == [False, True, True, True, True, True]:
        if hasattr(guess, 'read'):
            type_ = 'file_obj'
        elif url_is_absolute(guess):
            type_ = 'url'
        else:
            type_ = 'filename'
        result = _select_source(
            base_url=base_url, url_fetcher=url_fetcher,
            check_css_mime_type=check_css_mime_type,
            # Use str() to work around http://bugs.python.org/issue4978
            # See https://github.com/Kozea/WeasyPrint/issues/97
            **{str(type_): guess})
        with result as result:
            yield result
    elif nones == [True, False, True, True, True, True]:
        if base_url is None:
            base_url = path2url(filename)
        with open(filename, 'rb') as file_obj:
            yield 'file_obj', file_obj, base_url, None
    elif nones == [True, True, False, True, True, True]:
        with fetch(url_fetcher, url) as result:
            if check_css_mime_type and result['mime_type'] != 'text/css':
                LOGGER.warning(
                    'Unsupported stylesheet type %s for %s',
                    result['mime_type'], result['redirected_url'])
                yield 'string', '', base_url, None
            else:
                proto_encoding = result.get('encoding')
                if base_url is None:
                    base_url = result.get('redirected_url', url)
                if 'string' in result:
                    yield 'string', result['string'], base_url, proto_encoding
                else:
                    yield (
                        'file_obj', result['file_obj'], base_url,
                        proto_encoding)
    elif nones == [True, True, True, False, True, True]:
        if base_url is None:
            # filesystem file-like objects have a 'name' attribute.
            name = getattr(file_obj, 'name', None)
            # Some streams have a .name like '<stdin>', not a filename.
            if name and not name.startswith('<'):
                base_url = ensure_url(name)
        yield 'file_obj', file_obj, base_url, None
    elif nones == [True, True, True, True, False, True]:
        yield 'string', string, base_url, None
    elif nones == [True, True, True, True, True, False]:
        yield 'tree', tree, base_url, None
    else:
        raise TypeError('Expected exactly one source, got ' + (
            ', '.join(
                name for i, name in enumerate(
                    'guess filename url file_obj string tree'.split())
                if not nones[i]
            ) or 'nothing'
        ))

Example 188

Project: teuthology
Source File: hadoop.py
View license
@contextlib.contextmanager
def install_hadoop(ctx, config):
    testdir = teuthology.get_testdir(ctx)

    log.info("Downloading Hadoop...")
    hadoop_tarball = "{tdir}/hadoop.tar.gz".format(tdir=testdir)
    hadoops = ctx.cluster.only(is_hadoop_type(''))
    run.wait(
        hadoops.run(
            args = [
                'wget',
                '-nv',
                '-O',
                hadoop_tarball,
                HADOOP_2x_URL
            ],
            wait = False,
            )
        )

    log.info("Create directory for Hadoop install...")
    hadoop_dir = "{tdir}/hadoop".format(tdir=testdir)
    run.wait(
        hadoops.run(
            args = [
                'mkdir',
                hadoop_dir
            ],
            wait = False,
            )
        )

    log.info("Unpacking Hadoop...")
    run.wait(
        hadoops.run(
            args = [
                'tar',
                'xzf',
                hadoop_tarball,
                '--strip-components=1',
                '-C',
                hadoop_dir
            ],
            wait = False,
            )
        )

    log.info("Removing Hadoop download...")
    run.wait(
        hadoops.run(
            args = [
                'rm',
                hadoop_tarball
            ],
            wait = False,
            )
        )

    log.info("Create Hadoop temporary directory...")
    hadoop_tmp_dir = "{tdir}/hadoop_tmp".format(tdir=testdir)
    run.wait(
        hadoops.run(
            args = [
                'mkdir',
                hadoop_tmp_dir
            ],
            wait = False,
            )
        )

    if not config.get('hdfs', False):
        log.info("Fetching cephfs-hadoop...")

        sha1, url = teuthology.get_ceph_binary_url(
                package = "hadoop",
                format = "jar",
                dist = "precise",
                arch = "x86_64",
                flavor = "basic",
                branch = "master")

        run.wait(
            hadoops.run(
                args = [
                    'wget',
                    '-nv',
                    '-O',
                    "{tdir}/cephfs-hadoop.jar".format(tdir=testdir), # FIXME
                    url + "/cephfs-hadoop-0.80.6.jar", # FIXME
                ],
                wait = False,
                )
            )

        run.wait(
            hadoops.run(
                args = [
                    'mv',
                    "{tdir}/cephfs-hadoop.jar".format(tdir=testdir),
                    "{tdir}/hadoop/share/hadoop/common/".format(tdir=testdir),
                ],
                wait = False,
                )
            )

        # Copy JNI native bits. Need to do this explicitly because the
        # handling is dependent on the os-type.
        for remote in hadoops.remotes:
            libcephfs_jni_path = None
            if remote.os.package_type == 'rpm':
                libcephfs_jni_path = "/usr/lib64/libcephfs_jni.so.1.0.0"
            elif remote.os.package_type == 'deb':
                libcephfs_jni_path = "/usr/lib/jni/libcephfs_jni.so"
            else:
                raise UnsupportedPackageTypeError(remote)

            libcephfs_jni_fname = "libcephfs_jni.so"
            remote.run(
                args = [
                    'cp',
                    libcephfs_jni_path,
                    "{tdir}/hadoop/lib/native/{fname}".format(tdir=testdir,
                        fname=libcephfs_jni_fname),
                ])

        run.wait(
            hadoops.run(
                args = [
                    'cp',
                    "/usr/share/java/libcephfs.jar",
                    "{tdir}/hadoop/share/hadoop/common/".format(tdir=testdir),
                ],
                wait = False,
                )
            )

    configure(ctx, config, hadoops)

    try:
        yield
    finally:
        run.wait(
            hadoops.run(
                args = [
                    'rm',
                    '-rf',
                    hadoop_dir,
                    hadoop_tmp_dir
                ],
                wait = False,
                )
            )

Example 189

Project: sqlalchemy-utils
Source File: mock.py
View license
@contextlib.contextmanager
def mock_engine(engine, stream=None):
    """Mocks out the engine specified in the passed bind expression.

    Note this function is meant for convenience and protected usage. Do NOT
    blindly pass user input to this function as it uses exec.

    :param engine: A python expression that represents the engine to mock.
    :param stream: Render all DDL operations to the stream.
    """

    # Create a stream if not present.

    if stream is None:
        stream = six.moves.cStringIO()

    # Navigate the stack and find the calling frame that allows the
    # expression to execuate.

    for frame in inspect.stack()[1:]:

        try:
            frame = frame[0]
            expression = '__target = %s' % engine
            six.exec_(expression, frame.f_globals, frame.f_locals)
            target = frame.f_locals['__target']
            break

        except:
            pass

    else:

        raise ValueError('Not a valid python expression', engine)

    # Evaluate the expression and get the target engine.

    frame.f_locals['__mock'] = create_mock_engine(target, stream)

    # Replace the target with our mock.

    six.exec_('%s = __mock' % engine, frame.f_globals, frame.f_locals)

    # Give control back.

    yield stream

    # Put the target engine back.

    frame.f_locals['__target'] = target
    six.exec_('%s = __target' % engine, frame.f_globals, frame.f_locals)
    six.exec_('del __target', frame.f_globals, frame.f_locals)
    six.exec_('del __mock', frame.f_globals, frame.f_locals)

Example 190

Project: gaupol
Source File: util.py
View license
@contextlib.contextmanager
def atomic_open(path, mode="w", *args, **kwargs):
    """
    A context manager for atomically writing a file.

    The file is written to a temporary file on the same filesystem, flushed and
    fsynced and then renamed to replace the existing file. This should
    (probably) be atomic on any Unix system. On Windows, it should (probably)
    be atomic if using Python 3.3 or greater.
    """
    path = os.path.realpath(path)
    chars = list("abcdefghijklmnopqrstuvwxyz0123456789")
    directory = os.path.dirname(path)
    basename = os.path.basename(path)
    while True:
        # Let's use a hidden temporary file to avoid a file
        # flickering in a possibly open file browser window.
        suffix = "".join(random.sample(chars, 8))
        temp_basename = ".{}.tmp{}".format(basename, suffix)
        temp_path = os.path.join(directory, temp_basename)
        if not os.path.isfile(temp_path): break
    try:
        if os.path.isfile(path):
            # If the file exists, use the same permissions.
            # Note that all other file metadata, including
            # owner and group, is not preserved.
            with open(temp_path, "w") as f: pass
            st = os.stat(path)
            os.chmod(temp_path, stat.S_IMODE(st.st_mode))
        with open(temp_path, mode, *args, **kwargs) as f:
            yield f
            f.flush()
            os.fsync(f.fileno())
        try:
            if hasattr(os, "replace"):
                # os.replace was added in Python 3.3.
                # This should be atomic on Windows too.
                os.replace(temp_path, path)
            else:
                # os.rename is atomic on Unix, but fails
                # on Windows if the file exists.
                os.rename(temp_path, path)
            # os.rename and os.replace will fail if path
            # and temp_path are not on the same device,
            # for instance they can be on two separate
            # branches of a union mount. Atomicity is not
            # possible in this case.
        except OSError:
            # Fall back to a non-atomic operation using
            # shutil.move. On Windows this requires that
            # the destination file does not exist.
            if sys.platform == "win32":
                if os.path.isfile(path):
                    os.remove(path)
            shutil.move(temp_path, path)
    finally:
        with silent(Exception):
            os.remove(temp_path)

Example 191

Project: ganeti
Source File: instance_helpervm.py
View license
@contextlib.contextmanager
def HelperVM(lu, instance, vm_image, startup_timeout, vm_timeout,
             log_prefix=None, feedback_fn=None):
  """Runs a given helper VM for a given instance.

  @type lu: L{LogicalUnit}
  @param lu: the lu on whose behalf we execute
  @type instance: L{objects.Instance}
  @param instance: the instance definition
  @type vm_image: string
  @param vm_image: the name of the helper VM image to dump on a temporary disk
  @type startup_timeout: int
  @param startup_timeout: how long to wait for the helper VM to start up
  @type vm_timeout: int
  @param vm_timeout: how long to wait for the helper VM to finish its work
  @type log_prefix: string
  @param log_prefix: a prefix for all log messages
  @type feedback_fn: function
  @param feedback_fn: Function used to log progress

  """
  if log_prefix:
    add_prefix = lambda msg: "%s: %s" % (log_prefix, msg)
  else:
    add_prefix = lambda msg: msg

  if feedback_fn is not None:
    log_feedback = lambda msg: feedback_fn(add_prefix(msg))
  else:
    log_feedback = lambda _: None

  try:
    disk_size = DetermineImageSize(lu, vm_image, instance.primary_node)
  except errors.OpExecError, err:
    raise errors.OpExecError("Could not create temporary disk: %s", err)

  with TemporaryDisk(lu,
                     instance,
                     [(constants.DT_PLAIN, constants.DISK_RDWR, disk_size)],
                     log_feedback):
    log_feedback("Activating helper VM's temporary disks")
    StartInstanceDisks(lu, instance, False)

    log_feedback("Imaging temporary disks with image %s" % (vm_image, ))
    ImageDisks(lu, instance, vm_image)

    log_feedback("Starting helper VM")
    result = lu.rpc.call_instance_start(instance.primary_node,
                                        (instance, [], []),
                                        False, lu.op.reason)
    result.Raise(add_prefix("Could not start helper VM with image %s" %
                            (vm_image, )))

    # First wait for the instance to start up
    running_check = lambda: IsInstanceRunning(lu, instance, prereq=False)
    instance_up = retry.SimpleRetry(True, running_check, 5.0,
                                    startup_timeout)
    if not instance_up:
      raise errors.OpExecError(add_prefix("Could not boot instance using"
                                          " image %s" % (vm_image, )))

    log_feedback("Helper VM is up")

    def cleanup():
      log_feedback("Waiting for helper VM to finish")

      # Then for it to be finished, detected by its shutdown
      instance_up = retry.SimpleRetry(False, running_check, 20.0, vm_timeout)
      if instance_up:
        lu.LogWarning(add_prefix("Helper VM has not finished within the"
                                 " timeout; shutting it down forcibly"))
        return \
          lu.rpc.call_instance_shutdown(instance.primary_node,
                                        instance,
                                        constants.DEFAULT_SHUTDOWN_TIMEOUT,
                                        lu.op.reason)
      else:
        return None

    # Run the inner block and handle possible errors
    try:
      yield
    except Exception:
      # if the cleanup failed for some reason, log it and just re-raise
      result = cleanup()
      if result:
        result.Warn(add_prefix("Could not shut down helper VM with image"
                               " %s within timeout" % (vm_image, )))
        log_feedback("Error running helper VM with image %s" %
                     (vm_image, ))
      raise
    else:
      result = cleanup()
      # if the cleanup failed for some reason, throw an exception
      if result:
        result.Raise(add_prefix("Could not shut down helper VM with image %s"
                                " within timeout" % (vm_image, )))
        raise errors.OpExecError("Error running helper VM with image %s" %
                                 (vm_image, ))

  log_feedback("Helper VM execution completed")

Example 192

Project: pyperclip
Source File: windows.py
View license
def init_windows_clipboard():
    from ctypes.wintypes import (HGLOBAL, LPVOID, DWORD, LPCSTR, INT, HWND,
                                 HINSTANCE, HMENU, BOOL, UINT, HANDLE)

    windll = ctypes.windll

    safeCreateWindowExA = CheckedCall(windll.user32.CreateWindowExA)
    safeCreateWindowExA.argtypes = [DWORD, LPCSTR, LPCSTR, DWORD, INT, INT,
                                    INT, INT, HWND, HMENU, HINSTANCE, LPVOID]
    safeCreateWindowExA.restype = HWND

    safeDestroyWindow = CheckedCall(windll.user32.DestroyWindow)
    safeDestroyWindow.argtypes = [HWND]
    safeDestroyWindow.restype = BOOL

    OpenClipboard = windll.user32.OpenClipboard
    OpenClipboard.argtypes = [HWND]
    OpenClipboard.restype = BOOL

    safeCloseClipboard = CheckedCall(windll.user32.CloseClipboard)
    safeCloseClipboard.argtypes = []
    safeCloseClipboard.restype = BOOL

    safeEmptyClipboard = CheckedCall(windll.user32.EmptyClipboard)
    safeEmptyClipboard.argtypes = []
    safeEmptyClipboard.restype = BOOL

    safeGetClipboardData = CheckedCall(windll.user32.GetClipboardData)
    safeGetClipboardData.argtypes = [UINT]
    safeGetClipboardData.restype = HANDLE

    safeSetClipboardData = CheckedCall(windll.user32.SetClipboardData)
    safeSetClipboardData.argtypes = [UINT, HANDLE]
    safeSetClipboardData.restype = HANDLE

    safeGlobalAlloc = CheckedCall(windll.kernel32.GlobalAlloc)
    safeGlobalAlloc.argtypes = [UINT, c_size_t]
    safeGlobalAlloc.restype = HGLOBAL

    safeGlobalLock = CheckedCall(windll.kernel32.GlobalLock)
    safeGlobalLock.argtypes = [HGLOBAL]
    safeGlobalLock.restype = LPVOID

    safeGlobalUnlock = CheckedCall(windll.kernel32.GlobalUnlock)
    safeGlobalUnlock.argtypes = [HGLOBAL]
    safeGlobalUnlock.restype = BOOL

    GMEM_MOVEABLE = 0x0002
    CF_UNICODETEXT = 13

    @contextlib.contextmanager
    def window():
        """
        Context that provides a valid Windows hwnd.
        """
        # we really just need the hwnd, so setting "STATIC"
        # as predefined lpClass is just fine.
        hwnd = safeCreateWindowExA(0, b"STATIC", None, 0, 0, 0, 0, 0,
                                   None, None, None, None)
        try:
            yield hwnd
        finally:
            safeDestroyWindow(hwnd)

    @contextlib.contextmanager
    def clipboard(hwnd):
        """
        Context manager that opens the clipboard and prevents
        other applications from modifying the clipboard content.
        """
        # We may not get the clipboard handle immediately because
        # some other application is accessing it (?)
        # We try for at least 500ms to get the clipboard.
        t = time.time() + 0.5
        success = False
        while time.time() < t:
            success = OpenClipboard(hwnd)
            if success:
                break
            time.sleep(0.01)
        if not success:
            raise PyperclipWindowsException("Error calling OpenClipboard")

        try:
            yield
        finally:
            safeCloseClipboard()

    def copy_windows(text):
        # This function is heavily based on
        # http://msdn.com/ms649016#_win32_Copying_Information_to_the_Clipboard
        with window() as hwnd:
            # http://msdn.com/ms649048
            # If an application calls OpenClipboard with hwnd set to NULL,
            # EmptyClipboard sets the clipboard owner to NULL;
            # this causes SetClipboardData to fail.
            # => We need a valid hwnd to copy something.
            with clipboard(hwnd):
                safeEmptyClipboard()

                if text:
                    # http://msdn.com/ms649051
                    # If the hMem parameter identifies a memory object,
                    # the object must have been allocated using the
                    # function with the GMEM_MOVEABLE flag.
                    count = len(text) + 1
                    handle = safeGlobalAlloc(GMEM_MOVEABLE,
                                             count * sizeof(c_wchar))
                    locked_handle = safeGlobalLock(handle)

                    ctypes.memmove(c_wchar_p(locked_handle), c_wchar_p(text), count * sizeof(c_wchar))

                    safeGlobalUnlock(handle)
                    safeSetClipboardData(CF_UNICODETEXT, handle)

    def paste_windows():
        with clipboard(None):
            handle = safeGetClipboardData(CF_UNICODETEXT)
            if not handle:
                # GetClipboardData may return NULL with errno == NO_ERROR
                # if the clipboard is empty.
                # (Also, it may return a handle to an empty buffer,
                # but technically that's not empty)
                return ""
            return c_wchar_p(handle).value

    return copy_windows, paste_windows

Example 193

Project: entropy
Source File: multifetch.py
View license
    def _download_files(self, url_data, resume = True, repository_id = None):
        """
        Effectively fetch the package files.
        """
        self._setup_url_directories(url_data)

        @contextlib.contextmanager
        def download_context(path):
            lock = None
            try:
                lock = self.path_lock(path)
                with lock.exclusive():
                    yield  # hooks running inside here
            finally:
                if lock is not None:
                    lock.close()

        # set of paths that have been verified and don't need any
        # firther match_checksum() call.
        validated_download_ids_lock = threading.Lock()
        validated_download_ids = set()

        # Note: the following two hooks are running in separate threads.

        def pre_download_hook(path, download_id):
            path_data = url_data[download_id - 1]
            (_hook_package_id, hook_repository_id, _hook_url,
             hook_download_path, hook_cksum, hook_signs) = path_data

            if self._stat_path(hook_download_path):
                verify_st = self._match_checksum(
                    hook_download_path,
                    hook_repository_id,
                    hook_cksum,
                    hook_signs)
                if verify_st == 0:
                    # UrlFetcher returns the md5 checksum on success
                    with validated_download_ids_lock:
                        validated_download_ids.add(download_id)
                    return hook_cksum

            # request the download
            return None

        def post_download_hook(_path, _status, download_id):
            path_data = url_data[download_id - 1]
            (_hook_package_id, hook_repository_id, _hook_url,
             hook_download_path, hook_cksum, hook_signs) = path_data

            with validated_download_ids_lock:
                if hook_download_path in validated_download_ids:
                    # nothing to check, path already verified
                    return

            if not self._stat_path(hook_download_path):
                return

            verify_st = self._match_checksum(
                hook_download_path,
                hook_repository_id,
                hook_cksum,
                hook_signs)
            if verify_st == 0:
                with validated_download_ids_lock:
                    validated_download_ids.add(download_id)

        url_path_list = []
        last_repos_id = None
        for pkg_id, repository_id, url, download_path, _cksum, _sig in url_data:
            url_path_list.append((url, download_path))

            lock = None
            try:
                # hold a lock against the download path, like fetch.py does.
                lock = self.path_lock(download_path)
                with lock.exclusive():

                    self._setup_differential_download(
                        self._entropy._multiple_url_fetcher, url,
                        resume, download_path,
                        repository_id, pkg_id)

                    # This is horrible but probably
                    # MultipleUrlFetcher must be reimplemented.
                    last_repos_id = repository_id

            finally:
                if lock is not None:
                    lock.close()

        avail_data = self._settings['repositories']['available']
        repo_data = avail_data[last_repos_id]
        basic_user = repo_data.get('username')
        basic_pwd = repo_data.get('password')
        https_validate_cert = not repo_data.get('https_validate_cert') == "false"

        fetch_abort_function = self._meta.get('fetch_abort_function')
        fetch_intf = self._entropy._multiple_url_fetcher(
            url_path_list, resume = resume,
            abort_check_func = fetch_abort_function,
            url_fetcher_class = self._entropy._url_fetcher,
            download_context_func = download_context,
            pre_download_hook = pre_download_hook,
            post_download_hook = post_download_hook,
            http_basic_user = basic_user,
            http_basic_pwd = basic_pwd,
            https_validate_cert = https_validate_cert)
        try:
            # make sure that we don't need to abort already
            # doing the check here avoids timeouts
            if fetch_abort_function != None:
                fetch_abort_function()

            data = fetch_intf.download()
        except (KeyboardInterrupt, InterruptError):
            return -100, {}, 0

        failed_map = {}
        for download_id, tup in enumerate(url_data, 1):

            if download_id in validated_download_ids:
                # valid, nothing to do
                continue

            (_pkg_id, repository_id, _url,
             _download_path, _ignore_checksum, signatures) = tup

            # use the outcome returned by download(), it
            # contains an error code if download failed.
            val = data.get(download_id)
            failed_map[url_path_list[download_id - 1][0]] = (
                val, signatures)

        exit_st = 0
        if failed_map:
            exit_st = -1
        # determine if we got a -100, KeyboardInterrupt
        for _key, (val, _signs) in tuple(failed_map.items()):
            if val == -100:
                exit_st = -100
                break

        return exit_st, failed_map, fetch_intf.get_transfer_rate()

Example 194

Project: bsd-cloudinit
Source File: testutils.py
View license
    @contextlib.contextmanager
    def assert_raises_windows_message(
            self, expected_msg, error_code,
            exc=exception.WindowsCloudbaseInitException):
        """Helper method for testing raised error messages

        This assert method is similar to :meth:`~assertRaises`, but
        it can only be used as a context manager. It will check that the
        block of the with statement raises an exception of type :class:`exc`,
        having as message the result of the interpolation between
        `expected_msg` and a formatted string, obtained through
        `ctypes.FormatError(error_code)`.
        """
        # Can't use the decorator form, since it will not be properly set
        # after the function passes control with the `yield` (so the
        # with statement block will have the original value, not the
        # mocked one).

        with self.assertRaises(exc) as cm:
            with mock.patch('cloudbaseinit.exception.'
                            'ctypes.FormatError',
                            create=True) as mock_format_error:
                with mock.patch('cloudbaseinit.exception.ctypes.'
                                'GetLastError',
                                create=True) as mock_get_last_error:
                    mock_format_error.return_value = "description"
                    yield

        if mock_get_last_error.called:
            # This can be called when the error code is not given,
            # but we don't have control over that, so test that
            # it's actually called only once.
            mock_get_last_error.assert_called_once_with()
            mock_format_error.assert_called_once_with(
                mock_get_last_error.return_value)
        else:
            mock_format_error.assert_called_once_with(error_code)

        expected_msg = expected_msg % mock_format_error.return_value
        self.assertEqual(expected_msg, cm.exception.args[0])

Example 195

Project: cclyzer
Source File: logging_utils.py
View license
@contextlib.contextmanager
def setup_logging(lvl=logging.INFO):
    """Configure logging.

    This creates a file handler to the application cache and a syslog
    handler. It also sets up log formatting and level.

    """

    # Get runtime environment and settings
    env = runtime.Environment()
    app_name = settings.APP_NAME

    # Get root logger and set its logging level
    root_logger = logging.getLogger()
    root_logger.setLevel(lvl)

    # Add rotating file handler
    file_log = env.user_log_file
    file_formatter = logging.Formatter("[PID %(process)d - %(asctime)s %(levelname)5.5s] "
                                       "%(pathname)s: Line %(lineno)d: %(message)s")
    file_handler = RotatingFileHandler(file_log, maxBytes=(2 ** 20), backupCount=7)
    file_handler.setFormatter(file_formatter)
    root_logger.addHandler(file_handler)

    # Add system log handler
    try:
        syslog_formatter = logging.Formatter("[{}] %(message)s".format(app_name))
        syslog_handler = SysLogHandler(address='/dev/log')
        syslog_handler.setFormatter(syslog_formatter)
        root_logger.addHandler(syslog_handler)
    except:
        root_logger.warning('Cannot add system logger')

    # Add stderr handler
    stderr_formatter = ConsoleFormatter("%(levelname)s (%(name)s): %(message)s")
    stderr_handler = StreamHandler(stream=sys.stderr)
    stderr_handler.setFormatter(stderr_formatter)
    stderr_handler.setLevel(logging.WARNING)
    root_logger.addHandler(stderr_handler)

    # Start executing task
    yield root_logger

Example 196

Project: pre-commit
Source File: staged_files_only.py
View license
@contextlib.contextmanager
def staged_files_only(cmd_runner):
    """Clear any unstaged changes from the git working directory inside this
    context.

    Args:
        cmd_runner - PrefixedCommandRunner
    """
    # Determine if there are unstaged files
    retcode, diff_stdout_binary, _ = cmd_runner.run(
        [
            'git', 'diff', '--ignore-submodules', '--binary', '--exit-code',
            '--no-color', '--no-ext-diff',
        ],
        retcode=None,
        encoding=None,
    )
    if retcode and diff_stdout_binary.strip():
        patch_filename = cmd_runner.path('patch{}'.format(int(time.time())))
        logger.warning('Unstaged files detected.')
        logger.info(
            'Stashing unstaged files to {}.'.format(patch_filename),
        )
        # Save the current unstaged changes as a patch
        with io.open(patch_filename, 'wb') as patch_file:
            patch_file.write(diff_stdout_binary)

        # Clear the working directory of unstaged changes
        cmd_runner.run(['git', 'checkout', '--', '.'])
        try:
            yield
        finally:
            # Try to apply the patch we saved
            try:
                cmd_runner.run(('git', 'apply', patch_filename), encoding=None)
            except CalledProcessError:
                logger.warning(
                    'Stashed changes conflicted with hook auto-fixes... '
                    'Rolling back fixes...'
                )
                # We failed to apply the patch, presumably due to fixes made
                # by hooks.
                # Roll back the changes made by hooks.
                cmd_runner.run(['git', 'checkout', '--', '.'])
                cmd_runner.run(('git', 'apply', patch_filename), encoding=None)
            logger.info('Restored changes from {}.'.format(patch_filename))
    else:
        # There weren't any staged files so we don't need to do anything
        # special
        yield

Example 197

Project: sqlalchemy-utils
Source File: mock.py
View license
@contextlib.contextmanager
def mock_engine(engine, stream=None):
    """Mocks out the engine specified in the passed bind expression.

    Note this function is meant for convenience and protected usage. Do NOT
    blindly pass user input to this function as it uses exec.

    :param engine: A python expression that represents the engine to mock.
    :param stream: Render all DDL operations to the stream.
    """

    # Create a stream if not present.

    if stream is None:
        stream = six.moves.cStringIO()

    # Navigate the stack and find the calling frame that allows the
    # expression to execuate.

    for frame in inspect.stack()[1:]:

        try:
            frame = frame[0]
            expression = '__target = %s' % engine
            six.exec_(expression, frame.f_globals, frame.f_locals)
            target = frame.f_locals['__target']
            break

        except:
            pass

    else:

        raise ValueError('Not a valid python expression', engine)

    # Evaluate the expression and get the target engine.

    frame.f_locals['__mock'] = create_mock_engine(target, stream)

    # Replace the target with our mock.

    six.exec_('%s = __mock' % engine, frame.f_globals, frame.f_locals)

    # Give control back.

    yield stream

    # Put the target engine back.

    frame.f_locals['__target'] = target
    six.exec_('%s = __target' % engine, frame.f_globals, frame.f_locals)
    six.exec_('del __target', frame.f_globals, frame.f_locals)
    six.exec_('del __mock', frame.f_globals, frame.f_locals)

Example 198

Project: geofront-cli
Source File: client.py
View license
    @contextlib.contextmanager
    def request(self, method, url, data=None, headers={}):
        if isinstance(url, tuple):
            url = './{0}/'.format('/'.join(url))
        url = urljoin(self.server_url, url)
        h = {
            'User-Agent': 'geofront-cli/{0} (Python-urllib/{1})'.format(
                VERSION, sys.version[:3]
            ),
            'Accept': 'application/json'
        }
        h.update(headers)
        request = Request(url, method=method, data=data, headers=h)
        try:
            response = urlopen(request)
        except HTTPError as e:
            response = e
        server_version = response.headers.get('X-Geofront-Version')
        if server_version:
            try:
                server_version_info = tuple(
                    map(int, server_version.strip().split('.'))
                )
            except ValueError:
                raise ProtocolVersionError(
                    None,
                    'the protocol version number the server sent is not '
                    'a valid format: ' + repr(server_version)
                )
            else:
                if not (MIN_PROTOCOL_VERSION <=
                        server_version_info <=
                        MAX_PROTOCOL_VERSION):
                    raise ProtocolVersionError(
                        server_version_info,
                        'the server protocol version ({0}) is '
                        'incompatible'.format(server_version)
                    )
        else:
            raise ProtocolVersionError(
                None,
                'the server did not send the protocol version '
                '(X-Geofront-Version)'
            )
        mimetype, _ = parse_mimetype(response.headers['Content-Type'])
        if mimetype == 'application/json' and 400 <= response.code < 500:
            read = response.read()
            body = json.loads(read.decode('utf-8'))
            response.close()
            error = isinstance(body, dict) and body.get('error')
            if response.code == 404 and error == 'token-not-found' or \
               response.code == 410 and error == 'expired-token':
                raise ExpiredTokenIdError('token id seems expired')
            elif response.code == 412 and error == 'unfinished-authentication':
                raise UnfinishedAuthenticationError(body['message'])
            buffered = BufferedResponse(response.code, response.headers, read)
            yield buffered
            buffered.close()
            return
        yield response
        response.close()

Example 199

Project: teuthology
Source File: install.py
View license
@contextlib.contextmanager
def task(ctx, config):
    """
    Install packages for a given project.

    tasks:
    - install:
        project: ceph
        branch: bar
    - install:
        project: samba
        branch: foo
        extra_packages: ['samba']
    - install:
        rhbuild: 1.3.0
        playbook: downstream_setup.yml
        vars:
           yum_repos:
             - url: "http://location.repo"
               name: "ceph_repo"

    Overrides are project specific:

    overrides:
      install:
        ceph:
          sha1: ...


    Debug packages may optionally be installed:

    overrides:
      install:
        ceph:
          debuginfo: true


    Default package lists (which come from packages.yaml) may be overridden:

    overrides:
      install:
        ceph:
          packages:
            deb:
            - ceph-osd
            - ceph-mon
            rpm:
            - ceph-devel
            - rbd-fuse

    When tag, branch and sha1 do not reference the same commit hash, the
    tag takes precedence over the branch and the branch takes precedence
    over the sha1.

    When the overrides have a sha1 that is different from the sha1 of
    the project to be installed, it will be a noop if the project has
    a branch or tag, because they take precedence over the sha1. For
    instance:

    overrides:
      install:
        ceph:
          sha1: 1234

    tasks:
    - install:
        project: ceph
          sha1: 4567
          branch: foobar # which has sha1 4567

    The override will transform the tasks as follows:

    tasks:
    - install:
        project: ceph
          sha1: 1234
          branch: foobar # which has sha1 4567

    But the branch takes precedence over the sha1 and foobar
    will be installed. The override of the sha1 has no effect.

    When passed 'rhbuild' as a key, it will attempt to install an rh ceph build using ceph-deploy

    Reminder regarding teuthology-suite side effects:

    The teuthology-suite command always adds the following:

    overrides:
      install:
        ceph:
          sha1: 1234

    where sha1 matches the --ceph argument. For instance if
    teuthology-suite is called with --ceph master, the sha1 will be
    the tip of master. If called with --ceph v0.94.1, the sha1 will be
    the v0.94.1 (as returned by git rev-parse v0.94.1 which is not to
    be confused with git rev-parse v0.94.1^{commit})

    :param ctx: the argparse.Namespace object
    :param config: the config dict
    """
    if config is None:
        config = {}
    assert isinstance(config, dict), \
        "task install only supports a dictionary for configuration"

    project, = config.get('project', 'ceph'),
    log.debug('project %s' % project)
    overrides = ctx.config.get('overrides')
    if overrides:
        install_overrides = overrides.get('install', {})
        teuthology.deep_merge(config, install_overrides.get(project, {}))
    log.debug('config %s' % config)

    rhbuild = None
    if config.get('rhbuild'):
        rhbuild = config.get('rhbuild')
        log.info("Build is %s " % rhbuild)

    flavor = get_flavor(config)
    log.info("Using flavor: %s", flavor)

    ctx.summary['flavor'] = flavor
    nested_tasks = [lambda: rh_install(ctx=ctx, config=config),
                    lambda: ship_utilities(ctx=ctx, config=None)]

    if config.get('rhbuild'):
        if config.get('playbook'):
            ansible_config=dict(config)
            # remove key not required by ansible task
            del ansible_config['rhbuild']
            nested_tasks.insert(0, lambda: ansible.CephLab(ctx,config=ansible_config))
        with contextutil.nested(*nested_tasks):
                yield
    else:
        with contextutil.nested(
            lambda: install(ctx=ctx, config=dict(
                branch=config.get('branch'),
                tag=config.get('tag'),
                sha1=config.get('sha1'),
                debuginfo=config.get('debuginfo'),
                flavor=flavor,
                extra_packages=config.get('extra_packages', []),
                exclude_packages=config.get('exclude_packages', []),
                extras=config.get('extras', None),
                wait_for_package=config.get('wait_for_package', False),
                project=project,
                packages=config.get('packages', dict()),
            )),
            lambda: ship_utilities(ctx=ctx, config=None),
        ):
            yield

Example 200

Project: pgi
Source File: test_pgi_codegen_backends.py
View license
@contextlib.contextmanager
def executor(backend, type_):
    """Compiles and executes the generated code on demand and takes and
    returns real values::

        with executor(backend, GITypeTag.INT32) as var:
            self.assertEqual(var.pack_in(4), 4)
            self.assertEqual(var.unpack_out(var.pack_out(42)), 42)

    """

    class Compiler(object):
        def __init__(self, backend, var, key):
            self.backend = backend
            self.var = var
            self.key = key

        def __call__(self, *args):
            arg_names = [self.backend.var() for arg in args]
            out_names = getattr(self.var, self.key)(*arg_names)
            if not isinstance(out_names, (list, tuple)):
                out_names = [out_names]

            block, var = backend.parse("""
def func($args):
    $body
    return $out
""", args=", ".join(arg_names), body=self.var.block, out=", ".join(out_names))

            self.var.block.clear()
            return block.compile()["func"](*args)

    class VarWrapper(object):

        def __init__(self, backend, var):
            self.backend = backend
            self.var = var

        def __getattr__(self, key):
            return Compiler(self.backend, self.var, key)

    yield VarWrapper(backend, backend.get_type(FakeTypeInfo(type_)))