six.iteritems

Here are the examples of the python api six.iteritems taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: cleaREST
Source File: core.py
View license
def application(environ, start_response):
    def parse_content_type(value):
        parts = value.split(";")
        return (parts[0], {k.lstrip(): v for k, v in
                           (kv.split("=") for kv in parts[1:])}) if len(parts) > 1 else (value, {})

    def parse_www_form(input_file, n, extras):
        return parse_qs(input_file.read(n))

    def parse_form_data(input_file, n, extras):
        assert "boundary" in extras
        kwargs = {}
        state = 0  # TODO: enum
        name = None
        for line in input_file.read(n).splitlines():
            if state == 0 and line == extras["boundary"]:
                state = 1
            elif state == 1 and line.startswith(CONTENT_DISPOSITION):
                name = re.match(CONTENT_DISPOSITION + ": form\-data; name=\"(.*)\"", line).group(1)
                state = 2
            elif state == 2 and not len(line):
                state = 3
            elif state == 3:
                kwargs[name] = [line]
                state = 0
        return kwargs

    def parse_json(input_file, n, extras):
        encoding = "utf-8" if "encoding" not in extras else extras["encoding"]
        return {k: [v] for k, v in six.iteritems(json.loads(input_file.read(n), encoding=encoding))}

    def parse_accept():
        if HTTP_ACCEPT in environ:
            result = []
            parts = (x.split(";") for x in environ[HTTP_ACCEPT].split(","))
            for part in parts:
                if len(part) == 1:
                    result.append((1.0, part[0]))
                else:
                    mime, q = part
                    result.append((float(q[2:]), mime))
            return tuple(value for (weight, value) in sorted(result, key=lambda x: x[0], reverse=True))
        else:
            return {}

    content_types = {MIME_WWW_FORM_URLENCODED: parse_www_form,
                     MIME_FORM_DATA: parse_form_data,
                     MIME_JSON: parse_json}
    specials = {ACCEPT_MIMES: parse_accept}
    try:
        if environ[REQUEST_METHOD] == HTTP_GET and environ[PATH_INFO] in _static_files:
            content_type, content = _static_files[environ[PATH_INFO]]
            start_response(STATUS_FMT.format(*HTTP_OK), [(CONTENT_TYPE, content_type)])
            return [content]
        elif environ[REQUEST_METHOD] in all_registered():
            path = tuple(environ[PATH_INFO][1:].split("/"))
            query = parse_qs(environ[QUERY_STRING]) if QUERY_STRING in environ else {}
            if (
                                environ[REQUEST_METHOD] == HTTP_GET and
                                MIME_XHTML_XML in parse_accept() and
                            path == ("",)
            ):
                start_response(STATUS_FMT.format(*HTTP_OK), [(CONTENT_TYPE, MIME_TEXT_HTML)])
                for method in six.iterkeys(BaseDecorator.registered):
                    all_ = [(desc, method, signature_to_path(signature), fn.__doc__)
                            for signature, (fn, args, status, desc) in
                            six.iteritems(BaseDecorator.registered[method])]
                    return [generate_index(all_).encode("utf-8")]
            elif (
                                WSGI_CONTENT_TYPE in environ and
                            environ[WSGI_CONTENT_TYPE] and
                            environ[WSGI_CONTENT_TYPE] != MIME_TEXT_PLAIN
            ):
                content_type, extras_ = parse_content_type(environ[WSGI_CONTENT_TYPE])
                if content_type not in content_types:
                    raise HttpUnsupportedMediaType()
                query.update(content_types[content_type](environ[WSGI_INPUT],
                                                         int(environ[WSGI_CONTENT_LENGTH]),
                                                         extras_))
            for signature, (fn, args, status, desc) in six.iteritems(BaseDecorator.registered[environ[REQUEST_METHOD]]):
                if is_matching(signature, args, path, query):
                    try:
                        updated_query = query.copy()
                        updated_query.update({key.name: [value]
                                              for key, value in zip(signature, path)
                                              if isinstance(key, Key)})
                        parsed_args = parse_args(args, path, updated_query, specials)
                    except HttpError:
                        raise
                    except Exception as e:
                        logging.exception(e)
                        raise HttpBadRequest()
                    result = fn(**parsed_args)
                    for type_ in six.iterkeys(_content_types):
                        if isinstance(result, type_):
                            content_type, content_handler = _content_types[type_]
                            break
                    else:
                        raise HttpNotImplemented()
                    start_response(STATUS_FMT.format(*status),
                                   [(CONTENT_TYPE, content_type)])
                    return [content_handler(result)]
        raise HttpNotFound()
    except HttpError as error:
        status = STATUS_FMT.format(error.code, error.msg)
        start_response(status, [(CONTENT_TYPE, MIME_TEXT_PLAIN)])
        return [status]

Example 2

Project: cryptography
Source File: utils.py
View license
def load_pkcs1_vectors(vector_data):
    """
    Loads data out of RSA PKCS #1 vector files.
    """
    private_key_vector = None
    public_key_vector = None
    attr = None
    key = None
    example_vector = None
    examples = []
    vectors = []
    for line in vector_data:
        if (
            line.startswith("# PSS Example") or
            line.startswith("# OAEP Example") or
            line.startswith("# PKCS#1 v1.5")
        ):
            if example_vector:
                for key, value in six.iteritems(example_vector):
                    hex_str = "".join(value).replace(" ", "").encode("ascii")
                    example_vector[key] = hex_str
                examples.append(example_vector)

            attr = None
            example_vector = collections.defaultdict(list)

        if line.startswith("# Message"):
            attr = "message"
            continue
        elif line.startswith("# Salt"):
            attr = "salt"
            continue
        elif line.startswith("# Seed"):
            attr = "seed"
            continue
        elif line.startswith("# Signature"):
            attr = "signature"
            continue
        elif line.startswith("# Encryption"):
            attr = "encryption"
            continue
        elif (
            example_vector and
            line.startswith("# =============================================")
        ):
            for key, value in six.iteritems(example_vector):
                hex_str = "".join(value).replace(" ", "").encode("ascii")
                example_vector[key] = hex_str
            examples.append(example_vector)
            example_vector = None
            attr = None
        elif example_vector and line.startswith("#"):
            continue
        else:
            if attr is not None and example_vector is not None:
                example_vector[attr].append(line.strip())
                continue

        if (
            line.startswith("# Example") or
            line.startswith("# =============================================")
        ):
            if key:
                assert private_key_vector
                assert public_key_vector

                for key, value in six.iteritems(public_key_vector):
                    hex_str = "".join(value).replace(" ", "")
                    public_key_vector[key] = int(hex_str, 16)

                for key, value in six.iteritems(private_key_vector):
                    hex_str = "".join(value).replace(" ", "")
                    private_key_vector[key] = int(hex_str, 16)

                private_key_vector["examples"] = examples
                examples = []

                assert (
                    private_key_vector['public_exponent'] ==
                    public_key_vector['public_exponent']
                )

                assert (
                    private_key_vector['modulus'] ==
                    public_key_vector['modulus']
                )

                vectors.append(
                    (private_key_vector, public_key_vector)
                )

            public_key_vector = collections.defaultdict(list)
            private_key_vector = collections.defaultdict(list)
            key = None
            attr = None

        if private_key_vector is None or public_key_vector is None:
            continue

        if line.startswith("# Private key"):
            key = private_key_vector
        elif line.startswith("# Public key"):
            key = public_key_vector
        elif line.startswith("# Modulus:"):
            attr = "modulus"
        elif line.startswith("# Public exponent:"):
            attr = "public_exponent"
        elif line.startswith("# Exponent:"):
            if key is public_key_vector:
                attr = "public_exponent"
            else:
                assert key is private_key_vector
                attr = "private_exponent"
        elif line.startswith("# Prime 1:"):
            attr = "p"
        elif line.startswith("# Prime 2:"):
            attr = "q"
        elif line.startswith("# Prime exponent 1:"):
            attr = "dmp1"
        elif line.startswith("# Prime exponent 2:"):
            attr = "dmq1"
        elif line.startswith("# Coefficient:"):
            attr = "iqmp"
        elif line.startswith("#"):
            attr = None
        else:
            if key is not None and attr is not None:
                key[attr].append(line.strip())
    return vectors

Example 3

Project: iris
Source File: name_loaders.py
View license
def _generate_cubes(header, column_headings, coords, data_arrays,
                    cell_methods=None):
    """
    Yield :class:`iris.cube.Cube` instances given
    the headers, column headings, coords and data_arrays extracted
    from a NAME file.

    """
    for i, data_array in enumerate(data_arrays):
        # Turn the dictionary of column headings with a list of header
        # information for each field into a dictionary of headings for
        # just this field.
        field_headings = {k: v[i] for k, v in six.iteritems(column_headings)}

        # Make a cube.
        cube = iris.cube.Cube(data_array)

        # Determine the name and units.
        name = '{} {}'.format(field_headings['Species'],
                              field_headings['Quantity'])
        name = name.upper().replace(' ', '_')
        cube.rename(name)

        # Some units are not in SI units, are missing spaces or typed
        # in the wrong case. _parse_units returns units that are
        # recognised by Iris.
        cube.units = _parse_units(field_headings['Units'])

        # Define and add the singular coordinates of the field (flight
        # level, time etc.)
        if 'Z' in field_headings:
            upper_bound, = [field_headings['... to [Z]']
                            if '... to [Z]' in field_headings else None]
            lower_bound, = [field_headings['... from [Z]']
                            if '... from [Z]' in field_headings else None]
            z_coord = _cf_height_from_name(field_headings['Z'],
                                           upper_bound=upper_bound,
                                           lower_bound=lower_bound)
            cube.add_aux_coord(z_coord)

        # Define the time unit and use it to serialise the datetime for
        # the time coordinate.
        time_unit = cf_units.Unit(
            'hours since epoch', calendar=cf_units.CALENDAR_GREGORIAN)

        # Build time, height, latitude and longitude coordinates.
        for coord in coords:
            pts = coord.values
            coord_sys = None
            if coord.name == 'latitude' or coord.name == 'longitude':
                coord_units = 'degrees'
                coord_sys = iris.coord_systems.GeogCS(EARTH_RADIUS)
            if coord.name == 'projection_x_coordinate' \
                    or coord.name == 'projection_y_coordinate':
                coord_units = 'm'
                coord_sys = iris.coord_systems.OSGB()
            if coord.name == 'height':
                coord_units = 'm'
                long_name = 'height above ground level'
                pts = coord.values
            if coord.name == 'altitude':
                coord_units = 'm'
                long_name = 'altitude above sea level'
                pts = coord.values
            if coord.name == 'air_pressure':
                coord_units = 'Pa'
                pts = coord.values
            if coord.name == 'flight_level':
                pts = coord.values
                long_name = 'flight_level'
                coord_units = _parse_units('FL')
            if coord.name == 'time':
                coord_units = time_unit
                pts = time_unit.date2num(coord.values)

            if coord.dimension is not None:
                if coord.name == 'longitude':
                    circular = iris.util._is_circular(pts, 360.0)
                else:
                    circular = False
                if coord.name == 'flight_level':
                    icoord = DimCoord(points=pts,
                                      units=coord_units,
                                      long_name=long_name,)
                else:
                    icoord = DimCoord(points=pts,
                                      standard_name=coord.name,
                                      units=coord_units,
                                      coord_system=coord_sys,
                                      circular=circular)
                if coord.name == 'height' or coord.name == 'altitude':
                    icoord.long_name = long_name
                if coord.name == 'time' and 'Av or Int period' in \
                        field_headings:
                    dt = coord.values - \
                        field_headings['Av or Int period']
                    bnds = time_unit.date2num(
                        np.vstack((dt, coord.values)).T)
                    icoord.bounds = bnds
                else:
                    icoord.guess_bounds()
                cube.add_dim_coord(icoord, coord.dimension)
            else:
                icoord = AuxCoord(points=pts[i],
                                  standard_name=coord.name,
                                  coord_system=coord_sys,
                                  units=coord_units)
                if coord.name == 'time' and 'Av or Int period' in \
                        field_headings:
                    dt = coord.values - \
                        field_headings['Av or Int period']
                    bnds = time_unit.date2num(
                        np.vstack((dt, coord.values)).T)
                    icoord.bounds = bnds[i, :]
                cube.add_aux_coord(icoord)

        # Headings/column headings which are encoded elsewhere.
        headings = ['X', 'Y', 'Z', 'Time', 'T', 'Units',
                    'Av or Int period',
                    '... from [Z]', '... to [Z]',
                    'X grid origin', 'Y grid origin',
                    'X grid size', 'Y grid size',
                    'X grid resolution', 'Y grid resolution',
                    'Number of field cols', 'Number of preliminary cols',
                    'Number of fields', 'Number of series',
                    'Output format', ]

        # Add the Main Headings as attributes.
        for key, value in six.iteritems(header):
            if value is not None and value != '' and \
                    key not in headings:
                cube.attributes[key] = value

        # Add the Column Headings as attributes
        for key, value in six.iteritems(field_headings):
            if value is not None and value != '' and \
                    key not in headings:
                cube.attributes[key] = value

        if cell_methods is not None:
            cube.add_cell_method(cell_methods[i])

        yield cube

Example 4

Project: frontera
Source File: hbase.py
View license
    def get_next_requests(self, max_n_requests, partition_id, **kwargs):
        """
        Tries to get new batch from priority queue. It makes self.GET_RETRIES tries and stops, trying to fit all
        parameters. Every new iteration evaluates a deeper batch. After batch is requested it is removed from the queue.

        :param max_n_requests: maximum number of requests
        :param partition_id: partition id to get batch from
        :param min_requests: minimum number of requests
        :param min_hosts: minimum number of hosts
        :param max_requests_per_host: maximum number of requests per host
        :return: list of :class:`Request <frontera.core.models.Request>` objects.
        """
        min_requests = kwargs.pop('min_requests')
        min_hosts = kwargs.pop('min_hosts')
        max_requests_per_host = kwargs.pop('max_requests_per_host')
        assert(max_n_requests > min_requests)
        table = self.connection.table(self.table_name)

        meta_map = {}
        queue = {}
        limit = min_requests
        tries = 0
        count = 0
        prefix = '%d_' % partition_id
        now_ts = int(time())
        filter = "PrefixFilter ('%s') AND SingleColumnValueFilter ('f', 't', <=, 'binary:%d')" % (prefix, now_ts)
        while tries < self.GET_RETRIES:
            tries += 1
            limit *= 5.5 if tries > 1 else 1.0
            self.logger.debug("Try %d, limit %d, last attempt: requests %d, hosts %d",
                              tries, limit, count, len(queue.keys()))
            meta_map.clear()
            queue.clear()
            count = 0
            for rk, data in table.scan(limit=int(limit), batch_size=256, filter=filter):
                for cq, buf in six.iteritems(data):
                    if cq == b'f:t':
                        continue
                    stream = BytesIO(buf)
                    unpacker = Unpacker(stream)
                    for item in unpacker:
                        fprint, host_crc32, _, _ = item
                        if host_crc32 not in queue:
                            queue[host_crc32] = []
                        if max_requests_per_host is not None and len(queue[host_crc32]) > max_requests_per_host:
                            continue
                        queue[host_crc32].append(fprint)
                        count += 1

                        if fprint not in meta_map:
                            meta_map[fprint] = []
                        meta_map[fprint].append((rk, item))
                if count > max_n_requests:
                    break

            if min_hosts is not None and len(queue.keys()) < min_hosts:
                continue

            if count < min_requests:
                continue
            break

        self.logger.debug("Finished: tries %d, hosts %d, requests %d", tries, len(queue.keys()), count)

        # For every fingerprint collect it's row keys and return all fingerprints from them
        fprint_map = {}
        for fprint, meta_list in six.iteritems(meta_map):
            for rk, _ in meta_list:
                fprint_map.setdefault(rk, []).append(fprint)

        results = []
        trash_can = set()

        for _, fprints in six.iteritems(queue):
            for fprint in fprints:
                for rk, _ in meta_map[fprint]:
                    if rk in trash_can:
                        continue
                    for rk_fprint in fprint_map[rk]:
                        _, item = meta_map[rk_fprint][0]
                        _, _, encoded, score = item
                        request = self.decoder.decode_request(encoded)
                        request.meta[b'score'] = score
                        results.append(request)
                    trash_can.add(rk)

        with table.batch(transaction=True) as b:
            for rk in trash_can:
                b.delete(rk)
        self.logger.debug("%d row keys removed", len(trash_can))
        return results

Example 5

Project: routes
Source File: mapper.py
View license
    def resource(self, member_name, collection_name, **kwargs):
        """Generate routes for a controller resource

        The member_name name should be the appropriate singular version
        of the resource given your locale and used with members of the
        collection. The collection_name name will be used to refer to
        the resource collection methods and should be a plural version
        of the member_name argument. By default, the member_name name
        will also be assumed to map to a controller you create.

        The concept of a web resource maps somewhat directly to 'CRUD'
        operations. The overlying things to keep in mind is that
        mapping a resource is about handling creating, viewing, and
        editing that resource.

        All keyword arguments are optional.

        ``controller``
            If specified in the keyword args, the controller will be
            the actual controller used, but the rest of the naming
            conventions used for the route names and URL paths are
            unchanged.

        ``collection``
            Additional action mappings used to manipulate/view the
            entire set of resources provided by the controller.

            Example::

                map.resource('message', 'messages', collection={'rss':'GET'})
                # GET /message/rss (maps to the rss action)
                # also adds named route "rss_message"

        ``member``
            Additional action mappings used to access an individual
            'member' of this controllers resources.

            Example::

                map.resource('message', 'messages', member={'mark':'POST'})
                # POST /message/1/mark (maps to the mark action)
                # also adds named route "mark_message"

        ``new``
            Action mappings that involve dealing with a new member in
            the controller resources.

            Example::

                map.resource('message', 'messages', new={'preview':'POST'})
                # POST /message/new/preview (maps to the preview action)
                # also adds a url named "preview_new_message"

        ``path_prefix``
            Prepends the URL path for the Route with the path_prefix
            given. This is most useful for cases where you want to mix
            resources or relations between resources.

        ``name_prefix``
            Perpends the route names that are generated with the
            name_prefix given. Combined with the path_prefix option,
            it's easy to generate route names and paths that represent
            resources that are in relations.

            Example::

                map.resource('message', 'messages', controller='categories',
                    path_prefix='/category/:category_id',
                    name_prefix="category_")
                # GET /category/7/message/1
                # has named route "category_message"

        ``requirements``

           A dictionary that restricts the matching of a
           variable. Can be used when matching variables with path_prefix.

           Example::

                map.resource('message', 'messages',
                     path_prefix='{project_id}/',
                     requirements={"project_id": R"\d+"})
                # POST /01234/message
                #    success, project_id is set to "01234"
                # POST /foo/message
                #    404 not found, won't be matched by this route


        ``parent_resource``
            A ``dict`` containing information about the parent
            resource, for creating a nested resource. It should contain
            the ``member_name`` and ``collection_name`` of the parent
            resource. This ``dict`` will
            be available via the associated ``Route`` object which can
            be accessed during a request via
            ``request.environ['routes.route']``

            If ``parent_resource`` is supplied and ``path_prefix``
            isn't, ``path_prefix`` will be generated from
            ``parent_resource`` as
            "<parent collection name>/:<parent member name>_id".

            If ``parent_resource`` is supplied and ``name_prefix``
            isn't, ``name_prefix`` will be generated from
            ``parent_resource`` as  "<parent member name>_".

            Example::

                >>> from routes.util import url_for
                >>> m = Mapper()
                >>> m.resource('location', 'locations',
                ...            parent_resource=dict(member_name='region',
                ...                                 collection_name='regions'))
                >>> # path_prefix is "regions/:region_id"
                >>> # name prefix is "region_"
                >>> url_for('region_locations', region_id=13)
                '/regions/13/locations'
                >>> url_for('region_new_location', region_id=13)
                '/regions/13/locations/new'
                >>> url_for('region_location', region_id=13, id=60)
                '/regions/13/locations/60'
                >>> url_for('region_edit_location', region_id=13, id=60)
                '/regions/13/locations/60/edit'

            Overriding generated ``path_prefix``::

                >>> m = Mapper()
                >>> m.resource('location', 'locations',
                ...            parent_resource=dict(member_name='region',
                ...                                 collection_name='regions'),
                ...            path_prefix='areas/:area_id')
                >>> # name prefix is "region_"
                >>> url_for('region_locations', area_id=51)
                '/areas/51/locations'

            Overriding generated ``name_prefix``::

                >>> m = Mapper()
                >>> m.resource('location', 'locations',
                ...            parent_resource=dict(member_name='region',
                ...                                 collection_name='regions'),
                ...            name_prefix='')
                >>> # path_prefix is "regions/:region_id"
                >>> url_for('locations', region_id=51)
                '/regions/51/locations'

        """
        collection = kwargs.pop('collection', {})
        member = kwargs.pop('member', {})
        new = kwargs.pop('new', {})
        path_prefix = kwargs.pop('path_prefix', None)
        name_prefix = kwargs.pop('name_prefix', None)
        parent_resource = kwargs.pop('parent_resource', None)

        # Generate ``path_prefix`` if ``path_prefix`` wasn't specified and
        # ``parent_resource`` was. Likewise for ``name_prefix``. Make sure
        # that ``path_prefix`` and ``name_prefix`` *always* take precedence if
        # they are specified--in particular, we need to be careful when they
        # are explicitly set to "".
        if parent_resource is not None:
            if path_prefix is None:
                path_prefix = '%s/:%s_id' % (parent_resource['collection_name'],
                                             parent_resource['member_name'])
            if name_prefix is None:
                name_prefix = '%s_' % parent_resource['member_name']
        else:
            if path_prefix is None:
                path_prefix = ''
            if name_prefix is None:
                name_prefix = ''

        # Ensure the edit and new actions are in and GET
        member['edit'] = 'GET'
        new.update({'new': 'GET'})

        # Make new dict's based off the old, except the old values become keys,
        # and the old keys become items in a list as the value
        def swap(dct, newdct):
            """Swap the keys and values in the dict, and uppercase the values
            from the dict during the swap."""
            for key, val in six.iteritems(dct):
                newdct.setdefault(val.upper(), []).append(key)
            return newdct
        collection_methods = swap(collection, {})
        member_methods = swap(member, {})
        new_methods = swap(new, {})

        # Insert create, update, and destroy methods
        collection_methods.setdefault('POST', []).insert(0, 'create')
        member_methods.setdefault('PUT', []).insert(0, 'update')
        member_methods.setdefault('DELETE', []).insert(0, 'delete')

        # If there's a path prefix option, use it with the controller
        controller = strip_slashes(collection_name)
        path_prefix = strip_slashes(path_prefix)
        path_prefix = '/' + path_prefix
        if path_prefix and path_prefix != '/':
            path = path_prefix + '/' + controller
        else:
            path = '/' + controller
        collection_path = path
        new_path = path + "/new"
        member_path = path + "/:(id)"

        options = {
            'controller': kwargs.get('controller', controller),
            '_member_name': member_name,
            '_collection_name': collection_name,
            '_parent_resource': parent_resource,
            '_filter': kwargs.get('_filter')
        }
        if 'requirements' in kwargs:
            options['requirements'] = kwargs['requirements']

        def requirements_for(meth):
            """Returns a new dict to be used for all route creation as the
            route options"""
            opts = options.copy()
            if method != 'any':
                opts['conditions'] = {'method': [meth.upper()]}
            return opts

        # Add the routes for handling collection methods
        for method, lst in six.iteritems(collection_methods):
            primary = (method != 'GET' and lst.pop(0)) or None
            route_options = requirements_for(method)
            for action in lst:
                route_options['action'] = action
                route_name = "%s%s_%s" % (name_prefix, action, collection_name)
                self.connect("formatted_" + route_name, "%s/%s.:(format)" %
                             (collection_path, action), **route_options)
                self.connect(route_name, "%s/%s" % (collection_path, action),
                             **route_options)
            if primary:
                route_options['action'] = primary
                self.connect("%s.:(format)" % collection_path, **route_options)
                self.connect(collection_path, **route_options)

        # Specifically add in the built-in 'index' collection method and its
        # formatted version
        self.connect("formatted_" + name_prefix + collection_name,
                     collection_path + ".:(format)", action='index',
                     conditions={'method': ['GET']}, **options)
        self.connect(name_prefix + collection_name, collection_path,
                     action='index', conditions={'method': ['GET']}, **options)

        # Add the routes that deal with new resource methods
        for method, lst in six.iteritems(new_methods):
            route_options = requirements_for(method)
            for action in lst:
                name = "new_" + member_name
                route_options['action'] = action
                if action == 'new':
                    path = new_path
                    formatted_path = new_path + '.:(format)'
                else:
                    path = "%s/%s" % (new_path, action)
                    name = action + "_" + name
                    formatted_path = "%s/%s.:(format)" % (new_path, action)
                self.connect("formatted_" + name_prefix + name, formatted_path,
                             **route_options)
                self.connect(name_prefix + name, path, **route_options)

        requirements_regexp = '[^\/]+(?<!\\\)'

        # Add the routes that deal with member methods of a resource
        for method, lst in six.iteritems(member_methods):
            route_options = requirements_for(method)
            route_options['requirements'] = {'id': requirements_regexp}
            if method not in ['POST', 'GET', 'any']:
                primary = lst.pop(0)
            else:
                primary = None
            for action in lst:
                route_options['action'] = action
                self.connect("formatted_%s%s_%s" % (name_prefix, action,
                                                    member_name),
                             "%s/%s.:(format)" % (member_path, action),
                             **route_options)
                self.connect("%s%s_%s" % (name_prefix, action, member_name),
                             "%s/%s" % (member_path, action), **route_options)
            if primary:
                route_options['action'] = primary
                self.connect("%s.:(format)" % member_path, **route_options)
                self.connect(member_path, **route_options)

        # Specifically add the member 'show' method
        route_options = requirements_for('GET')
        route_options['action'] = 'show'
        route_options['requirements'] = {'id': requirements_regexp}
        self.connect("formatted_" + name_prefix + member_name,
                     member_path + ".:(format)", **route_options)
        self.connect(name_prefix + member_name, member_path, **route_options)

Example 6

Project: asv
Source File: publish.py
View license
    @classmethod
    def run(cls, conf, env_spec=None):
        params = {}
        graphs = GraphSet()
        machines = {}
        benchmark_names = set()

        log.set_nitems(6 + len(list(util.iter_subclasses(OutputPublisher))))

        if os.path.exists(conf.html_dir):
            util.long_path_rmtree(conf.html_dir)

        environments = list(environment.get_environments(conf, env_spec))
        repo = get_repo(conf)
        benchmarks = Benchmarks.load(conf, repo, environments)

        template_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), '..', 'www')
        shutil.copytree(template_dir, conf.html_dir)

        log.step()
        log.info("Loading machine info")
        with log.indent():
            for path in iter_machine_files(conf.results_dir):
                d = util.load_json(path)
                machines[d['machine']] = d

        log.step()
        log.info("Getting params, commits, tags and branches")
        with log.indent():
            # Determine first the set of all parameters and all commits
            hash_to_date = {}
            for results in iter_results(conf.results_dir):
                hash_to_date[results.commit_hash] = results.date
                for key, val in six.iteritems(results.params):
                    if val is None:
                        # Backward compatibility -- null means ''
                        val = ''

                    params.setdefault(key, set())
                    params[key].add(val)

            repo.pull()
            tags = repo.get_tags()
            revisions = repo.get_revisions(set(hash_to_date.keys()) | set(tags.values()))

            for tag, commit_hash in list(tags.items()):
                # Map to revision number instead of commit hash and add tags to hash_to_date
                tags[tag] = revisions[tags[tag]]
                hash_to_date[commit_hash] = repo.get_date_from_name(commit_hash)

            revision_to_date = dict((r, hash_to_date[h]) for h, r in six.iteritems(revisions))

            branches = dict(
                (branch, repo.get_branch_commits(branch))
                for branch in conf.branches)

        log.step()
        log.info("Loading results")
        with log.indent():
            # Generate all graphs
            for results in iter_results(conf.results_dir):
                log.dot()

                for key, val in six.iteritems(results.results):
                    b = benchmarks.get(key)
                    result = compatible_results(val, b)

                    benchmark_names.add(key)

                    for branch in [
                        branch for branch, commits in branches.items()
                        if results.commit_hash in commits
                    ]:
                        cur_params = dict(results.params)
                        cur_params['branch'] = repo.get_branch_name(branch)

                        # Backward compatibility, see above
                        for param_key, param_value in list(cur_params.items()):
                            if param_value is None:
                                cur_params[param_key] = ''

                        # Fill in missing params
                        for param_key in params.keys():
                            if param_key not in cur_params:
                                cur_params[param_key] = None
                                params[param_key].add(None)

                        # Create graph
                        graph = graphs.get_graph(key, cur_params)
                        graph.add_data_point(revisions[results.commit_hash], result)

            # Get the parameter sets for all graphs
            graph_param_list = []
            for path, graph in graphs:
                if 'summary' not in graph.params:
                    if graph.params not in graph_param_list:
                        graph_param_list.append(graph.params)

        log.step()
        log.info("Detecting steps")
        with log.indent():
            n_processes = multiprocessing.cpu_count()
            pool = multiprocessing.Pool(n_processes)
            try:
                graphs.detect_steps(pool, dots=log.dot)
            finally:
                pool.terminate()

        log.step()
        log.info("Generating graphs")
        with log.indent():
            # Save files
            graphs.save(conf.html_dir, dots=log.dot)

        pages = []
        classes = sorted(util.iter_subclasses(OutputPublisher),
                         key=lambda cls: cls.order)
        for cls in classes:
            log.step()
            log.info("Generating output for {0}".format(cls.__name__))
            with log.indent():
                cls.publish(conf, repo, benchmarks, graphs, revisions)
                pages.append([cls.name, cls.button_label, cls.description])

        log.step()
        log.info("Writing index")
        benchmark_map = dict(benchmarks)
        for key in six.iterkeys(benchmark_map):
            check_benchmark_params(key, benchmark_map[key])
        for key, val in six.iteritems(params):
            val = list(val)
            val.sort(key=lambda x: '[none]' if x is None else str(x))
            params[key] = val
        params['branch'] = [repo.get_branch_name(branch) for branch in conf.branches]
        revision_to_hash = dict((r, h) for h, r in six.iteritems(revisions))
        util.write_json(os.path.join(conf.html_dir, "index.json"), {
            'project': conf.project,
            'project_url': conf.project_url,
            'show_commit_url': conf.show_commit_url,
            'hash_length': conf.hash_length,
            'revision_to_hash': revision_to_hash,
            'revision_to_date': revision_to_date,
            'params': params,
            'graph_param_list': graph_param_list,
            'benchmarks': benchmark_map,
            'machines': machines,
            'tags': tags,
            'pages': pages,
        })

Example 7

Project: flywheel
Source File: engine.py
View license
    def sync(self, items, raise_on_conflict=None, consistent=False,
             constraints=None, no_read=False):
        """
        Sync model changes back to database

        This will push any updates to the database, and ensure that all of the
        synced items have the most up-to-date data.

        Parameters
        ----------
        items : list or :class:`~flywheel.models.Model`
            Models to sync
        raise_on_conflict : bool, optional
            If True, raise exception if any of the fields that are being
            updated were concurrently changed in the database (default set by
            :attr:`.default_conflict`)
        consistent : bool, optional
            If True, force a consistent read from the db. This will only take
            effect if the sync is only performing a read. (default False)
        constraints : list, optional
            List of more complex constraints that must pass for the update to
            complete. Must be used with raise_on_conflict=True. Format is the
            same as query filters (e.g. Model.fieldname > 5)
        no_read : bool, optional
            If True, don't perform a GET on models with no changes. (default False)

        Raises
        ------
        exc : :class:`dynamo3.CheckFailed`
            If raise_on_conflict=True and the data in dynamo fails the
            contraint checks.

        """
        if raise_on_conflict is None:
            raise_on_conflict = self.default_conflict in ('update', 'raise')
        if constraints is not None and not raise_on_conflict:
            raise ValueError("Cannot pass constraints to sync() when raise_on_conflict is False")
        if isinstance(items, Model):
            items = [items]
        refresh_models = []
        for item in items:
            # Look for any mutable fields (e.g. sets) that have changed
            for name, field in six.iteritems(item.meta_.fields):
                if name in item.__dirty__ or name in item.__incrs__:
                    continue
                if field.is_mutable:
                    cached_var = item.cached_(name)
                    if field.resolve(item) != cached_var:
                        for related in item.meta_.related_fields[name]:
                            item.__dirty__.add(related)

            if not item.__dirty__ and not item.__incrs__:
                refresh_models.append(item)
                continue
            fields = item.__dirty__
            item.pre_save_(self)

            keywords = {}
            constrained_fields = set()
            if raise_on_conflict and constraints is not None:
                for constraint in constraints:
                    constrained_fields.update(constraint.eq_fields.keys())
                    constrained_fields.update(constraint.fields.keys())
                    keywords.update(constraint.scan_kwargs())

            updates = []
            # Set dynamo keys
            for name in fields:
                field = item.meta_.fields.get(name)
                value = getattr(item, name)
                kwargs = {}
                if raise_on_conflict and name not in constrained_fields:
                    kwargs = {'eq': item.ddb_dump_cached_(name)}
                update = ItemUpdate.put(name, item.ddb_dump_field_(name),
                                        **kwargs)
                updates.append(update)

            # Atomic increment fields
            for name, value in six.iteritems(item.__incrs__):
                kwargs = {}
                # We don't need to ddb_dump because we know they're all native
                if isinstance(value, SetDelta):
                    update = ItemUpdate(value.action, name, value.values)
                else:
                    update = ItemUpdate.add(name, value)
                updates.append(update)

            # Perform sync
            ret = self.dynamo.update_item(
                item.meta_.ddb_tablename(self.namespace), item.pk_dict_,
                updates, returns=ALL_NEW, **keywords)

            # Load updated data back into object
            with item.loading_(self):
                for key, val in six.iteritems(ret):
                    item.set_ddb_val_(key, val)

            item.post_save_()

        # Handle items that didn't have any fields to update

        # If the item isn't known to exist in the db, try to save it first
        for item in refresh_models:
            if not item.persisted_:
                try:
                    self.save(item, overwrite=False)
                except CheckFailed:
                    pass
        # Refresh item data
        if not no_read:
            self.refresh(refresh_models, consistent=consistent)

Example 8

Project: kiel
Source File: producer.py
View license
    @gen.coroutine
    def flush(self):
        """
        Transforms the ``unsent`` structure to produce requests and sends them.

        The first order of business is to order the pending messages in
        ``unsent`` based on partition leader.  If a message's partition leader
        is not a know broker, the message is queued up to be retried and the
        flag denoting that a cluster ``heal()`` call is needed is set.

        Once the legitimate messages are ordered, instances of ProduceRequest
        are created for each broker and sent.
        """
        if not self.unsent:
            return

        # leader -> topic -> partition -> message list
        ordered = collections.defaultdict(
            lambda: collections.defaultdict(
                lambda: collections.defaultdict(list)
            )
        )

        to_retry = collections.defaultdict(list)

        for topic, msgs in drain(self.unsent):
            for msg in msgs:
                partition = self.partitioner(
                    msg.key, self.cluster.topics[topic]
                )
                leader = self.cluster.get_leader(topic, partition)
                if leader not in self.cluster:
                    to_retry[topic].append(msg)
                    continue
                ordered[leader][topic][partition].append(msg)

        requests = {}
        for leader, topics in six.iteritems(ordered):
            requests[leader] = produce_api.ProduceRequest(
                required_acks=self.required_acks,
                timeout=self.ack_timeout,
                topics=[]
            )
            for topic, partitions in six.iteritems(topics):
                requests[leader].topics.append(
                    produce_api.TopicRequest(name=topic, partitions=[])
                )
                for partition_id, msgs in six.iteritems(partitions):
                    requests[leader].topics[-1].partitions.append(
                        produce_api.PartitionRequest(
                            partition_id=partition_id,
                            message_set=messages.MessageSet.compressed(
                                self.compression, msgs
                            )
                        )
                    )
                    self.sent[
                        requests[leader].correlation_id
                    ][topic][partition_id] = msgs

        for topic, msgs in six.iteritems(to_retry):
            self.queue_retries(topic, msgs)

        yield self.send(requests)

Example 9

View license
    def collect(self, objs, source=None, nullable=False, collect_related=True,
                source_attr=None, reverse_dependency=False, keep_parents=False):
        """
        Adds 'objs' to the collection of objects to be deleted as well as all
        parent instances.  'objs' must be a homogeneous iterable collection of
        model instances (e.g. a QuerySet).  If 'collect_related' is True,
        related objects will be handled by their respective on_delete handler.

        If the call is the result of a cascade, 'source' should be the model
        that caused it and 'nullable' should be set to True, if the relation
        can be null.

        If 'reverse_dependency' is True, 'source' will be deleted before the
        current model, rather than after. (Needed for cascading to parent
        models, the one case in which the cascade follows the forwards
        direction of an FK rather than the reverse direction.)

        If 'keep_parents' is True, data of parent model's will be not deleted.
        """
        if self.can_fast_delete(objs):
            self.fast_deletes.append(objs)
            return
        new_objs = self.add(objs, source, nullable,
                            reverse_dependency=reverse_dependency)
        if not new_objs:
            return

        concrete_model_objs = {}
        for obj in new_objs:
            model = obj.__class__
            concrete_model = model._meta.concrete_model
            concrete_model_objs.setdefault(concrete_model, {})
            concrete_model_objs[concrete_model].setdefault(model, [])
            concrete_model_objs[concrete_model][model].append(obj)

        for concrete_model, model_objs in six.iteritems(concrete_model_objs):
            if not keep_parents:
                parent_objs = []
                for model, new_objs in six.iteritems(model_objs):
                    # Recursively collect concrete model's parent models, but not their
                    # related objects. These will be found by meta.get_all_related_objects()
                    for ptr in six.itervalues(concrete_model._meta.parents):
                        if ptr:
                            # FIXME: This seems to be buggy and execute a query for each
                            # parent object fetch. We have the parent data in the obj,
                            # but we don't have a nice way to turn that data into parent
                            # object instance.
                            parent_objs += [getattr(obj, ptr.name) for obj in new_objs]
                if parent_objs:
                    self.collect(parent_objs, source=model,
                                 source_attr=ptr.rel.related_name,
                                 collect_related=False,
                                 reverse_dependency=True)

            if collect_related:
                for model, new_objs in six.iteritems(model_objs):
                    for related in deletion.get_candidate_relations_to_delete(model._meta):
                        field = related.field
                        if field.rel.on_delete == deletion.DO_NOTHING:
                            continue
                        batches = self.get_del_batches(new_objs, field)
                        for batch in batches:
                            sub_objs = self.related_objects(related, batch)
                            if self.can_fast_delete(sub_objs, from_field=field):
                                self.fast_deletes.append(sub_objs)
                            elif sub_objs:
                                field.rel.on_delete(self, field, sub_objs, self.using)
                    for field in model._meta.virtual_fields:
                        if hasattr(field, 'bulk_related_objects'):
                            # Its something like generic foreign key.
                            sub_objs = field.bulk_related_objects(new_objs, self.using)
                            self.collect(sub_objs,
                                         source=model,
                                         source_attr=field.rel.related_name,
                                         nullable=True)

Example 10

Project: claviger
Source File: config.py
View license
def load(path):
    """ Loads the configuration file.
    
        A lot of the work is done by YAML.  We validate the easy bits with
        a JSON schema. The rest by hand. """
    # TODO Cache schema and configuration file
    l.debug('loading configuration file ...')
    with open(path) as f:
        cfg = yaml.load(f)

    if not isinstance(cfg, dict):
        raise ConfigurationError('Configuration file is empty')

    l.debug('  - checking schema')
    # First small fixes which the schema can't handle
    cfg.setdefault('servers', {})
    cfg['servers'].setdefault('$default', {})
    for key in cfg['servers']:
        if cfg['servers'][key] is None:
            cfg['servers'][key] = dict()

    # Now check the schema
    jsonschema.validate(cfg, get_schema())
    # TODO format into pretty error message

    l.debug('  - processing keys')
    new_keys = {}
    cfg.setdefault('keys', {})
    for key_name, key in six.iteritems(cfg['keys']):
        # TODO handle error
        entry = claviger.authorized_keys.Entry.parse(key)
        new_key = {'key': entry.key,
                   'options': entry.options,
                   'comment': entry.comment,
                   'keytype': entry.keytype}
        new_keys[key_name] = new_key
    cfg['keys'] = new_keys

    l.debug('  - processing server stanza short-hands')
    new_servers = {}
    for server_key, server in six.iteritems(cfg['servers']):
        parsed_server_key = parse_server_key(server_key)
        server.setdefault('name', server_key)
        server_name = server['name']
        server.setdefault('port', parsed_server_key.port)
        server.setdefault('user', parsed_server_key.user)
        server.setdefault('hostname', parsed_server_key.hostname)
        server.setdefault('ssh_user', server['user'])
        server.setdefault('present', [])
        server.setdefault('absent', [])
        server.setdefault('allow', [])
        server.setdefault('keepOtherKeys')
        server.setdefault('like', '$default' if server_key != '$default'
                                        else None)
        server.setdefault('abstract', parsed_server_key.abstract)
        prabsent = frozenset(server['present']) & frozenset(server['absent'])
        if prabsent:
            raise ConfigurationError(
                "Keys {0} are required to be both present and absent on {1}"
                    .format(tuple(prabsent), server_name))
        ablow = frozenset(server['allow']) & frozenset(server['absent'])
        if ablow:
            raise ConfigurationError(
                "Keys {0} are listed allowed and absent on {1}"
                    .format(tuple(ablow), server_name))
        for key_name in itertools.chain(server['present'], server['absent'],
                                        server['allow']):
            if not key_name in cfg['keys']:
                "Key {0} (on {1}) does not exist".format(key_name, server_name)
        if server_name in new_servers:
            raise ConfigurationError(
                "Duplicate server name {0}".format(server_name))
        new_servers[server_name] = server
    cfg['servers'] = new_servers

    l.debug('  - resolving server stanza inheritance')
    # create dependancy graph and use Tarjan's algorithm to find a possible
    # order to evaluate the server stanzas.
    server_dg = {server_name: [server['like']] if server['like'] else []
                    for server_name, server in six.iteritems(cfg['servers'])}
    for server_cycle_names in tarjan.tarjan(server_dg):
        if len(server_cycle_names) != 1:
            raise ConfigurationError(
                    "There is a cyclic dependacy among the servers {0}".format(
                                server_cycle_names))
        target_server = cfg['servers'][server_cycle_names[0]]
        if not target_server['like']:
            continue
        if not target_server['like'] in cfg['servers']:
            pass
        source_server = cfg['servers'][target_server['like']]

        # First the simple attributes
        for attr in ('port', 'user', 'hostname', 'ssh_user',
                        'keepOtherKeys'):
            if attr in source_server:
                if target_server[attr] is None:
                    target_server[attr] = source_server[attr]

        # Now, the present/absent/allow lists
        for key in source_server['present']:
            if key in target_server['absent']:
                continue
            if key not in target_server['present']:
                target_server['present'].append(key)
        for key in source_server['absent']:
            if (key in target_server['present']
                    or key in target_server['allow']):
                continue
            if key not in target_server['absent']:
                target_server['absent'].append(key)
        for key in source_server['allow']:
            if key in target_server['absent']:
                continue
            if key not in target_server['allow']:
                target_server['allow'].append(key)

    l.debug('  - setting defaults on server stanzas')
    for server in six.itervalues(cfg['servers']):
        for attr, dflt in (('port', 22),
                           ('user', 'root'),
                           ('keepOtherKeys', True)):
            if server[attr] is None:
                server[attr] = dflt
        
    l.debug('         ... done')

    return cfg

Example 11

Project: python-foreman
Source File: client.py
View license
    def _generate_api_defs(self, use_cache=True, strict_cache=True):
        """
        This method populates the class with the api definitions.

        :param use_cache: If set, will try to get the definitions from the
            local cache first, then from the remote server, and at last will
            try to get the closest one from the local cached
        :param strict_cache: If True, will not accept a similar version cached
            definitions file as valid
        """
        data = self._get_defs(use_cache, strict_cache=strict_cache)

        resource_defs = {}
        # parse all the defs first, as they may define methods cross-resource
        for res_name, res_dct in six.iteritems(data["docs"]["resources"]):
            new_resource, extra_foreign_methods = parse_resource_definition(
                res_name.lower(),
                res_dct,
            )
            # if the resource did already exist (for example, was defined
            # through a foreign method by enother resource), complain if it
            # overwrites any methods
            if res_name in resource_defs:
                old_res = resource_defs[res_name]
                for prop_name, prop_val in six.iteritems(new_resource):
                    if (
                        prop_name == '_own_methods' and
                        prop_name in new_resource
                    ):
                        old_res[prop_name].union(prop_val)
                        continue
                    # skip internal/private/magic methods
                    if prop_name.startswith('_'):
                        continue
                    if prop_name in old_res:
                        logger.warning(
                            "There is conflict trying to redefine method "
                            "(%s) with foreign method: \n"
                            "\tapipie_resource: %s\n",
                            prop_name,
                            res_name,
                        )
                        continue
                    old_res[prop_name] = prop_val
            else:
                resource_defs[res_name] = new_resource

            # update the other resources with the foreign methods, create
            # the resources if not there yet, merge if it already exists
            for f_res_name, f_methods in six.iteritems(extra_foreign_methods):
                methods = resource_defs.setdefault(
                    f_res_name,
                    {'_own_methods': set()},
                )

                for f_mname, f_method in six.iteritems(f_methods):
                    if f_mname in methods:
                        logger.warning(
                            "There is conflict trying to redefine method "
                            "(%s) with foreign method: \n"
                            "\tapipie_resource: %s\n",
                            f_mname,
                            f_res_name,
                        )
                        continue
                    methods[f_mname] = f_method
                    methods['_own_methods'].add(f_mname)

        # Finally ceate the resource classes for all the collected resources
        # instantiate and bind them to this class
        for resource_name, resource_data in six.iteritems(resource_defs):
            new_resource = ResourceMeta.__new__(
                ResourceMeta,
                str(resource_name),
                (Resource,),
                resource_data,
            )
            if not resource_data['_own_methods']:
                logger.debug('Skipping empty resource %s' % resource_name)
                continue
            instance = new_resource(self)
            setattr(self, resource_name, instance)

Example 12

Project: django-extensions
Source File: shells.py
View license
def import_objects(options, style):
    from django.apps import apps
    from django import setup
    if not apps.ready:
        setup()

    def get_apps_and_models():
        for app in apps.get_app_configs():
            if app.models_module:
                yield app.models_module, app.get_models()

    mongoengine = False
    try:
        from mongoengine.base import _document_registry
        mongoengine = True
    except:
        pass

    from django.conf import settings
    imported_objects = {}

    dont_load_cli = options.get('dont_load')  # optparse will set this to [] if it doensnt exists
    dont_load_conf = getattr(settings, 'SHELL_PLUS_DONT_LOAD', [])
    dont_load = dont_load_cli + dont_load_conf
    quiet_load = options.get('quiet_load')

    model_aliases = getattr(settings, 'SHELL_PLUS_MODEL_ALIASES', {})
    app_prefixes = getattr(settings, 'SHELL_PLUS_APP_PREFIXES', {})

    # Perform pre-imports before any other imports
    SHELL_PLUS_PRE_IMPORTS = getattr(settings, 'SHELL_PLUS_PRE_IMPORTS', {})
    if SHELL_PLUS_PRE_IMPORTS:
        if not quiet_load:
            print(style.SQL_TABLE("# Shell Plus User Imports"))
        imports = import_items(SHELL_PLUS_PRE_IMPORTS, style, quiet_load=quiet_load)
        for k, v in six.iteritems(imports):
            imported_objects[k] = v

    load_models = {}

    if mongoengine:
        for name, mod in six.iteritems(_document_registry):
            name = name.split('.')[-1]
            app_name = mod.__module__.split('.')[-2]
            if app_name in dont_load or ("%s.%s" % (app_name, name)) in dont_load:
                continue

            load_models.setdefault(mod.__module__, [])
            load_models[mod.__module__].append(name)

    for app_mod, app_models in get_apps_and_models():
        if not app_models:
            continue

        app_name = app_mod.__name__.split('.')[-2]
        if app_name in dont_load:
            continue

        for mod in app_models:
            if "%s.%s" % (app_name, mod.__name__) in dont_load:
                continue

            if mod.__module__:
                # Only add the module to the dict if `__module__` is not empty.
                load_models.setdefault(mod.__module__, [])
                load_models[mod.__module__].append(mod.__name__)

    if not quiet_load:
        print(style.SQL_TABLE("# Shell Plus Model Imports"))

    for app_mod, models in sorted(six.iteritems(load_models)):
        try:
            app_name = app_mod.split('.')[-2]
        except IndexError:
            # Some weird model naming scheme like in Sentry.
            app_name = app_mod
        app_aliases = model_aliases.get(app_name, {})
        prefix = app_prefixes.get(app_name)
        model_labels = []

        for model_name in sorted(models):
            try:
                imported_object = getattr(__import__(app_mod, {}, {}, [model_name]), model_name)

                if "%s.%s" % (app_name, model_name) in dont_load:
                    continue

                alias = app_aliases.get(model_name)

                if not alias:
                    if prefix:
                        alias = "%s_%s" % (prefix, model_name)
                    else:
                        alias = model_name

                imported_objects[alias] = imported_object
                if model_name == alias:
                    model_labels.append(model_name)
                else:
                    model_labels.append("%s (as %s)" % (model_name, alias))

            except AttributeError as e:
                if options.get("traceback"):
                    traceback.print_exc()
                if not quiet_load:
                    print(style.ERROR("Failed to import '%s' from '%s' reason: %s" % (model_name, app_mod, str(e))))
                continue

        if not quiet_load:
            print(style.SQL_COLTYPE("from %s import %s" % (app_mod, ", ".join(model_labels))))

    # Imports often used from Django
    if getattr(settings, 'SHELL_PLUS_DJANGO_IMPORTS', True):
        if not quiet_load:
            print(style.SQL_TABLE("# Shell Plus Django Imports"))
        from django import VERSION as DJANGO_VERSION
        SHELL_PLUS_DJANGO_IMPORTS = {
            'django.core.cache': ['cache'],
            'django.conf': ['settings'],
            'django.db': ['transaction'],
            'django.db.models': [
                'Avg', 'Case', 'Count', 'F', 'Max', 'Min', 'Prefetch', 'Q',
                'Sum', 'When',
            ],
            'django.utils': ['timezone'],
        }
        if DJANGO_VERSION < (1, 10):
            SHELL_PLUS_DJANGO_IMPORTS.update({
                'django.core.urlresolvers': ['reverse'],
            })
        else:
            SHELL_PLUS_DJANGO_IMPORTS.update({
                'django.urls': ['reverse'],
            })
        imports = import_items(SHELL_PLUS_DJANGO_IMPORTS.items(), style, quiet_load=quiet_load)
        for k, v in six.iteritems(imports):
            imported_objects[k] = v

    # Perform post-imports after any other imports
    SHELL_PLUS_POST_IMPORTS = getattr(settings, 'SHELL_PLUS_POST_IMPORTS', {})
    if SHELL_PLUS_POST_IMPORTS:
        if not quiet_load:
            print(style.SQL_TABLE("# Shell Plus User Imports"))
        imports = import_items(SHELL_PLUS_POST_IMPORTS, style, quiet_load=quiet_load)
        for k, v in six.iteritems(imports):
            imported_objects[k] = v

    return imported_objects

Example 13

Project: ete
Source File: scheduler.py
View license
def schedule(workflow_task_processor, pending_tasks, schedule_time, execution, debug, norender):
    # Adjust debug mode
    if debug == "all":
        log.setLevel(10)
    pending_tasks = set(pending_tasks)

    ## ===================================
    ## INITIALIZE BASIC VARS
    execution, run_detached = execution
    thread2tasks = defaultdict(list)
    for task in pending_tasks:
        thread2tasks[task.configid].append(task)
    expected_threads = set(thread2tasks.keys())
    past_threads = {}
    thread_errors = defaultdict(list)
    ## END OF VARS AND SHORTCUTS
    ## ===================================

    cores_total = GLOBALS["_max_cores"]
    if cores_total > 0:
        job_queue = Queue()

        back_launcher = Process(target=background_job_launcher,
                                args=(job_queue, run_detached,
                                      GLOBALS["launch_time"], cores_total))
        back_launcher.start()
    else:
        job_queue = None
        back_launcher = None

    GLOBALS["_background_scheduler"] = back_launcher
    GLOBALS["_job_queue"] = job_queue
    # Captures Ctrl-C for debuging DEBUG
    #signal.signal(signal.SIGINT, control_c)

    last_report_time = None

    BUG = set()
    try:
        # Enters into task scheduling
        while pending_tasks:
            wtime = schedule_time

            # ask SGE for running jobs
            if execution == "sge":
                #sgeid2jobs = db.get_sge_tasks()
                #qstat_jobs = sge.qstat()
                pass
            else:
                qstat_jobs = None

            # Show summary of pending tasks per thread
            thread2tasks = defaultdict(list)
            for task in pending_tasks:
                thread2tasks[task.configid].append(task)
            set_logindent(0)
            log.log(28, "@@13: Updating tasks status:@@1: (%s)" % (ctime()))
            info_lines = []
            for tid, tlist in six.iteritems(thread2tasks):
                threadname = GLOBALS[tid]["_name"]
                sizelist = ["%s" %getattr(_ts, "size", "?") for _ts in tlist]
                info = "Thread @@13:%[email protected]@1:: pending tasks: @@8:%[email protected]@1: of sizes: %s" %(
                    threadname, len(tlist), ', '.join(sizelist))
                info_lines.append(info)

            for line in info_lines:
                log.log(28, line)

            if GLOBALS["email"]  and last_report_time is None:
                last_report_time = time()
                send_mail(GLOBALS["email"], "Your NPR process has started", '\n'.join(info_lines))

            ## ================================
            ## CHECK AND UPDATE CURRENT TASKS
            checked_tasks = set()
            check_start_time = time()
            to_add_tasks = set()

            GLOBALS["cached_status"] = {}
            for task in sorted(pending_tasks, key=cmp_to_key(sort_tasks)):
                # Avoids endless periods without new job submissions
                elapsed_time = time() - check_start_time
                #if not back_launcher and pending_tasks and \
                #        elapsed_time > schedule_time * 2:
                #    log.log(26, "@@8:Interrupting task checks to schedule new [email protected]@1:")
                #    db.commit()
                #    wtime = launch_jobs(sorted(pending_tasks, sort_tasks),
                #                        execution, run_detached)
                #    check_start_time = time()

                # Enter debuging mode if necessary
                if debug and log.level > 10 and task.taskid.startswith(debug):
                    log.setLevel(10)
                    log.debug("ENTERING IN DEBUGGING MODE")
                thread2tasks[task.configid].append(task)

                # Update tasks and job statuses

                if task.taskid not in checked_tasks:
                    try:
                        show_task_info(task)
                        task.status = task.get_status(qstat_jobs)
                        db.dataconn.commit()
                        if back_launcher and task.status not in set("DE"):
                            for j, cmd in task.iter_waiting_jobs():
                                j.status = "Q"
                                GLOBALS["cached_status"][j.jobid] = "Q"
                                if j.jobid not in BUG:
                                    if not os.path.exists(j.jobdir):
                                        os.makedirs(j.jobdir)
                                    for ifile, outpath in six.iteritems(j.input_files):
                                        try:
                                            _tid, _did = ifile.split(".")
                                            _did = int(_did)
                                        except (IndexError, ValueError):
                                            dataid = ifile
                                        else:
                                            dataid = db.get_dataid(_tid, _did)

                                        if not outpath:
                                            outfile = pjoin(GLOBALS["input_dir"], ifile)
                                        else:
                                            outfile = pjoin(outpath, ifile)

                                        if not os.path.exists(outfile):
                                            open(outfile, "w").write(db.get_data(dataid))

                                    log.log(24, "  @@8:Queueing @@1: %s from %s" %(j, task))
                                    if execution:
                                        with open(pjoin(GLOBALS[task.configid]["_outpath"], "commands.log"), "a") as CMD_LOGGER:
                                            print('\t'.join([task.tname, task.taskid, j.jobname, j.jobid, j.get_launch_cmd()]), file=CMD_LOGGER)
                                            
                                        job_queue.put([j.jobid, j.cores, cmd, j.status_file])
                                BUG.add(j.jobid)

                        update_task_states_recursively(task)
                        db.commit()
                        checked_tasks.add(task.taskid)
                    except TaskError as e:
                        log.error("Errors found in %s" %task)
                        import traceback
                        traceback.print_exc()
                        if GLOBALS["email"]:
                            threadname = GLOBALS[task.configid]["_name"]
                            send_mail(GLOBALS["email"], "Errors found in %s!" %threadname,
                                      '\n'.join(map(str, [task, e.value, e.msg])))
                        pending_tasks.discard(task)
                        thread_errors[task.configid].append([task, e.value, e.msg])
                        continue
                else:
                    # Set temporary Queued state to avoids launching
                    # jobs from clones
                    task.status = "Q"
                    if log.level < 24:
                        show_task_info(task)

                if task.status == "D":
                    #db.commit()
                    show_task_info(task)
                    logindent(3)


                    # Log commands of every task
                    # if 'cmd_log_file' not in GLOBALS[task.configid]:
                    #      GLOBALS[task.configid]['cmd_log_file'] = pjoin(GLOBALS[task.configid]["_outpath"], "cmd.log")
                    #      O = open(GLOBALS[task.configid]['cmd_log_file'], "w")
                    #      O.close()

                    # cmd_lines =  get_cmd_log(task)
                    # CMD_LOG = open(GLOBALS[task.configid]['cmd_log_file'], "a")
                    # print(task, file=CMD_LOG)
                    # for c in cmd_lines:
                    #     print('   '+'\t'.join(map(str, c)), file=CMD_LOG)
                    # CMD_LOG.close()
                    #

                    try:
                        #wkname = GLOBALS[task.configid]['_name']
                        create_tasks = workflow_task_processor(task, task.target_wkname)
                    except TaskError as e:
                        log.error("Errors found in %s" %task)
                        pending_tasks.discard(task)
                        thread_errors[task.configid].append([task, e.value, e.msg])
                        continue
                    else:
                        logindent(-3)

                        to_add_tasks.update(create_tasks)
                        pending_tasks.discard(task)

                elif task.status == "E":
                    log.error("task contains errors: %s " %task)
                    log.error("Errors found in %s")
                    pending_tasks.discard(task)
                    thread_errors[task.configid].append([task, None, "Found (E) task status"])

            #db.commit()
            #if not back_launcher:
            #    wtime = launch_jobs(sorted(pending_tasks, sort_tasks),
            #                    execution, run_detached)

            # Update global task list with recently added jobs to be check
            # during next cycle
            pending_tasks.update(to_add_tasks)

            ## END CHECK AND UPDATE CURRENT TASKS
            ## ================================

            if wtime:
                set_logindent(0)
                log.log(28, "@@13:Waiting %s [email protected]@1:" %wtime)
                sleep(wtime)
            else:
                sleep(schedule_time)

            # Dump / show ended threads
            error_lines = []
            for configid, etasks in six.iteritems(thread_errors):
                error_lines.append("Thread @@10:%[email protected]@1: contains errors:" %\
                            (GLOBALS[configid]["_name"]))
                for error in etasks:
                    error_lines.append(" ** %s" %error[0])
                    e_obj = error[1] if error[1] else error[0]
                    error_path = e_obj.jobdir if isjob(e_obj) else e_obj.taskid
                    if e_obj is not error[0]:
                        error_lines.append("      -> %s" %e_obj)
                    error_lines.append("      -> %s" %error_path)
                    error_lines.append("        -> %s" %error[2])
            for eline in error_lines:
                log.error(eline)

            pending_threads = set([ts.configid for ts in pending_tasks])
            finished_threads = expected_threads - (pending_threads | set(thread_errors.keys()))
            just_finished_lines = []
            finished_lines = []
            for configid in finished_threads:
                # configid is the the same as threadid in master tasks
                final_tree_file = pjoin(GLOBALS[configid]["_outpath"],
                                        GLOBALS["inputname"] + ".final_tree")
                threadname = GLOBALS[configid]["_name"]

                if configid in past_threads:
                    log.log(28, "Done thread @@12:%[email protected]@1: in %d iteration(s)",
                            threadname, past_threads[configid])
                    finished_lines.append("Finished %s in %d iteration(s)" %(
                            threadname, past_threads[configid]))
                else:

                    log.log(28, "Assembling final tree...")
                    main_tree, treeiters =  assembly_tree(configid)
                    past_threads[configid] = treeiters - 1

                    log.log(28, "Done thread @@12:%[email protected]@1: in %d iteration(s)",
                            threadname, past_threads[configid])


                    log.log(28, "Writing final tree for @@13:%[email protected]@1:\n   %s\n   %s",
                            threadname, final_tree_file+".nw",
                            final_tree_file+".nwx (newick extended)")
                    main_tree.write(outfile=final_tree_file+".nw")
                    main_tree.write(outfile=final_tree_file+ ".nwx", features=[],
                                    format_root_node=True)

                    if hasattr(main_tree, "tree_phylip_alg"):
                        log.log(28, "Writing final tree alignment @@13:%[email protected]@1:\n   %s",
                                threadname, final_tree_file+".used_alg.fa")

                        alg = SeqGroup(get_stored_data(main_tree.tree_phylip_alg), format="iphylip_relaxed")
                        OUT = open(final_tree_file+".used_alg.fa", "w")
                        for name, seq, comments in alg:
                            realname = db.get_seq_name(name)
                            print(">%s\n%s" %(realname, seq), file=OUT)
                        OUT.close()

                    
                    if hasattr(main_tree, "alg_path"):
                        log.log(28, "Writing root node alignment @@13:%[email protected]@1:\n   %s",
                                threadname, final_tree_file+".fa")

                        alg = SeqGroup(get_stored_data(main_tree.alg_path))
                        OUT = open(final_tree_file+".fa", "w")
                        for name, seq, comments in alg:
                            realname = db.get_seq_name(name)
                            print(">%s\n%s" %(realname, seq), file=OUT)
                        OUT.close()

                    if hasattr(main_tree, "clean_alg_path"):
                        log.log(28, "Writing root node trimmed alignment @@13:%[email protected]@1:\n   %s",
                                threadname, final_tree_file+".trimmed.fa")

                        alg = SeqGroup(get_stored_data(main_tree.clean_alg_path))
                        OUT = open(final_tree_file+".trimmed.fa", "w")
                        for name, seq, comments in alg:
                            realname = db.get_seq_name(name)
                            print(">%s\n%s" %(realname, seq), file=OUT)
                        OUT.close()

                    if norender == False:
                        log.log(28, "Generating tree image for @@13:%[email protected]@1:\n   %s",
                                threadname, final_tree_file+".png")
                        for lf in main_tree:
                            lf.add_feature("sequence", alg.get_seq(lf.safename))
                        try:
                            from .visualize import draw_tree
                            draw_tree(main_tree, GLOBALS[configid], final_tree_file+".png")
                        except Exception as e:
                            log.warning('@@8:something went wrong when generating the tree image. Try manually :(@@1:')
                            if DEBUG:
                                import traceback, sys
                                traceback.print_exc(file=sys.stdout)

                    just_finished_lines.append("Finished %s in %d iteration(s)" %(
                            threadname, past_threads[configid]))
            if GLOBALS["email"]:
                if not pending_tasks:
                    all_lines = finished_lines + just_finished_lines + error_lines
                    send_mail(GLOBALS["email"], "Your NPR process has ended", '\n'.join(all_lines))

                elif GLOBALS["email_report_time"] and time() - last_report_time >= \
                        GLOBALS["email_report_time"]:
                    all_lines = info_lines + error_lines + just_finished_lines
                    send_mail(GLOBALS["email"], "Your NPR report", '\n'.join(all_lines))
                    last_report_time = time()

                elif just_finished_lines:
                    send_mail(GLOBALS["email"], "Finished threads!",
                              '\n'.join(just_finished_lines))

            log.log(26, "")
    except:
        raise

    if thread_errors:
        log.error("Done with ERRORS")
    else:
        log.log(28, "Done")

    return thread_errors

Example 14

Project: gnlpy
Source File: netlink.py
View license
def create_attr_list_type(class_name, *fields):
    """Create a new attr_list_type which is a class offering get and set
    methods which is capable of serializing and deserializing itself from
    netlink message.  The fields are a bunch of tuples of name and a class
    which should provide pack and unpack (except for in the case where we
    know it will be used exclusively for serialization or deserialization).
    attr_list_types can be used as packers in other attr_list_types.  The
    names and packers of the field should be taken from the appropriate
    linux kernel header and source files.
    """
    name_to_key = {}
    key_to_name = {}
    key_to_packer = {}
    for i, (name, packer) in enumerate(fields):
        key = i + 1
        name_to_key[name.upper()] = key
        key_to_name[key] = name
        key_to_packer[key] = packer

    class AttrListType(AttrListPacker):
        def __init__(self, **kwargs):
            self.attrs = {}
            for k, v in six.iteritems(kwargs):
                if v is not None:
                    self.set(k, v)

        def set(self, key, value):
            if not isinstance(key, int):
                key = name_to_key[key.upper()]
            self.attrs[key] = value

        def get(self, key, default=_unset):
            try:
                if not isinstance(key, int):
                    key = name_to_key[key.upper()]
                return self.attrs[key]
            except KeyError:
                if default is not _unset:
                    return default
                raise

        def __repr__(self):
            attrs = ['%s=%s' % (key_to_name[k].lower(), repr(v))
                     for k, v in six.iteritems(self.attrs)]
            return '%s(%s)' % (class_name, ', '.join(attrs))

        @staticmethod
        def pack(attr_list):
            packed = array.array(str('B'))
            for k, v in six.iteritems(attr_list.attrs):
                if key_to_packer[k] == RecursiveSelf:
                    x = AttrListType.pack(v)
                else:
                    x = key_to_packer[k].pack(v)
                alen = len(x) + 4

                # TODO(agartrell): This is scary.  In theory, we should OR
                # 1 << 15 into the length if it is an instance of
                # AttrListPacker, but this didn't work for some reason, so
                # we're not going to.

                packed.fromstring(struct.pack(str('=HH'), alen, k))
                packed.fromstring(x)
                packed.fromstring('\0' * ((4 - (len(x) % 4)) & 0x3))
            return packed

        @staticmethod
        def unpack(data):
            global global_nest
            attr_list = AttrListType()
            while len(data) > 0:
                alen, k = struct.unpack(str('=HH'), data[:4])
                alen = alen & 0x7fff
                if key_to_packer[k] == RecursiveSelf:
                    v = AttrListType.unpack(data[4:alen])
                else:
                    v = key_to_packer[k].unpack(data[4:alen])
                attr_list.set(k, v)
                data = data[((alen + 3) & (~3)):]
            return attr_list

    return AttrListType

Example 15

Project: pyinfra
Source File: cli.py
View license
def make_inventory(
    inventory_filename, deploy_dir=None, limit=None,
    ssh_user=None, ssh_key=None, ssh_key_password=None, ssh_port=None, ssh_password=None
):
    '''
    Builds a ``pyinfra.api.Inventory`` from the filesystem. If the file does not exist
    and doesn't contain a / attempts to use that as the only hostname.
    '''

    if ssh_port is not None:
        ssh_port = int(ssh_port)

    file_groupname = None

    try:
        attrs = exec_file(inventory_filename, return_locals=True)

        groups = {
            key: value
            for key, value in six.iteritems(attrs)
            if is_inventory_group(key, value)
        }

        # Used to set all the hosts to an additional group - that of the filename
        # ie inventories/dev.py means all the hosts are in the dev group, if not present
        file_groupname = path.basename(inventory_filename).split('.')[0].upper()

    except IOError as e:
        # If a /, definitely not a hostname
        if '/' in inventory_filename:
            raise CliError('{0}: {1}'.format(e.strerror, inventory_filename))

        # Otherwise we assume the inventory is actually a hostname or list of hostnames
        groups = {
            'all': inventory_filename.split(',')
        }

    all_data = {}

    if 'all' in groups:
        all_hosts = groups.pop('all')

        if isinstance(all_hosts, tuple):
            all_hosts, all_data = all_hosts

    # Build all out of the existing hosts if not defined
    else:
        all_hosts = []
        for hosts in groups.values():
            # Groups can be a list of hosts or tuple of (hosts, data)
            hosts = hosts[0] if isinstance(hosts, tuple) else hosts

            for host in hosts:
                # Hosts can be a hostname or tuple of (hostname, data)
                hostname = host[0] if isinstance(host, tuple) else host

                if hostname not in all_hosts:
                    all_hosts.append(hostname)

    groups['all'] = (all_hosts, all_data)

    # Apply the filename group if not already defined
    if file_groupname and file_groupname not in groups:
        groups[file_groupname] = all_hosts

    # In pyinfra an inventory is a combination of (hostnames + data). However, in CLI
    # mode we want to be define this in separate files (inventory / group data). The
    # issue is we want inventory access within the group data files - but at this point
    # we're not ready to make an Inventory. So here we just create a fake one, and attach
    # it to pseudo_inventory while we import the data files.
    fake_groups = {
        # In API mode groups *must* be tuples of (hostnames, data)
        name: group if isinstance(group, tuple) else (group, {})
        for name, group in six.iteritems(groups)
    }
    fake_inventory = Inventory((all_hosts, all_data), **fake_groups)
    pseudo_inventory.set(fake_inventory)

    # For each group load up any data
    for name, hosts in six.iteritems(groups):
        data = {}

        if isinstance(hosts, tuple):
            hosts, data = hosts

        data_filename = path.join(
            deploy_dir, 'group_data', '{0}.py'.format(name.lower())
        )
        logger.debug('Looking for group data: {0}'.format(data_filename))

        if path.exists(data_filename):
            # Read the files locals into a dict
            attrs = exec_file(data_filename, return_locals=True)

            data.update({
                key: value
                for key, value in six.iteritems(attrs)
                if is_group_data(key, value)
            })

        # Attach to group object
        groups[name] = (hosts, data)

    # Reset the pseudo inventory
    pseudo_inventory.reset()

    # Apply any limit to all_hosts
    if limit:
        # Limits can be groups
        limit_groupname = limit.upper()
        if limit_groupname in groups:
            all_hosts = [
                host[0] if isinstance(host, tuple) else host
                for host in groups[limit_groupname][0]
            ]

        # Or hostnames w/*wildcards
        else:
            limits = limit.split(',')

            all_hosts = [
                host for host in all_hosts
                if (
                    isinstance(host, tuple)
                    and any(fnmatch(host[0], limit) for limit in limits)
                )
                or (
                    isinstance(host, six.string_types)
                    and any(fnmatch(host, limit) for limit in limits)
                )
            ]

        # Reassign the all group w/limit
        groups['all'] = (all_hosts, all_data)

    return Inventory(
        groups.pop('all'),
        ssh_user=ssh_user,
        ssh_key=ssh_key,
        ssh_key_password=ssh_key_password,
        ssh_port=ssh_port,
        ssh_password=ssh_password,
        **groups
    ), file_groupname and file_groupname.lower()

Example 16

Project: formencode
Source File: variabledecode.py
View license
def variable_decode(d, dict_char='.', list_char='-'):
    """Decode the flat dictionary d into a nested structure."""
    result = {}
    dicts_to_sort = set()
    known_lengths = {}
    for key, value in six.iteritems(d):
        keys = key.split(dict_char)
        new_keys = []
        was_repetition_count = False
        for key in keys:
            if key.endswith('--repetitions'):
                key = key[:-len('--repetitions')]
                new_keys.append(key)
                known_lengths[tuple(new_keys)] = int(value)
                was_repetition_count = True
                break
            elif list_char in key:
                maybe_key, index = key.split(list_char, 1)
                if not index.isdigit():
                    new_keys.append(key)
                else:
                    key = maybe_key
                    new_keys.append(key)
                    dicts_to_sort.add(tuple(new_keys))
                    new_keys.append(int(index))
            else:
                new_keys.append(key)
        if was_repetition_count:
            continue

        place = result
        for i in range(len(new_keys) - 1):
            try:
                if not isinstance(place[new_keys[i]], dict):
                    place[new_keys[i]] = {None: place[new_keys[i]]}
                place = place[new_keys[i]]
            except KeyError:
                place[new_keys[i]] = {}
                place = place[new_keys[i]]
        if new_keys[-1] in place:
            if isinstance(place[new_keys[-1]], dict):
                place[new_keys[-1]][None] = value
            elif isinstance(place[new_keys[-1]], list):
                if isinstance(value, list):
                    place[new_keys[-1]].extend(value)
                else:
                    place[new_keys[-1]].append(value)
            else:
                if isinstance(value, list):
                    place[new_keys[-1]] = [place[new_keys[-1]]]
                    place[new_keys[-1]].extend(value)
                else:
                    place[new_keys[-1]] = [place[new_keys[-1]], value]
        else:
            place[new_keys[-1]] = value

    to_sort_list = sorted(dicts_to_sort, key=len, reverse=True)
    for key in to_sort_list:
        to_sort = result
        source = None
        last_key = None
        for sub_key in key:
            source = to_sort
            last_key = sub_key
            to_sort = to_sort[sub_key]
        if None in to_sort:
            none_values = [(0, x) for x in to_sort.pop(None)]
            none_values.extend(six.iteritems(to_sort))
            to_sort = none_values
        else:
            to_sort = six.iteritems(to_sort)
        to_sort = [x[1] for x in sorted(to_sort, key=_sort_key)]
        if key in known_lengths:
            if len(to_sort) < known_lengths[key]:
                to_sort.extend([''] * (known_lengths[key] - len(to_sort)))
        source[last_key] = to_sort

    return result

Example 17

Project: swiftly
Source File: directclient.py
View license
    def request(self, method, path, contents, headers, decode_json=False,
                stream=False, query=None, cdn=False):
        """
        See :py:func:`swiftly.client.client.Client.request`
        """
        if query:
            path += '?' + '&'.join(
                ('%s=%s' % (quote(k), quote(v)) if v else quote(k))
                for k, v in sorted(six.iteritems(query)))
        reset_func = self._default_reset_func
        if isinstance(contents, six.string_types):
            contents = StringIO(contents)
        tell = getattr(contents, 'tell', None)
        seek = getattr(contents, 'seek', None)
        if tell and seek:
            try:
                orig_pos = tell()
                reset_func = lambda: seek(orig_pos)
            except Exception:
                tell = seek = None
        elif not contents:
            reset_func = lambda: None
        status = 0
        reason = 'Unknown'
        attempt = 0
        while attempt < self.attempts:
            attempt += 1
            if cdn:
                conn_path = self.cdn_path
            else:
                conn_path = self.storage_path
            titled_headers = dict((k.title(), v) for k, v in six.iteritems({
                'User-Agent': self.user_agent}))
            if headers:
                titled_headers.update(
                    (k.title(), v) for k, v in six.iteritems(headers))
            resp = None
            if not hasattr(contents, 'read'):
                if method not in self.no_content_methods and contents and \
                        'Content-Length' not in titled_headers and \
                        'Transfer-Encoding' not in titled_headers:
                    titled_headers['Content-Length'] = str(
                        len(contents or ''))
                req = self.Request.blank(
                    conn_path + path,
                    environ={'REQUEST_METHOD': method, 'swift_owner': True},
                    headers=titled_headers, body=contents)
                verbose_headers = '  '.join(
                    '%s: %s' % (k, v) for k, v in six.iteritems(titled_headers))
                self.verbose(
                    '> %s %s %s', method, conn_path + path, verbose_headers)
                resp = req.get_response(self.swift_proxy)
            else:
                req = self.Request.blank(
                    conn_path + path,
                    environ={'REQUEST_METHOD': method, 'swift_owner': True},
                    headers=titled_headers)
                content_length = None
                for h, v in six.iteritems(titled_headers):
                    if h.lower() == 'content-length':
                        content_length = int(v)
                    req.headers[h] = v
                if method not in self.no_content_methods and \
                        content_length is None:
                    titled_headers['Transfer-Encoding'] = 'chunked'
                    req.headers['Transfer-Encoding'] = 'chunked'
                else:
                    req.content_length = content_length
                req.body_file = contents
                verbose_headers = '  '.join(
                    '%s: %s' % (k, v) for k, v in six.iteritems(titled_headers))
                self.verbose(
                    '> %s %s %s', method, conn_path + path, verbose_headers)
                resp = req.get_response(self.swift_proxy)
            status = resp.status_int
            reason = resp.status.split(' ', 1)[1]
            hdrs = headers_to_dict(resp.headers.items())
            if stream:
                def iter_reader(size=-1):
                    if size == -1:
                        return ''.join(resp.app_iter)
                    else:
                        try:
                            return next(resp.app_iter)
                        except StopIteration:
                            return ''
                iter_reader.read = iter_reader
                value = iter_reader
            else:
                value = resp.body
            self.verbose('< %s %s', status, reason)
            if status and status // 100 != 5:
                if not stream and decode_json and status // 100 == 2:
                    if value:
                        value = json.loads(value)
                    else:
                        value = None
                return (status, reason, hdrs, value)
            if reset_func:
                reset_func()
            self.sleep(2 ** attempt)
        raise Exception('%s %s failed: %s %s' % (method, path, status, reason))

Example 18

Project: swiftly
Source File: standardclient.py
View license
    def request(self, method, path, contents, headers, decode_json=False,
                stream=False, query=None, cdn=False):
        """
        See :py:func:`swiftly.client.client.Client.request`
        """
        if query:
            path += '?' + '&'.join(
                ('%s=%s' % (quote(k), quote(v)) if v else quote(k))
                for k, v in sorted(six.iteritems(query)))
        reset_func = self._default_reset_func
        if isinstance(contents, six.string_types):
            contents = StringIO(contents)
        tell = getattr(contents, 'tell', None)
        seek = getattr(contents, 'seek', None)
        if tell and seek:
            try:
                orig_pos = tell()
                reset_func = lambda: seek(orig_pos)
            except Exception:
                tell = seek = None
        elif not contents:
            reset_func = lambda: None
        status = 0
        reason = 'Unknown'
        attempt = 0
        while attempt < self.attempts:
            attempt += 1
            if time() >= self.conn_discard:
                self.storage_conn = None
                self.cdn_conn = None
            if cdn:
                conn = self.cdn_conn
                conn_path = self.cdn_path
            else:
                conn = self.storage_conn
                conn_path = self.storage_path
            if not conn:
                parsed, conn = self._connect(cdn=cdn)
                if conn:
                    if cdn:
                        self.cdn_conn = conn
                        self.cdn_path = conn_path = parsed.path
                    else:
                        self.storage_conn = conn
                        self.storage_path = conn_path = parsed.path
                else:
                    raise self.HTTPException(
                        '%s %s failed: No connection' % (method, path))
            self.conn_discard = time() + 4
            titled_headers = dict((k.title(), v) for k, v in six.iteritems({
                'User-Agent': self.user_agent,
                'X-Auth-Token': self.auth_token}))
            if headers:
                titled_headers.update(
                    (k.title(), v) for k, v in six.iteritems(headers))
            try:
                if not hasattr(contents, 'read'):
                    if method not in self.no_content_methods and contents and \
                            'Content-Length' not in titled_headers and \
                            'Transfer-Encoding' not in titled_headers:
                        titled_headers['Content-Length'] = str(
                            len(contents or ''))
                    verbose_headers = '  '.join(
                        '%s: %s' % (k, v)
                        for k, v in sorted(six.iteritems(titled_headers)))
                    self.verbose(
                        '> %s %s %s', method, conn_path + path,
                        verbose_headers)
                    conn.request(
                        method, conn_path + path, contents, titled_headers)
                else:
                    conn.putrequest(method, conn_path + path)
                    content_length = None
                    for h, v in sorted(six.iteritems(titled_headers)):
                        if h == 'Content-Length':
                            content_length = int(v)
                        conn.putheader(h, v)
                    if method not in self.no_content_methods and \
                            content_length is None:
                        titled_headers['Transfer-Encoding'] = 'chunked'
                        conn.putheader('Transfer-Encoding', 'chunked')
                    conn.endheaders()
                    verbose_headers = '  '.join(
                        '%s: %s' % (k, v)
                        for k, v in sorted(six.iteritems(titled_headers)))
                    self.verbose(
                        '> %s %s %s', method, conn_path + path,
                        verbose_headers)
                    if method not in self.no_content_methods and \
                            content_length is None:
                        chunk = contents.read(self.chunk_size)
                        while chunk:
                            conn.send('%x\r\n%s\r\n' % (len(chunk), chunk))
                            chunk = contents.read(self.chunk_size)
                        conn.send('0\r\n\r\n')
                    else:
                        left = content_length or 0
                        while left > 0:
                            size = self.chunk_size
                            if size > left:
                                size = left
                            chunk = contents.read(size)
                            if not chunk:
                                raise IOError('Early EOF from input')
                            conn.send(chunk)
                            left -= len(chunk)
                resp = conn.getresponse()
                status = resp.status
                reason = resp.reason
                hdrs = headers_to_dict(resp.getheaders())
                if stream:
                    value = resp
                else:
                    value = resp.read()
                    resp.close()
            except Exception as err:
                status = 0
                reason = '%s %s' % (type(err), str(err))
                hdrs = {}
                value = None
            self.verbose('< %s %s', status or '-', reason)
            self.verbose('< %s', hdrs)
            if status == 401:
                if stream:
                    value.close()
                conn.close()
                self.auth()
                attempt -= 1
            elif status and status // 100 != 5:
                if not stream and decode_json and status // 100 == 2:
                    if value:
                        value = json.loads(value.decode('utf-8'))
                    else:
                        value = None
                self.conn_discard = time() + 4
                return (status, reason, hdrs, value)
            else:
                if stream and value:
                    value.close()
                conn.close()
            if reset_func:
                reset_func()
            self.sleep(2 ** attempt)
        raise self.HTTPException(
            '%s %s failed: %s %s' % (method, path, status, reason))

Example 19

Project: pandas-profiling
Source File: base.py
View license
def to_html(sample, stats_object):
    """Generate a HTML report from summary statistics and a given sample.

    Parameters
    ----------
    sample: DataFrame containing the sample you want to print
    stats_object: Dictionary containing summary statistics. Should be generated with an appropriate describe() function

    Returns
    -------
    str, containing profile report in HTML format
    """

    n_obs = stats_object['table']['n']

    value_formatters = formatters.value_formatters
    row_formatters = formatters.row_formatters

    if not isinstance(sample, pd.DataFrame):
        raise TypeError("sample must be of type pandas.DataFrame")

    if not isinstance(stats_object, dict):
        raise TypeError("stats_object must be of type dict. Did you generate this using the pandas_profiling.describe() function?")

    if set(stats_object.keys()) != {'table', 'variables', 'freq'}:
        raise TypeError("stats_object badly formatted. Did you generate this using the pandas_profiling-eda.describe() function?")

    def fmt(value, name):
        if pd.isnull(value):
            return ""
        if name in value_formatters:
            return value_formatters[name](value)
        elif isinstance(value, float):
            return value_formatters[formatters.DEFAULT_FLOAT_FORMATTER](value)
        else:
            if sys.version_info.major == 3:
                return str(value)
            else:
                return unicode(value)

    def _format_row(freq, label, max_freq, row_template, n, extra_class=''):
            width = int(freq / max_freq * 99) + 1
            if width > 20:
                label_in_bar = freq
                label_after_bar = ""
            else:
                label_in_bar = "&nbsp;"
                label_after_bar = freq

            return row_template.render(label=label,
                                       width=width,
                                       count=freq,
                                       percentage='{:2.1f}'.format(freq / n * 100),
                                       extra_class=extra_class,
                                       label_in_bar=label_in_bar,
                                       label_after_bar=label_after_bar)

    def freq_table(freqtable, n, table_template, row_template, max_number_to_print):

        freq_rows_html = u''

        if max_number_to_print > n:
                max_number_to_print=n

        if max_number_to_print < len(freqtable):
            freq_other = sum(freqtable.iloc[max_number_to_print:])
            min_freq = freqtable.values[max_number_to_print]
        else:
            freq_other = 0
            min_freq = 0

        freq_missing = n - sum(freqtable)
        max_freq = max(freqtable.values[0], freq_other, freq_missing)

        # TODO: Correctly sort missing and other

        for label, freq in six.iteritems(freqtable.iloc[0:max_number_to_print]):
            freq_rows_html += _format_row(freq, label, max_freq, row_template, n)

        if freq_other > min_freq:
            freq_rows_html += _format_row(freq_other,
                                         "Other values (%s)" % (freqtable.count() - max_number_to_print), max_freq, row_template, n,
                                         extra_class='other')

        if freq_missing > min_freq:
            freq_rows_html += _format_row(freq_missing, "(Missing)", max_freq, row_template, n, extra_class='missing')

        return table_template.render(rows=freq_rows_html, varid=hash(idx))

    def extreme_obs_table(freqtable, table_template, row_template, number_to_print, n, ascending = True):
        if ascending:
            obs_to_print = freqtable.sort_index().iloc[:number_to_print]
        else:
            obs_to_print = freqtable.sort_index().iloc[-number_to_print:]

        freq_rows_html = ''
        max_freq = max(obs_to_print.values)

        for label, freq in six.iteritems(obs_to_print):
            freq_rows_html += _format_row(freq, label, max_freq, row_template, n)

        return table_template.render(rows=freq_rows_html)

    # Variables
    rows_html = u""
    messages = []

    for idx, row in stats_object['variables'].iterrows():

        formatted_values = {'varname': idx, 'varid': hash(idx)}
        row_classes = {}

        for col, value in six.iteritems(row):
            formatted_values[col] = fmt(value, col)

        for col in set(row.index) & six.viewkeys(row_formatters):
            row_classes[col] = row_formatters[col](row[col])
            if row_classes[col] == "alert" and col in templates.messages:
                messages.append(templates.messages[col].format(formatted_values, varname = formatters.fmt_varname(idx)))

        if row['type'] == 'CAT':
            formatted_values['minifreqtable'] = freq_table(stats_object['freq'][idx], n_obs,
                                                           templates.template('mini_freq_table'), templates.template('mini_freq_table_row'), 3)

            if row['distinct_count'] > 50:
                messages.append(templates.messages['HIGH_CARDINALITY'].format(formatted_values, varname = formatters.fmt_varname(idx)))
                row_classes['distinct_count'] = "alert"
            else:
                row_classes['distinct_count'] = ""

        if row['type'] == 'UNIQUE':
            obs = stats_object['freq'][idx].index

            formatted_values['firstn'] = pd.DataFrame(obs[0:3], columns=["First 3 values"]).to_html(classes="example_values", index=False)
            formatted_values['lastn'] = pd.DataFrame(obs[-3:], columns=["Last 3 values"]).to_html(classes="example_values", index=False)

        if row['type'] in {'CORR', 'CONST'}:
            formatted_values['varname'] = formatters.fmt_varname(idx)
            messages.append(templates.messages[row['type']].format(formatted_values))
        else:
            formatted_values['freqtable'] = freq_table(stats_object['freq'][idx], n_obs,
                                                       templates.template('freq_table'), templates.template('freq_table_row'), 10)
            formatted_values['firstn_expanded'] = extreme_obs_table(stats_object['freq'][idx], templates.template('freq_table'), templates.template('freq_table_row'), 5, n_obs, ascending = True)
            formatted_values['lastn_expanded'] = extreme_obs_table(stats_object['freq'][idx], templates.template('freq_table'), templates.template('freq_table_row'), 5, n_obs, ascending = False)

        rows_html += templates.row_templates_dict[row['type']].render(values=formatted_values, row_classes=row_classes)

    # Overview
    formatted_values = {k: fmt(v, k) for k, v in six.iteritems(stats_object['table'])}

    row_classes={}
    for col in six.viewkeys(stats_object['table']) & six.viewkeys(row_formatters):
        row_classes[col] = row_formatters[col](stats_object['table'][col])
        if row_classes[col] == "alert" and col in templates.messages:
            messages.append(templates.messages[col].format(formatted_values, varname = formatters.fmt_varname(idx)))

    messages_html = u''
    for msg in messages:
        messages_html += templates.message_row.format(message=msg)

    overview_html = templates.template('overview').render(values=formatted_values, row_classes = row_classes, messages=messages_html)

    # Sample

    sample_html = templates.template('sample').render(sample_table_html=sample.to_html(classes="sample"))
    # TODO: should be done in the template
    return templates.template('base').render({'overview_html': overview_html, 'rows_html': rows_html, 'sample_html': sample_html})

Example 20

Project: makina-states.pack1
Source File: ms_firewalld.py
View license
def _main(vopts,
          jconfig,
          errors=None,
          changes=None,
          apply_retry=0,
          **kwargs):
    if errors is None:
        errors = []
    if changes is None:
        changes = []
    # be sure that the firewall client is avalaible and ready
    fw()
    for z, zdata in six.iteritems(jconfig['zones']):
        try:
            masq = z in jconfig['public_zones']
            if not masq:
                masq = None
            # NEVER ACTIVATE GLOBAL MASQUERADE OR WE LL FIND YOU
            # AND WE WILL CUT YOUR FINGERS.
            # define_zone(z, zdata, masquerade=masq, errors=errors)
            # instead, use a rich rule to set masquerade via source/dest
            # matching to restrict correctly the application of the masquerade
            # perimeter
            define_zone(z, zdata, errors=errors, changes=changes,
                        apply_retry=apply_retry, **kwargs)
        except (Exception,) as ex:
            trace = traceback.format_exc()
            errors.append({'trace': trace,
                           'type': 'zone',
                           'id': z,
                           'exception': ex})
    for z, zdata in six.iteritems(jconfig['services']):
        try:
            define_service(z, zdata, errors=errors, changes=changes,
                           apply_retry=apply_retry, **kwargs)
        except (Exception,) as ex:
            trace = traceback.format_exc()
            errors.append({'trace': trace,
                           'type': 'service',
                           'id': z,
                           'exception': ex})
            log.error(trace)

    # for each of our known zones; collect interface mappings
    # this is an optimization as calling getZoneOfInterface is really slow
    interfaces = {}
    try:
        for z, izdata in six.iteritems(jconfig.get('zones', {})):
            for ifc in fw().getInterfaces(z):
                zns = interfaces.setdefault(ifc, [])
                zns.append(z)
    except (Exception,) as ex:
        trace = traceback.format_exc()
        log.error('zone may not exists yet')
        log.error(trace)

    for z, zdata in six.iteritems(jconfig['zones']):
        try:
            link_interfaces(z, zdata, interfaces, errors=errors,
                            changes=changes, apply_retry=apply_retry,
                            **kwargs)
        except (Exception,) as ex:
            trace = traceback.format_exc()
            errors.append({'trace': trace,
                           'type': 'interfaces',
                           'id': z,
                           'exception': ex})

    lazy_reload()

    for z, zdata in six.iteritems(jconfig['zones']):
        try:
            configure_rules(z, zdata, errors=errors, changes=changes,
                            apply_retry=apply_retry, **kwargs)
        except (Exception,) as ex:
            trace = traceback.format_exc()
            errors.append({'trace': trace,
                           'type': 'rules',
                           'id': z,
                           'exception': ex})

    try:
        configure_directs(jconfig, errors=errors, changes=changes,
                          apply_retry=apply_retry, **kwargs)
    except (Exception,) as ex:
        trace = traceback.format_exc()
        errors.append({'trace': trace,
                       'type': 'directs',
                       'id': z,
                       'exception': ex})

    log.info('end conf')
    errortypes = [a['type'] for a in errors]
    if 'zone' in errortypes:
        code = 1
    elif 'service' in errortypes:
        code = 2
    elif 'rules' in errortypes:
        code = 3
    elif 'rules/rule' in errortypes:
        code = 4
    elif 'interface' in errortypes:
        code = 5
    elif 'interfaces' in errortypes:
        code = 6
    elif 'direct/rule' in errortypes:
        code = 7
    elif 'directs' in errortypes:
        code = 8
    elif len(errors):
        code = 255
    else:
        code = 0
    return code

Example 21

Project: makina-states.pack1
Source File: ms_firewalld.py
View license
def _main(vopts,
          jconfig,
          errors=None,
          changes=None,
          apply_retry=0,
          **kwargs):
    if errors is None:
        errors = []
    if changes is None:
        changes = []
    # be sure that the firewall client is avalaible and ready
    fw()
    for z, zdata in six.iteritems(jconfig['zones']):
        try:
            masq = z in jconfig['public_zones']
            if not masq:
                masq = None
            # NEVER ACTIVATE GLOBAL MASQUERADE OR WE LL FIND YOU
            # AND WE WILL CUT YOUR FINGERS.
            # define_zone(z, zdata, masquerade=masq, errors=errors)
            # instead, use a rich rule to set masquerade via source/dest
            # matching to restrict correctly the application of the masquerade
            # perimeter
            define_zone(z, zdata, errors=errors, changes=changes,
                        apply_retry=apply_retry, **kwargs)
        except (Exception,) as ex:
            trace = traceback.format_exc()
            errors.append({'trace': trace,
                           'type': 'zone',
                           'id': z,
                           'exception': ex})
    for z, zdata in six.iteritems(jconfig['services']):
        try:
            define_service(z, zdata, errors=errors, changes=changes,
                           apply_retry=apply_retry, **kwargs)
        except (Exception,) as ex:
            trace = traceback.format_exc()
            errors.append({'trace': trace,
                           'type': 'service',
                           'id': z,
                           'exception': ex})
            log.error(trace)

    # for each of our known zones; collect interface mappings
    # this is an optimization as calling getZoneOfInterface is really slow
    interfaces = {}
    try:
        for z, izdata in six.iteritems(jconfig.get('zones', {})):
            for ifc in fw().getInterfaces(z):
                zns = interfaces.setdefault(ifc, [])
                zns.append(z)
    except (Exception,) as ex:
        trace = traceback.format_exc()
        log.error('zone may not exists yet')
        log.error(trace)

    for z, zdata in six.iteritems(jconfig['zones']):
        try:
            link_interfaces(z, zdata, interfaces, errors=errors,
                            changes=changes, apply_retry=apply_retry,
                            **kwargs)
        except (Exception,) as ex:
            trace = traceback.format_exc()
            errors.append({'trace': trace,
                           'type': 'interfaces',
                           'id': z,
                           'exception': ex})

    lazy_reload()

    for z, zdata in six.iteritems(jconfig['zones']):
        try:
            configure_rules(z, zdata, errors=errors, changes=changes,
                            apply_retry=apply_retry, **kwargs)
        except (Exception,) as ex:
            trace = traceback.format_exc()
            errors.append({'trace': trace,
                           'type': 'rules',
                           'id': z,
                           'exception': ex})

    try:
        configure_directs(jconfig, errors=errors, changes=changes,
                          apply_retry=apply_retry, **kwargs)
    except (Exception,) as ex:
        trace = traceback.format_exc()
        errors.append({'trace': trace,
                       'type': 'directs',
                       'id': z,
                       'exception': ex})

    log.info('end conf')
    errortypes = [a['type'] for a in errors]
    if 'zone' in errortypes:
        code = 1
    elif 'service' in errortypes:
        code = 2
    elif 'rules' in errortypes:
        code = 3
    elif 'rules/rule' in errortypes:
        code = 4
    elif 'interface' in errortypes:
        code = 5
    elif 'interfaces' in errortypes:
        code = 6
    elif 'direct/rule' in errortypes:
        code = 7
    elif 'directs' in errortypes:
        code = 8
    elif len(errors):
        code = 255
    else:
        code = 0
    return code

Example 22

Project: wiring
Source File: configuration.py
View license
    def __new__(cls, module_name, bases, attributes):
        special_attributes = (
            'providers',
            'instances',
            'factories',
            'functions',
        )
        module = super(ModuleMetaclass, cls).__new__(
            cls,
            module_name,
            bases,
            {
                key: value for key, value in six.iteritems(attributes)
                if key not in special_attributes
            }
        )

        providers = {}

        for ancestor in reversed(inspect.getmro(module)):
            if cls._is_module_class(ancestor):
                providers.update(ancestor.providers)

        already_provided = set()

        providers_attribute = attributes.get('providers', {})
        providers.update(providers_attribute)
        already_provided.update(six.iterkeys(providers_attribute))

        def check_specification(key):
            if key in already_provided:
                raise InvalidConfigurationError(
                    module,
                    "Multiple sources defined for specification {spec}".format(
                        spec=repr(key)
                    )
                )
            already_provided.add(key)

        for key, value in six.iteritems(attributes.get('instances', {})):
            check_specification(key)
            providers[key] = InstanceProvider(value)
        for key, value in six.iteritems(attributes.get('factories', {})):
            check_specification(key)
            if not isinstance(value, collections.Iterable):
                value = [value]
            if len(value) < 1 or len(value) > 2:
                raise InvalidConfigurationError(
                    module,
                    (
                        "Wrong number of arguments for {spec} in"
                        " `factories`."
                    ).format(
                        spec=repr(key)
                    )
                )
            providers[key] = FactoryProvider(
                value[0],
                scope=(value[1] if len(value) > 1 else None)
            )
        for key, value in six.iteritems(attributes.get('functions', {})):
            check_specification(key)
            providers[key] = FunctionProvider(value)

        for key, value in six.iteritems(attributes):
            if hasattr(value, '__provides__'):
                check_specification(value.__provides__)

        module.providers = providers

        return module

Example 23

Project: anvil
Source File: base.py
View license
    def _run_many_phase(self, functors, group, instances, phase_name, *inv_phase_names):
        """Run a given 'functor' across all of the components, passing *all* instances to run."""

        # This phase recorder will be used to check if a given component
        # and action has ran in the past, if so that components action
        # will not be ran again. It will also be used to mark that a given
        # component has completed a phase (if that phase runs).
        if not phase_name:
            phase_recorder = phase.NullPhaseRecorder()
        else:
            phase_recorder = phase.PhaseRecorder(self._get_phase_filename(phase_name))

        # These phase recorders will be used to undo other actions activities
        # ie, when an install completes you want the uninstall phase to be
        # removed from that actions phase file (and so on). This list will be
        # used to accomplish that.
        neg_phase_recs = []
        if inv_phase_names:
            for n in inv_phase_names:
                if not n:
                    neg_phase_recs.append(phase.NullPhaseRecorder())
                else:
                    neg_phase_recs.append(phase.PhaseRecorder(self._get_phase_filename(n)))

        def change_activate(instance, on_off):
            # Activate/deactivate a component instance and there siblings (if any)
            #
            # This is used when you say are looking at components
            # that have been activated before your component has been.
            #
            # Typically this is useful for checking if a previous component
            # has a shared dependency with your component and if so then there
            # is no need to reinstall said dependency...
            instance.activated = on_off
            for (_name, sibling_instance) in instance.siblings.items():
                sibling_instance.activated = on_off

        def run_inverse_recorders(c_name):
            for n in neg_phase_recs:
                n.unmark(c_name)

        # Reset all activations
        for c, instance in six.iteritems(instances):
            change_activate(instance, False)

        # Run all components which have not been ran previously (due to phase tracking)
        instances_started = utils.OrderedDict()
        for c, instance in six.iteritems(instances):
            if c in SPECIAL_GROUPS:
                c = "%s_%s" % (c, group)
            if c in phase_recorder:
                LOG.debug("Skipping phase named %r for component %r since it already happened.", phase_name, c)
            else:
                try:
                    with phase_recorder.mark(c):
                        if functors.start:
                            functors.start(instance)
                        instances_started[c] = instance
                except excp.NoTraceException:
                    pass
        if functors.run:
            results = functors.run(list(six.itervalues(instances_started)))
        else:
            results = [None] * len(instances_started)
        instances_ran = instances_started
        for i, (c, instance) in enumerate(six.iteritems(instances_ran)):
            result = results[i]
            try:
                with phase_recorder.mark(c):
                    if functors.end:
                        functors.end(instance, result)
            except excp.NoTraceException:
                pass
        for c, instance in six.iteritems(instances_ran):
            change_activate(instance, True)
            run_inverse_recorders(c)

Example 24

Project: fuel-web
Source File: test_plugin_adapters.py
View license
    def test_get_metadata(self):
        plugin_metadata = self.env.get_default_plugin_metadata()
        attributes_metadata = self.env.get_default_plugin_env_config()
        roles_metadata = self.env.get_default_plugin_node_roles_config()
        volumes_metadata = self.env.get_default_plugin_volumes_config()
        network_roles_metadata = self.env.get_default_network_roles_config()
        deployment_tasks = self.env.get_default_plugin_deployment_tasks()
        tasks = self.env.get_default_plugin_tasks()
        components_metadata = self.env.get_default_components()

        nic_attributes_metadata = self.env.get_default_plugin_nic_config()
        bond_attributes_metadata = self.env.get_default_plugin_bond_config()
        node_attributes_metadata = self.env.get_default_plugin_node_config()

        plugin_metadata.update({
            'attributes_metadata': attributes_metadata,
            'roles_metadata': roles_metadata,
            'volumes_metadata': volumes_metadata,
            'network_roles_metadata': network_roles_metadata,
            'deployment_tasks': deployment_tasks,
            'tasks': tasks,
            'components_metadata': components_metadata,
            'nic_attributes_metadata': nic_attributes_metadata,
            'bond_attributes_metadata': bond_attributes_metadata,
            'node_attributes_metadata': node_attributes_metadata,
            'graphs': [{
                'type': 'custom',
                'name': 'custom',
                'tasks': [
                    {'id': 'task{}'.format(n), 'type': 'puppet'}
                    for n in range(2)
                ]
            }]
        })

        with mock.patch.object(
                self.plugin_adapter, 'loader') as loader:
            loader.load.return_value = (plugin_metadata, ReportNode())
            Plugin.update(self.plugin, self.plugin_adapter.get_metadata())
            for key, val in six.iteritems(
                {
                    k: v for (k, v) in six.iteritems(plugin_metadata)
                    if k not in ('deployment_tasks', 'graphs')
                }
            ):
                self.assertEqual(
                    getattr(self.plugin, key), val)

            self.assertEqual(
                self.plugin.attributes_metadata,
                attributes_metadata['attributes'])
            self.assertEqual(
                self.plugin.roles_metadata, roles_metadata)
            self.assertEqual(
                self.plugin.volumes_metadata, volumes_metadata)
            self.assertEqual(
                self.plugin.tasks, tasks)
            self.assertEqual(
                self.plugin.components_metadata, components_metadata)
            self.assertEqual(
                self.plugin.nic_attributes_metadata,
                nic_attributes_metadata)
            self.assertEqual(
                self.plugin.bond_attributes_metadata,
                bond_attributes_metadata)
            self.assertEqual(
                self.plugin.node_attributes_metadata,
                node_attributes_metadata)

            # check custom graph
            dg = DeploymentGraph.get_for_model(
                self.plugin, graph_type='custom'
            )
            self.assertEqual(dg.name, 'custom')
            self.assertItemsEqual(
                DeploymentGraph.get_tasks(dg),
                [
                    {
                        'id': 'task{}'.format(i),
                        'task_name':
                        'task{}'.format(i),
                        'type': 'puppet',
                        'version': '1.0.0'
                    } for i in range(2)
                ]
            )
            # deployment tasks returning all non-defined fields, so check
            # should differ from JSON-stored fields
            plugin_tasks = self.env.get_default_plugin_deployment_tasks()
            self.assertGreater(len(plugin_tasks), 0)
            for k, v in six.iteritems(plugin_tasks[0]):
                # this field is updated by plugin adapter
                if k is 'parameters':
                    v.update({
                        'cwd': '/etc/fuel/plugins/testing_plugin-0.1/'
                    })
                self.assertEqual(
                    self.plugin_adapter.get_deployment_tasks()[0][k],
                    v)

Example 25

Project: octavia
Source File: update_db.py
View license
    def update_health(self, health):
        """This function is to update db info based on amphora status

        :param health: map object that contains amphora, listener, member info
        :type map: string
        :returns: null

        The input health data structure is shown as below:

        health = {
            "id": self.FAKE_UUID_1,
            "listeners": {
                "listener-id-1": {"status": constants.OPEN, "pools": {
                    "pool-id-1": {"status": constants.UP,
                                  "members": {"member-id-1": constants.ONLINE}
                                  }
                }
                }
            }
        }

        """
        session = db_api.get_session()

        # We need to see if all of the listeners are reporting in
        expected_listener_count = 0
        lbs_on_amp = self.amphora_repo.get_all_lbs_on_amphora(session,
                                                              health['id'])
        for lb in lbs_on_amp:
            listener_count = self.listener_repo.count(session,
                                                      load_balancer_id=lb.id)
            expected_listener_count += listener_count

        listeners = health['listeners']

        # Do not update amphora health if the reporting listener count
        # does not match the expected listener count
        if len(listeners) == expected_listener_count:

            # if the input amphora is healthy, we update its db info
            self.amphora_health_repo.replace(session, health['id'],
                                             last_update=(datetime.
                                                          datetime.utcnow()))
        else:
            LOG.warning(_LW('Amphora %(id)s health message reports %(found)i '
                            'listeners when %(expected)i expected'),
                        {'id': health['id'],
                         'found': len(listeners),
                         'expected': expected_listener_count})

        # We got a heartbeat so lb is healthy until proven otherwise
        lb_status = constants.ONLINE

        # update listener and nodes db information
        for listener_id, listener in six.iteritems(listeners):

            listener_status = None
            # OPEN = HAProxy listener status nbconn < maxconn
            if listener.get('status') == constants.OPEN:
                listener_status = constants.ONLINE
            # FULL = HAProxy listener status not nbconn < maxconn
            elif listener.get('status') == constants.FULL:
                listener_status = constants.DEGRADED
                if lb_status == constants.ONLINE:
                    lb_status = constants.DEGRADED
            else:
                LOG.warning(_LW('Listener %(list)s reported status of '
                                '%(status)s'), {'list': listener_id,
                            'status': listener.get('status')})

            try:
                if listener_status is not None:
                    self._update_status_and_emit_event(
                        session, self.listener_repo, constants.LISTENER,
                        listener_id, listener_status
                    )
            except sqlalchemy.orm.exc.NoResultFound:
                LOG.error(_LE("Listener %s is not in DB"), listener_id)

            pools = listener['pools']
            for pool_id, pool in six.iteritems(pools):

                pool_status = None
                # UP = HAProxy backend has working or no servers
                if pool.get('status') == constants.UP:
                    pool_status = constants.ONLINE
                # DOWN = HAProxy backend has no working servers
                elif pool.get('status') == constants.DOWN:
                    pool_status = constants.ERROR
                    lb_status = constants.ERROR
                else:
                    LOG.warning(_LW('Pool %(pool)s reported status of '
                                    '%(status)s'), {'pool': pool_id,
                                'status': pool.get('status')})

                members = pool['members']
                for member_id, status in six.iteritems(members):

                    member_status = None
                    if status == constants.UP:
                        member_status = constants.ONLINE
                    elif status == constants.DOWN:
                        member_status = constants.ERROR
                        if pool_status == constants.ONLINE:
                            pool_status = constants.DEGRADED
                            if lb_status == constants.ONLINE:
                                lb_status = constants.DEGRADED
                    elif status == constants.NO_CHECK:
                        member_status = constants.NO_MONITOR
                    else:
                        LOG.warning(_LW('Member %(mem)s reported status of '
                                        '%(status)s'), {'mem': member_id,
                                    'status': status})

                    try:
                        if member_status is not None:
                            self._update_status_and_emit_event(
                                session, self.member_repo, constants.MEMBER,
                                member_id, member_status
                            )
                    except sqlalchemy.orm.exc.NoResultFound:
                        LOG.error(_LE("Member %s is not able to update "
                                      "in DB"), member_id)

                try:
                    if pool_status is not None:
                        self._update_status_and_emit_event(
                            session, self.pool_repo, constants.POOL,
                            pool_id, pool_status
                        )
                except sqlalchemy.orm.exc.NoResultFound:
                    LOG.error(_LE("Pool %s is not in DB"), pool_id)

        # Update the load balancer status last
        # TODO(sbalukoff): This logic will need to be adjusted if we
        # start supporting multiple load balancers per amphora
        lb_id = self.amphora_repo.get(
            session, id=health['id']).load_balancer_id
        if lb_id is not None:
            try:
                self._update_status_and_emit_event(
                    session, self.loadbalancer_repo,
                    constants.LOADBALANCER, lb_id, lb_status
                )
            except sqlalchemy.orm.exc.NoResultFound:
                LOG.error(_LE("Load balancer %s is not in DB"), lb_id)

Example 26

Project: sqlalchemy-migrate
Source File: shell.py
View license
def main(argv=None, **kwargs):
    """Shell interface to :mod:`migrate.versioning.api`.

    kwargs are default options that can be overriden with passing
    --some_option as command line option

    :param disable_logging: Let migrate configure logging
    :type disable_logging: bool
    """
    if argv is not None:
        argv = argv
    else:
        argv = list(sys.argv[1:])
    commands = list(api.__all__)
    commands.sort()

    usage = """%%prog COMMAND ...

    Available commands:
        %s

    Enter "%%prog help COMMAND" for information on a particular command.
    """ % '\n\t'.join(["%s - %s" % (command.ljust(28), api.command_desc.get(command)) for command in commands])

    parser = PassiveOptionParser(usage=usage)
    parser.add_option("-d", "--debug",
                     action="store_true",
                     dest="debug",
                     default=False,
                     help="Shortcut to turn on DEBUG mode for logging")
    parser.add_option("-q", "--disable_logging",
                      action="store_true",
                      dest="disable_logging",
                      default=False,
                      help="Use this option to disable logging configuration")
    help_commands = ['help', '-h', '--help']
    HELP = False

    try:
        command = argv.pop(0)
        if command in help_commands:
            HELP = True
            command = argv.pop(0)
    except IndexError:
        parser.print_help()
        return

    command_func = getattr(api, command, None)
    if command_func is None or command.startswith('_'):
        parser.error("Invalid command %s" % command)

    parser.set_usage(inspect.getdoc(command_func))
    f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func)
    for arg in f_args:
        parser.add_option(
            "--%s" % arg,
            dest=arg,
            action='store',
            type="string")

    # display help of the current command
    if HELP:
        parser.print_help()
        return

    options, args = parser.parse_args(argv)

    # override kwargs with anonymous parameters
    override_kwargs = dict()
    for arg in list(args):
        if arg.startswith('--'):
            args.remove(arg)
            if '=' in arg:
                opt, value = arg[2:].split('=', 1)
            else:
                opt = arg[2:]
                value = True
            override_kwargs[opt] = value

    # override kwargs with options if user is overwriting
    for key, value in six.iteritems(options.__dict__):
        if value is not None:
            override_kwargs[key] = value

    # arguments that function accepts without passed kwargs
    f_required = list(f_args)
    candidates = dict(kwargs)
    candidates.update(override_kwargs)
    for key, value in six.iteritems(candidates):
        if key in f_args:
            f_required.remove(key)

    # map function arguments to parsed arguments
    for arg in args:
        try:
            kw = f_required.pop(0)
        except IndexError:
            parser.error("Too many arguments for command %s: %s" % (command,
                                                                    arg))
        kwargs[kw] = arg

    # apply overrides
    kwargs.update(override_kwargs)

    # configure options
    for key, value in six.iteritems(options.__dict__):
        kwargs.setdefault(key, value)

    # configure logging
    if not asbool(kwargs.pop('disable_logging', False)):
        # filter to log =< INFO into stdout and rest to stderr
        class SingleLevelFilter(logging.Filter):
            def __init__(self, min=None, max=None):
                self.min = min or 0
                self.max = max or 100

            def filter(self, record):
                return self.min <= record.levelno <= self.max

        logger = logging.getLogger()
        h1 = logging.StreamHandler(sys.stdout)
        f1 = SingleLevelFilter(max=logging.INFO)
        h1.addFilter(f1)
        h2 = logging.StreamHandler(sys.stderr)
        f2 = SingleLevelFilter(min=logging.WARN)
        h2.addFilter(f2)
        logger.addHandler(h1)
        logger.addHandler(h2)

        if options.debug:
            logger.setLevel(logging.DEBUG)
        else:
            logger.setLevel(logging.INFO)

    log = logging.getLogger(__name__)

    # check if all args are given
    try:
        num_defaults = len(f_defaults)
    except TypeError:
        num_defaults = 0
    f_args_default = f_args[len(f_args) - num_defaults:]
    required = list(set(f_required) - set(f_args_default))
    required.sort()
    if required:
        parser.error("Not enough arguments for command %s: %s not specified" \
            % (command, ', '.join(required)))

    # handle command
    try:
        ret = command_func(**kwargs)
        if ret is not None:
            log.info(ret)
    except (exceptions.UsageError, exceptions.KnownError) as e:
        parser.error(e.args[0])

Example 27

Project: yaql
Source File: specs.py
View license
    def get_delegate(self, receiver, engine, context, args, kwargs):
        def checked(val, param):
            if not param.value_type.check(val, context, engine):
                raise exceptions.ArgumentException(param.name)

            def convert_arg_func(context2):
                try:
                    return param.value_type.convert(
                        val, receiver, context2, self, engine)
                except exceptions.ArgumentValueException:
                    raise exceptions.ArgumentException(param.name)
            return convert_arg_func

        kwargs = kwargs.copy()
        kwargs = dict(kwargs)
        positional = 0
        for arg_name, p in six.iteritems(self.parameters):
            if p.position is not None and arg_name != '*':
                positional += 1

        positional_args = positional * [None]
        positional_fix_table = positional * [0]
        keyword_args = {}

        for p in six.itervalues(self.parameters):
            if p.position is not None and isinstance(
                    p.value_type, yaqltypes.HiddenParameterType):
                for index in range(p.position + 1, positional):
                    positional_fix_table[index] += 1

        for key, p in six.iteritems(self.parameters):
            arg_name = p.alias or p.name
            if p.position is not None and key != '*':
                if isinstance(p.value_type, yaqltypes.HiddenParameterType):
                    positional_args[p.position] = checked(None, p)
                    positional -= 1
                elif p.position - positional_fix_table[p.position] < len(
                        args) and args[p.position - positional_fix_table[
                            p.position]] is not utils.NO_VALUE:
                    if arg_name in kwargs:
                        raise exceptions.ArgumentException(p.name)
                    positional_args[p.position] = checked(
                        args[p.position - positional_fix_table[
                            p.position]], p)
                elif arg_name in kwargs:
                    positional_args[p.position] = checked(
                        kwargs.pop(arg_name), p)
                elif p.default is not NO_DEFAULT:
                    positional_args[p.position] = checked(p.default, p)
                else:
                    raise exceptions.ArgumentException(p.name)
            elif p.position is None and key != '**':
                if isinstance(p.value_type, yaqltypes.HiddenParameterType):
                    keyword_args[key] = checked(None, p)
                elif arg_name in kwargs:
                    keyword_args[key] = checked(kwargs.pop(arg_name), p)
                elif p.default is not NO_DEFAULT:
                    keyword_args[key] = checked(p.default, p)
                else:
                    raise exceptions.ArgumentException(p.name)
        if len(args) > positional:
            if '*' in self.parameters:
                argdef = self.parameters['*']
                positional_args.extend(
                    map(lambda t: checked(t, argdef), args[positional:]))
            else:
                raise exceptions.ArgumentException('*')
        if len(kwargs) > 0:
            if '**' in self.parameters:
                argdef = self.parameters['**']
                for key, value in six.iteritems(kwargs):
                    keyword_args[key] = checked(value, argdef)
            else:
                raise exceptions.ArgumentException('**')

        def func():
            new_context = context.create_child_context()
            result = self.payload(
                *tuple(map(lambda t: t(new_context),
                           positional_args)),
                **dict(map(lambda t: (t[0], t[1](new_context)),
                           six.iteritems(keyword_args)))
            )
            return result

        return func

Example 28

Project: crowdin-cli
Source File: connection.py
View license
    def export_pattern_to_path(self, lang, download=None):
        # translation = {}
        lang_info = []
        get_sources_translations = self.get_files_source()
        for value_source, value_translation, translations_params in zip(get_sources_translations[::3],
                                                                        get_sources_translations[1::3],
                                                                        get_sources_translations[2::3]):
            translation = {}

            if '**' in value_translation:
                logger.info("Translation pattern `{0}` is not valid. The mask `**` "
                            "can't be used. When using `**` in 'translation' pattern it will always "
                            "contain sub-path from 'source' for certain file.".format(value_translation))

            for l in lang:
                path = value_source
                if '/' in path:
                    original_file_name = path[1:][path.rfind("/"):]
                    file_name = path[1:][path.rfind("/"):].split(".")[0]
                    original_path = path[:path.rfind("/")]
                else:
                    original_file_name = path
                    original_path = ''
                    file_name = path.split(".")[0]

                file_extension = path.split(".")[-1]

                pattern = {
                    '%original_file_name%': original_file_name,
                    '%original_path%': original_path,
                    '%file_extension%': file_extension,
                    '%file_name%': file_name,
                    '%language%': l['name'],
                    '%two_letters_code%': l['iso_639_1'],
                    '%three_letters_code%': l['iso_639_3'],
                    '%locale%': l['locale'],
                    '%crowdin_code%': l['crowdin_code'],
                    '%locale_with_underscore%': l['locale'].replace('-', '_'),
                    '%android_code%': self.android_locale_code(l['locale']),
                    '%osx_code%': self.osx_language_code(l['crowdin_code']) + '.lproj',
                    '%osx_xliff%': self.osx_language_code(l['crowdin_code']) + '.xliff',
                }
                if not download:
                    if 'languages_mapping' in translations_params:
                        try:
                            for i in six.iteritems(translations_params['languages_mapping']):
                                if not i[1] is None:
                                    true_key = ''.join(('%', i[0], '%'))
                                    for k, v in six.iteritems(i[1]):
                                        if l['crowdin_code'] == k:
                                            for key, value in pattern.items():
                                                if key == true_key:
                                                    pattern[key] = v

                        # for i in translations_params['languages_mapping'].iteritems():
                        #     if not i[1] is None:
                        #         rep = dict((re.escape(k), v) for k, v in i[1].iteritems())
                        #         patter = re.compile("|".join(rep.keys()))
                        #         true_key = ''.join(('%', i[0], '%'))
                        #         for key, value in pattern.items():
                        #             if key == true_key:
                        #                 pattern[key] = patter.sub(lambda m: rep[re.escape(m.group(0))], value)

                        except Exception as e:
                            print(e, 'It seems that languages_mapping is not set correctly')
                            exit(1)
                m = re.search("%[a-z0-9_]*?%", value_translation)
                if m.group(0) not in pattern:
                    print('Warning: {} is not valid variable supported by Crowdin. See '
                          'http://crowdin.com/page/cli-tool#configuration-file for more details.'.format(m.group(0)))
                    exit()
                path_lang = value_translation
                rep = dict((re.escape(k), v) for k, v in six.iteritems(pattern))
                patter = re.compile("|".join(rep.keys()))
                text = patter.sub(lambda m: rep[re.escape(m.group(0))], path_lang)
                if text not in translation:
                    translation[l['crowdin_code']] = (text[1:] if text[:1] == '/' else text).replace('//', '/', 1)

                if path not in lang_info:
                    lang_info.append(path)
                    lang_info.append(translation)
                    lang_info.append(translations_params)
        return lang_info

Example 29

Project: crowdin-cli
Source File: methods.py
View license
    def download_project(self):
        # GET https://api.crowdin.com/api/project/{project-identifier}/download/{package}.zip?key={project-key}
        self.build_project()
        base_path = os.path.normpath(Configuration(self.options_config).get_base_path()) + os.sep
        if self.any_options.dlanguage:
            lang = self.any_options.dlanguage
        else:
            lang = "all"
        url = {'post': 'GET', 'url_par1': '/api/project/', 'url_par2': True,
               'url_par3': '/download/{0}.zip'.format(lang), 'url_par4': True}
        params = {'json': 'json'}
        if self.any_options.branch:
            params['branch'] = self.any_options.branch
        # files that exists in archive and doesn't match current project configuration
        unmatched_files = []

        with zipfile.ZipFile(io.BytesIO(self.true_connection(url, params))) as z:
            # for i in self.exists(Configuration().get_files_source()):
            unzip_dict = {}
            lang = self.lang()
            translations_file = Configuration(self.options_config).export_pattern_to_path(lang, download=True)
            trans_file_no_mapping = Configuration(self.options_config).export_pattern_to_path(lang)
            for i, y in zip(translations_file[1::3], trans_file_no_mapping[1::3]):
                for k, v in six.iteritems(y):
                    for key, value in six.iteritems(i):
                        if k == key:
                            unzip_dict[value] = v
                            if self.any_options.branch:
                                unzip_dict[self.any_options.branch + '/' + value] = v

            initial_files = unzip_dict.keys()
            for target_lang in lang:
                for source_file in list(initial_files):
                    # change only for target_lang files
                    for lang_key in target_lang:
                        if target_lang[lang_key] in source_file:
                            if source_file == unzip_dict[source_file]:
                                f = os.path.basename(source_file)
                            else:
                                r_source = list(reversed(source_file.split('/')))
                                r_target = list(reversed(unzip_dict[source_file].split('/')))
                                f = ''
                                # print(r_source)
                                # print(r_target)
                                for i in range(len(r_target)-1):
                                    if r_target[i] == r_source[i]:
                                        f = '/' + r_target[i] + f

                            if not self.any_options.branch:
                                k = target_lang[lang_key] + '/' + f
                            else:
                                k = self.any_options.branch + '/' + target_lang[lang_key] + '/' + f
                            k = k.replace('//', '/')
                            unzip_dict[k] = unzip_dict[source_file]

            matched_files = []
            for structure in z.namelist():
                if not structure.endswith("/"):
                    for key, value in six.iteritems(unzip_dict):
                        if structure == key:
                            matched_files.append(structure)
                            source = z.open(structure)
                            target_path = os.path.join(base_path, value)
                            target_dir = os.path.dirname(target_path)
                            if not os.path.isdir(target_dir):
                                os.makedirs(target_dir)

                            target = open(target_path, "wb")
                            logger.info("Download: {0} to {1}".format(key, target_path))
                            with source, target:
                                shutil.copyfileobj(source, target)
                                # z.extract(structure, base_path)

                    if structure not in unmatched_files and structure not in matched_files:
                        unmatched_files.append(structure)

            if unmatched_files:
                logger.warning(
                    "\n Warning: Downloaded translations do not match current project configuration. "
                    "Some of the resulted files will be omitted."
                )
                for i in unmatched_files:
                    print(i)

Example 30

Project: paramnormal
Source File: plot_directive.py
View license
def run(arguments, content, options, state_machine, state, lineno):
    # The user may provide a filename *or* Python code content, but not both
    if arguments and content:
        raise RuntimeError("plot:: directive can't have both args and content")

    document = state_machine.document
    config = document.settings.env.config
    nofigs = 'nofigs' in options

    options.setdefault('include-source', config.plot_include_source)
    keep_context = 'context' in options
    context_opt = None if not keep_context else options['context']

    rst_file = document.attributes['source']
    rst_dir = os.path.dirname(rst_file)

    if len(arguments):
        if not config.plot_basedir:
            source_file_name = os.path.join(setup.app.builder.srcdir,
                                            directives.uri(arguments[0]))
        else:
            source_file_name = os.path.join(setup.confdir, config.plot_basedir,
                                            directives.uri(arguments[0]))

        # If there is content, it will be passed as a caption.
        caption = '\n'.join(content)

        # If the optional function name is provided, use it
        if len(arguments) == 2:
            function_name = arguments[1]
        else:
            function_name = None

        with io.open(source_file_name, 'r', encoding='utf-8') as fd:
            code = fd.read()
        output_base = os.path.basename(source_file_name)
    else:
        source_file_name = rst_file
        code = textwrap.dedent("\n".join(map(str, content)))
        counter = document.attributes.get('_plot_counter', 0) + 1
        document.attributes['_plot_counter'] = counter
        base, ext = os.path.splitext(os.path.basename(source_file_name))
        output_base = '%s-%d.py' % (base, counter)
        function_name = None
        caption = ''

    base, source_ext = os.path.splitext(output_base)
    if source_ext in ('.py', '.rst', '.txt'):
        output_base = base
    else:
        source_ext = ''

    # ensure that LaTeX includegraphics doesn't choke in foo.bar.pdf filenames
    output_base = output_base.replace('.', '-')

    # is it in doctest format?
    is_doctest = contains_doctest(code)
    if 'format' in options:
        if options['format'] == 'python':
            is_doctest = False
        else:
            is_doctest = True

    # determine output directory name fragment
    source_rel_name = relpath(source_file_name, setup.confdir)
    source_rel_dir = os.path.dirname(source_rel_name)
    while source_rel_dir.startswith(os.path.sep):
        source_rel_dir = source_rel_dir[1:]

    # build_dir: where to place output files (temporarily)
    build_dir = os.path.join(os.path.dirname(setup.app.doctreedir),
                             'plot_directive',
                             source_rel_dir)
    # get rid of .. in paths, also changes pathsep
    # see note in Python docs for warning about symbolic links on Windows.
    # need to compare source and dest paths at end
    build_dir = os.path.normpath(build_dir)

    if not os.path.exists(build_dir):
        os.makedirs(build_dir)

    # output_dir: final location in the builder's directory
    dest_dir = os.path.abspath(os.path.join(setup.app.builder.outdir,
                                            source_rel_dir))
    if not os.path.exists(dest_dir):
        os.makedirs(dest_dir) # no problem here for me, but just use built-ins

    # how to link to files from the RST file
    dest_dir_link = os.path.join(relpath(setup.confdir, rst_dir),
                                 source_rel_dir).replace(os.path.sep, '/')
    build_dir_link = relpath(build_dir, rst_dir).replace(os.path.sep, '/')
    source_link = dest_dir_link + '/' + output_base + source_ext

    # make figures
    try:
        results = render_figures(code,
                                 source_file_name,
                                 build_dir,
                                 output_base,
                                 keep_context,
                                 function_name,
                                 config,
                                 context_reset=context_opt == 'reset',
                                 close_figs=context_opt == 'close-figs')
        errors = []
    except PlotError as err:
        reporter = state.memo.reporter
        sm = reporter.system_message(
            2, "Exception occurred in plotting %s\n from %s:\n%s" % (output_base,
                                                source_file_name, err),
            line=lineno)
        results = [(code, [])]
        errors = [sm]

    # Properly indent the caption
    caption = '\n'.join('      ' + line.strip()
                        for line in caption.split('\n'))

    # generate output restructuredtext
    total_lines = []
    for j, (code_piece, images) in enumerate(results):
        if options['include-source']:
            if is_doctest:
                lines = ['']
                lines += [row.rstrip() for row in code_piece.split('\n')]
            else:
                lines = ['.. code-block:: python', '']
                lines += ['    %s' % row.rstrip()
                          for row in code_piece.split('\n')]
            source_code = "\n".join(lines)
        else:
            source_code = ""

        if nofigs:
            images = []

        opts = [':%s: %s' % (key, val) for key, val in six.iteritems(options)
                if key in ('alt', 'height', 'width', 'scale', 'align', 'class')]

        only_html = ".. only:: html"
        only_latex = ".. only:: latex"
        only_texinfo = ".. only:: texinfo"

        # Not-None src_link signals the need for a source link in the generated
        # html
        if j == 0 and config.plot_html_show_source_link:
            src_link = source_link
        else:
            src_link = None

        result = format_template(
            config.plot_template or TEMPLATE,
            dest_dir=dest_dir_link,
            build_dir=build_dir_link,
            source_link=src_link,
            multi_image=len(images) > 1,
            only_html=only_html,
            only_latex=only_latex,
            only_texinfo=only_texinfo,
            options=opts,
            images=images,
            source_code=source_code,
            html_show_formats=config.plot_html_show_formats and not nofigs,
            caption=caption)

        total_lines.extend(result.split("\n"))
        total_lines.extend("\n")

    if total_lines:
        state_machine.insert_input(total_lines, source=source_file_name)

    # copy image files to builder's output directory, if necessary
    if not os.path.exists(dest_dir):
        cbook.mkdirs(dest_dir)

    for code_piece, images in results:
        for img in images:
            for fn in img.filenames():
                destimg = os.path.join(dest_dir, os.path.basename(fn))
                if fn != destimg:
                    shutil.copyfile(fn, destimg)

    # copy script (if necessary)
    target_name = os.path.join(dest_dir, output_base + source_ext)
    with io.open(target_name, 'w', encoding="utf-8") as f:
        if source_file_name == rst_file:
            code_escaped = unescape_doctest(code)
        else:
            code_escaped = code
        f.write(code_escaped)

    return errors

Example 31

Project: pycounter
Source File: sushi.py
View license
def _raw_to_full(raw_report):
    """Convert a raw report to CounterReport.

    :param raw_report: raw XML report
    :return: a :class:`pycounter.report.CounterReport`
    """
    try:
        root = etree.fromstring(raw_report)
    except etree.XMLSyntaxError:
        logger.error("XML syntax error: %s", raw_report)
        raise pycounter.exceptions.SushiException(
            message="XML syntax error",
            raw=raw_report)
    o_root = objectify.fromstring(raw_report)
    rep = None
    try:
        rep = o_root.Body[_ns('sushicounter', "ReportResponse")]
        c_report = rep.Report[_ns('counter', 'Report')]
    except AttributeError:
        try:
            c_report = rep.Report[_ns('counter', 'Reports')].Report
        except AttributeError:
            logger.error("report not found in XML: %s", raw_report)
            raise pycounter.exceptions.SushiException(
                message="report not found in XML",
                raw=raw_report, xml=o_root)
    logger.debug("COUNTER report: %s", etree.tostring(c_report))
    start_date = datetime.datetime.strptime(
        root.find('.//%s' % _ns('sushi', 'Begin')).text,
        "%Y-%m-%d").date()

    end_date = datetime.datetime.strptime(
        root.find('.//%s' % _ns('sushi', 'End')).text,
        "%Y-%m-%d").date()

    report_data = {'period': (start_date, end_date)}

    rep_def = root.find('.//%s' % _ns('sushi', 'ReportDefinition'))
    report_data['report_version'] = int(rep_def.get('Release'))

    report_data['report_type'] = rep_def.get('Name')

    customer = root.find('.//%s' % _ns('counter', 'Customer'))
    try:
        report_data['customer'] = (customer.find('.//%s' %
                                                 _ns('counter', 'Name')).text)
    except AttributeError:
        report_data['customer'] = ""

    inst_id = customer.find('.//%s' % _ns('counter', 'ID')).text
    report_data['institutional_identifier'] = inst_id

    rep_root = root.find('.//%s' % _ns('counter', 'Report'))
    created_string = rep_root.get('Created')
    if created_string is not None:
        report_data['date_run'] = arrow.get(created_string)
    else:
        report_data['date_run'] = datetime.datetime.now()

    report = pycounter.report.CounterReport(**report_data)

    report.metric = pycounter.constants.METRICS.get(report_data['report_type'])

    for item in c_report.Customer.ReportItems:
        try:
            publisher_name = item.ItemPublisher.text
        except AttributeError:
            publisher_name = ""
        title = item.ItemName.text
        platform = item.ItemPlatform.text

        eissn = issn = isbn = ""

        try:
            for identifier in item.ItemIdentifier:
                if identifier.Type == "Print_ISSN":
                    issn = identifier.Value.text
                    if issn is None:
                        issn = ""
                elif identifier.Type == "Online_ISSN":
                    eissn = identifier.Value.text
                    if eissn is None:
                        eissn = ""
                elif identifier.Type == "Online_ISBN":
                    logging.debug("FOUND ISBN")
                    isbn = identifier.Value.text
                    if isbn is None:
                        isbn = ""

        except AttributeError:
            pass

        month_data = []
        html_usage = 0
        pdf_usage = 0

        metrics_for_db = collections.defaultdict(list)

        for perform_item in item.ItemPerformance:
            item_date = convert_date_run(perform_item.Period.Begin.text)
            logger.debug("perform_item date: %r", item_date)
            usage = None
            for inst in perform_item.Instance:
                if inst.MetricType == "ft_total":
                    usage = str(inst.Count)
                elif inst.MetricType == "ft_pdf":
                    pdf_usage += int(inst.Count)
                elif inst.MetricType == "ft_html":
                    html_usage += int(inst.Count)
                elif report.report_type.startswith('DB'):
                    metrics_for_db[inst.MetricType].append((item_date,
                                                            int(inst.Count)))
            if usage is not None:
                month_data.append((item_date, int(usage)))

        if report.report_type:
            if report.report_type.startswith('JR'):
                report.pubs.append(pycounter.report.CounterJournal(
                    title=title,
                    platform=platform,
                    publisher=publisher_name,
                    period=report.period,
                    metric=report.metric,
                    issn=issn,
                    eissn=eissn,
                    month_data=month_data,
                    html_total=html_usage,
                    pdf_total=pdf_usage
                ))
            elif report.report_type.startswith('BR'):
                report.pubs.append(
                    pycounter.report.CounterBook(
                        title=title,
                        platform=platform,
                        publisher=publisher_name,
                        period=report.period,
                        metric=report.metric,
                        issn=issn,
                        isbn=isbn,
                        month_data=month_data,
                    ))
            elif report.report_type.startswith('DB'):
                for metric_code, month_data in six.iteritems(metrics_for_db):
                    metric = pycounter.constants.DB_METRIC_MAP[metric_code]
                    report.pubs.append(
                        pycounter.report.CounterDatabase(
                            title=title,
                            platform=platform,
                            publisher=publisher_name,
                            period=report.period,
                            metric=metric,
                            month_data=month_data
                        ))

    return report

Example 32

Project: pyomo
Source File: computeconf.py
View license
def run_conf(scenario_instance_factory,
             index_list,
             num_scenarios_for_solution,
             num_scenarios_per_sample,
             full_scenario_tree,
             xhat_ph,
             options):

    if options.MRP_directory_basename is None:
        AllInOne = True

    sense = xhat_ph._scenario_tree._scenarios[0]._objective_sense

    # in order to handle the case of scenarios that are not equally
    # likely, we will split the expectations for Gsupk
    # BUT we are going to assume that the groups themselves are
    # equally likely and just scale by n_g and n_g-1 for Gbar and VarG

    # really not always needed...
    # http://www.eecs.berkeley.edu/~mhoemmen/cs194/Tutorials/variance.pdf
    g_supk_of_xhat = []
    g_bar = 0
    sum_xstar_obj_given_xhat = 0
    n_g = options.n_g

    for k in range(1, n_g+1):

        gk_ph = None
        try:

            if AllInOne:

                start_index = num_scenarios_for_solution + \
                              (k-1)*num_scenarios_per_sample
                stop_index = start_index + num_scenarios_per_sample

                print("")
                print("Computing statistics for sample k="+str(k)+".")
                if options.verbose:
                    print("Bundle start index="+str(start_index)
                          +", stop index="+str(stop_index)+".")

                # compute this xstar solution for the EF associated with
                # sample k.

                print("Loading scenario instances and initializing "
                      "scenario tree for xstar scenario bundle.")

                gk_ph = ph_for_bundle(start_index,
                                      stop_index,
                                      scenario_instance_factory,
                                      full_scenario_tree,
                                      index_list,
                                      options)

            else:

                options.instance_directory = \
                    options.MRP_directory_basename+str(k)

                gk_ph = PHFromScratch(options)

            print("Creating the xstar extensive form.")
            print("")
            print("Composite scenarios:")
            for scenario in gk_ph._scenario_tree._scenarios:
                print (scenario._name)
            print("")
            gk_ef = ExtensiveFormAlgorithm(gk_ph,
                                           options._ef_options,
                                           prefix="ef_")
            gk_ef.build_ef()
            print("Solving the xstar extensive form.")
            # Instance preprocessing is managed within the
            # ph object automatically when required for a
            # solve. Since we are solving the instances
            # outside of the ph object, we will inform it
            # that it should complete the instance
            # preprocessing early
            gk_ph._preprocess_scenario_instances()
            gk_ef.solve(io_options=\
                        {'output_fixed_variable_bounds':
                         options.write_fixed_variables})
            xstar_obj = gk_ef.objective
            # assuming this is the absolute gap
            xstar_obj_gap = gk_ef.gap

            """
            gk_ef = create_ef_instance(gk_ph._scenario_tree,
                                       generate_weighted_cvar=options.generate_weighted_cvar,
                                       cvar_weight=options.cvar_weight,
                                       risk_alpha=options.risk_alpha)
            print("Solving the xstar extensive form.")

            # Instance preprocessing is managed within the ph object
            # automatically when required for a solve. Since we are
            # solving the instances outside of the ph object, we will
            # inform it that it should complete the instance preprocessing
            # early
            gk_ph._preprocess_scenario_instances()

            ef_results = solve_ef(gk_ef, options)

            # as in the computation of xhat, the following is required to form a
            # solution to the extensive form in the scenario tree itself.
            gk_ph._scenario_tree.pullScenarioSolutionsFromInstances()
            gk_ph._scenario_tree.snapshotSolutionFromScenarios()

            # extract the objective function value corresponding to the
            # xstar solution, along with any gap information.

            xstar_obj = gk_ph._scenario_tree.findRootNode().computeExpectedNodeCost()
            # assuming this is the absolute gap
            xstar_obj_gap = gk_ef.solutions[0].gap# ef_results.solution(0).gap
            """

            print("Sample extensive form objective value="+str(xstar_obj))


            # CVARHACK: if CPLEX barfed, keep trucking and bury our head
            # in the sand.
            if type(xstar_obj_gap) is UndefinedData:
                xstar_obj_bound = xstar_obj
                #EW#print("xstar_obj_bound= "+str(xstar_obj_bound))
            else:
                if sense == minimize:
                    xstar_obj_bound = xstar_obj - xstar_obj_gap
                else:
                    xstar_obj_bound = xstar_obj + xstar_obj_gap
                #EW#print("xstar_obj_bound= "+str(xstar_obj_bound))
                #EW#print("xstar_obj = "+str(xstar_obj))
                #EW#print("xstar_obj_gap = "+str(xstar_obj_gap))
            # TBD: ADD VERBOSE OUTPUT HERE

            # to get f(xhat) for this sample, fix the first-stage
            # variables and re-solve the extensive form.  note that the
            # fixing yields side-effects on the original gk_ef, but that
            # is fine as it isn't used after this point.
            print("Solving the extensive form given the xhat solution.")
            #xhat = pyomo.pysp.phboundbase.ExtractInternalNodeSolutionsforInner(xhat_ph)
            #
            # fix the first stage variables
            #
            gk_root_node = gk_ph._scenario_tree.findRootNode()
            #root_xhat = xhat[gk_root_node._name]
            root_xhat = xhat_ph._scenario_tree.findRootNode()._solution
            for variable_id in gk_root_node._standard_variable_ids:
                gk_root_node.fix_variable(variable_id,
                                          root_xhat[variable_id])

            # Push fixed variable statuses on instances (or
            # transmit to the phsolverservers), since we are not
            # calling the solve method on the ph object, we
            # need to do this manually
            gk_ph._push_fix_queue_to_instances()
            gk_ph._preprocess_scenario_instances()

            gk_ef.solve(io_options=\
                        {'output_fixed_variable_bounds':
                         options.write_fixed_variables})
            #ef_results = solve_ef(gk_ef, options)

            # we don't need the solution - just the objective value.
            #objective_name = "MASTER"
            #objective = gk_ef.find_component(objective_name)
            xstar_obj_given_xhat = gk_ef.objective

            print("Sample extensive form objective value given xhat="
                  +str(xstar_obj_given_xhat))

            #g_supk_of_xhat.append(xstar_obj_given_xhat - xstar_obj_bound)
            if sense == minimize:
                g_supk_of_xhat.append(xstar_obj_given_xhat - xstar_obj_bound)
            else:
                g_supk_of_xhat.append(- xstar_obj_given_xhat + xstar_obj_bound)
            g_bar += g_supk_of_xhat[k-1]
            sum_xstar_obj_given_xhat += xstar_obj_given_xhat

        finally:

            if gk_ph is not None:

                # we are using the PHCleanup function for
                # convenience, but we need to prevent it
                # from shutting down the scenario_instance_factory
                # as it is managed outside this function
                if gk_ph._scenario_tree._scenario_instance_factory is \
                   scenario_instance_factory:
                    gk_ph._scenario_tree._scenario_instance_factory = None
                PHCleanup(gk_ph)

    g_bar /= n_g
    # second pass for variance calculation (because we like storing
    # the g_supk)
    g_var = 0.0
    for k in range(0, n_g):
        print("g_supk_of_xhat[%d]=%12.6f"
              % (k+1, g_supk_of_xhat[k]))
        g_var = g_var + (g_supk_of_xhat[k] - g_bar) * \
                (g_supk_of_xhat[k] - g_bar)
    if n_g != 1:
        # sample var
        g_var = g_var / (n_g - 1)
    print("")
    print("Raw results:")
    print("g_bar= "+str(g_bar))
    print("g_stddev= "+str(math.sqrt(g_var)))
    print("Average f(xhat)= "+str(sum_xstar_obj_given_xhat / n_g))

    if n_g in t_table_values:
        print("")
        print("Results summary:")
        t_table_entries = t_table_values[n_g]
        for key in sorted(iterkeys(t_table_entries)):
            print("Confidence interval width for alpha="+str(key)
                  +" is "+str(g_bar + (t_table_entries[key] * \
                                       math.sqrt(g_var) / \
                                       math.sqrt(n_g))))
    else:
        print("No built-in t-table entries for "+str(n_g)
              +" degrees of freedom - cannot calculate confidence interval width")

    if options.write_xhat_solution:
        print("")
        print("xhat solution:")
        scenario_tree = xhat_ph._scenario_tree
        first_stage = scenario_tree._stages[0]
        root_node = first_stage._tree_nodes[0]
        for key, val in iteritems(root_node._solutions):
            for idx in val:
                if val[idx] != 0.0:
                    print("%s %s %s" % (str(key), str(idx), str(val[idx]())))

    scenario_count = len(full_scenario_tree._stages[-1]._tree_nodes)
    if options.append_file is not None:
        output_file = open(options.append_file, "a")
        output_file.write("\ninstancedirectory, "
                          +str(options.instance_directory)
                          +", seed, "+str(options.random_seed)
                          +", N, "+str(scenario_count)
                          +", hatn, "+str(num_scenarios_for_solution)
                          +", n_g, "+str(options.n_g)
                          +", Eoffofxhat, "
                          +str(sum_xstar_obj_given_xhat / n_g)
                          +", gbar, "+str(g_bar)+", sg, "
                          +str(math.sqrt(g_var))+", objforxhat, "
                          +str(xhat_obj)+", n,"
                          +str(num_scenarios_per_sample))

        if n_g in t_table_values:
            t_table_entries = t_table_values[n_g]
            for key in sorted(iterkeys(t_table_entries)):
                output_file.write(" , alpha="+str(key)+" , "
                                  +str(g_bar + (t_table_entries[key] * \
                                                math.sqrt(g_var) / \
                                                math.sqrt(n_g))))

        if options.write_xhat_solution:
            output_file.write(" , ")
            scenario_tree = xhat_ph._scenario_tree
            first_stage = scenario_tree._stages[0]
            root_node = first_stage._tree_nodes[0]
            for key, val in iteritems(root_node._solutions):
                for idx in val:
                    if val[idx] != 0.0:
                        output_file.write("%s %s %s"
                                          % (str(key),
                                             str(idx),
                                             str(val[idx]())))
        output_file.close()
        print("")
        print("Results summary appended to file="
              +options.append_file)

    xhat_ph.release_components()

Example 33

Project: pyomo
Source File: ampl_repn.py
View license
def _generate_ampl_repn(exp):
    ampl_repn = AmplRepn()

    # We need to do this not at the global scope in case someone changed
    # the mode after importing the environment.
    _using_pyomo4_trees = expr_common.mode == expr_common.Mode.pyomo4_trees

    exp_type = type(exp)
    if exp_type in native_numeric_types:
        ampl_repn._constant = value(exp)
        return ampl_repn

    #
    # Expression
    #
    elif exp.is_expression():

        #
        # Sum
        #
        if _using_pyomo4_trees and (exp_type is Expr._LinearExpression):
            ampl_repn._constant = value(exp._const)
            ampl_repn._nonlinear_expr = None
            for child_exp in exp._args:
                exp_coef = value(exp._coef[id(child_exp)])
                if exp_coef != 0:
                    child_repn = _generate_ampl_repn(child_exp)
                    # adjust the constant
                    ampl_repn._constant += exp_coef * child_repn._constant

                    # adjust the linear terms
                    for var_ID in child_repn._linear_vars:
                        if var_ID in ampl_repn._linear_terms_coef:
                            ampl_repn._linear_terms_coef[var_ID] += \
                                exp_coef * child_repn._linear_terms_coef[var_ID]
                        else:
                            ampl_repn._linear_terms_coef[var_ID] = \
                                exp_coef * child_repn._linear_terms_coef[var_ID]
                    # adjust the linear vars
                    ampl_repn._linear_vars.update(child_repn._linear_vars)

                    # adjust the nonlinear terms
                    if not child_repn._nonlinear_expr is None:
                        if ampl_repn._nonlinear_expr is None:
                            ampl_repn._nonlinear_expr = \
                                [(exp_coef, child_repn._nonlinear_expr)]
                        else:
                            ampl_repn._nonlinear_expr.append(
                                (exp_coef, child_repn._nonlinear_expr))
                    # adjust the nonlinear vars
                    ampl_repn._nonlinear_vars.update(child_repn._nonlinear_vars)

            return ampl_repn

        elif _using_pyomo4_trees and (exp_type is Expr._SumExpression):
            ampl_repn._constant = 0.0
            ampl_repn._nonlinear_expr = None
            for child_exp in exp._args:
                child_repn = _generate_ampl_repn(child_exp)
                # adjust the constant
                ampl_repn._constant += child_repn._constant

                # adjust the linear terms
                for var_ID in child_repn._linear_vars:
                    if var_ID in ampl_repn._linear_terms_coef:
                        ampl_repn._linear_terms_coef[var_ID] += \
                            child_repn._linear_terms_coef[var_ID]
                    else:
                        ampl_repn._linear_terms_coef[var_ID] = \
                            child_repn._linear_terms_coef[var_ID]
                # adjust the linear vars
                ampl_repn._linear_vars.update(child_repn._linear_vars)

                # adjust the nonlinear terms
                if not child_repn._nonlinear_expr is None:
                    if ampl_repn._nonlinear_expr is None:
                        ampl_repn._nonlinear_expr = \
                            [(1, child_repn._nonlinear_expr)]
                    else:
                        ampl_repn._nonlinear_expr.append(
                            (1, child_repn._nonlinear_expr))
                # adjust the nonlinear vars
                ampl_repn._nonlinear_vars.update(child_repn._nonlinear_vars)
            return ampl_repn

        elif exp_type is Expr._SumExpression:
            assert not _using_pyomo4_trees
            ampl_repn._constant = exp._const
            ampl_repn._nonlinear_expr = None
            for i in xrange(len(exp._args)):
                exp_coef = exp._coef[i]
                if exp_coef != 0:
                    child_exp = exp._args[i]
                    child_repn = _generate_ampl_repn(child_exp)
                    # adjust the constant
                    ampl_repn._constant += exp_coef * child_repn._constant

                    # adjust the linear terms
                    for var_ID in child_repn._linear_vars:
                        if var_ID in ampl_repn._linear_terms_coef:
                            ampl_repn._linear_terms_coef[var_ID] += \
                                exp_coef * child_repn._linear_terms_coef[var_ID]
                        else:
                            ampl_repn._linear_terms_coef[var_ID] = \
                                exp_coef * child_repn._linear_terms_coef[var_ID]
                    # adjust the linear vars
                    ampl_repn._linear_vars.update(child_repn._linear_vars)

                    # adjust the nonlinear terms
                    if not child_repn._nonlinear_expr is None:
                        if ampl_repn._nonlinear_expr is None:
                            ampl_repn._nonlinear_expr = \
                                [(exp_coef, child_repn._nonlinear_expr)]
                        else:
                            ampl_repn._nonlinear_expr.append(
                                (exp_coef, child_repn._nonlinear_expr))
                    # adjust the nonlinear vars
                    ampl_repn._nonlinear_vars.update(child_repn._nonlinear_vars)
            return ampl_repn

        #
        # Product
        #
        elif (not _using_pyomo4_trees) and \
             (exp_type is Expr._ProductExpression):
            #
            # Iterate through the denominator.  If they
            # aren't all constants, then simply return this
            # expression.
            #
            denom = 1.0
            for e in exp._denominator:
                if e.is_fixed():
                    denom *= value(e)
                else:
                    ampl_repn._nonlinear_expr = exp
                    break
                if denom == 0.0:
                    raise ZeroDivisionError(
                        "Divide-by-zero error - offending sub-expression: "+str(e))

            if ampl_repn._nonlinear_expr is not None:
                # we have a nonlinear expression ... build up all the vars
                for e in exp._denominator:
                    arg_repn = _generate_ampl_repn(e)
                    ampl_repn._nonlinear_vars.update(arg_repn._linear_vars)
                    ampl_repn._nonlinear_vars.update(arg_repn._nonlinear_vars)

                for e in exp._numerator:
                    arg_repn = _generate_ampl_repn(e)
                    ampl_repn._nonlinear_vars.update(arg_repn._linear_vars)
                    ampl_repn._nonlinear_vars.update(arg_repn._nonlinear_vars)
                return ampl_repn

            #
            # OK, the denominator is a constant.
            #
            # build up the ampl_repns for the numerator
            n_linear_args = 0
            n_nonlinear_args = 0
            arg_repns = list()
            for e in exp._numerator:
                e_repn = _generate_ampl_repn(e)
                arg_repns.append(e_repn)
                # check if the expression is not nonlinear else it is nonlinear
                if e_repn._nonlinear_expr is not None:
                    n_nonlinear_args += 1
                # Check whether the expression is constant or else it is linear
                elif len(e_repn._linear_vars) > 0:
                    n_linear_args += 1
                # At this point we do not have a nonlinear
                # expression and there are no linear
                # terms. If the expression constant is zero,
                # then we have a zero term in the product
                # expression, so the entire product
                # expression becomes trivial.
                elif e_repn._constant == 0.0:
                    ampl_repn = e_repn
                    return ampl_repn

            is_nonlinear = False
            if n_linear_args > 1 or n_nonlinear_args > 0:
                is_nonlinear = True

            if is_nonlinear:
                # do like AMPL and simply return the expression
                # without extracting the potentially linear part
                ampl_repn = AmplRepn()
                ampl_repn._nonlinear_expr = exp
                for repn in arg_repns:
                    ampl_repn._nonlinear_vars.update(repn._linear_vars)
                    ampl_repn._nonlinear_vars.update(repn._nonlinear_vars)
                return ampl_repn

            else: # is linear or constant
                ampl_repn = current_repn = arg_repns[0]
                for i in xrange(1,len(arg_repns)):
                    e_repn = arg_repns[i]
                    ampl_repn = AmplRepn()

                    # const_c * const_e
                    ampl_repn._constant = current_repn._constant * e_repn._constant

                    # const_e * L_c
                    if e_repn._constant != 0.0:
                        for (var_ID, var) in iteritems(current_repn._linear_vars):
                            ampl_repn._linear_terms_coef[var_ID] = \
                                current_repn._linear_terms_coef[var_ID] * \
                                e_repn._constant
                        ampl_repn._linear_vars.update(current_repn._linear_vars)

                    # const_c * L_e
                    if current_repn._constant != 0.0:
                        for (e_var_ID,e_var) in iteritems(e_repn._linear_vars):
                            if e_var_ID in ampl_repn._linear_vars:
                                ampl_repn._linear_terms_coef[e_var_ID] += \
                                    current_repn._constant * \
                                    e_repn._linear_terms_coef[e_var_ID]
                            else:
                                ampl_repn._linear_terms_coef[e_var_ID] = \
                                    current_repn._constant * \
                                    e_repn._linear_terms_coef[e_var_ID]
                        ampl_repn._linear_vars.update(e_repn._linear_vars)
                    current_repn = ampl_repn

            # now deal with the product expression's coefficient that needs
            # to be applied to all parts of the ampl_repn
            ampl_repn._constant *= exp._coef/denom
            for var_ID in ampl_repn._linear_terms_coef:
                ampl_repn._linear_terms_coef[var_ID] *= exp._coef/denom

            return ampl_repn

        elif _using_pyomo4_trees and (exp_type is Expr._ProductExpression):
            # It is assumed this is a binary operator
            # (x=args[0], y=args[1])
            assert len(exp._args) == 2

            n_linear_args = 0
            n_nonlinear_args = 0
            arg_repns = list()
            for e in exp._args:
                e_repn = _generate_ampl_repn(e)
                arg_repns.append(e_repn)
                # check if the expression is not nonlinear else it is nonlinear
                if e_repn._nonlinear_expr is not None:
                    n_nonlinear_args += 1
                # Check whether the expression is constant or else it is linear
                elif len(e_repn._linear_vars) > 0:
                    n_linear_args += 1
                # At this point we do not have a nonlinear
                # expression and there are no linear
                # terms. If the expression constant is zero,
                # then we have a zero term in the product
                # expression, so the entire product
                # expression becomes trivial.
                elif e_repn._constant == 0.0:
                    ampl_repn = e_repn
                    return ampl_repn

            is_nonlinear = False
            if n_linear_args > 1 or n_nonlinear_args > 0:
                is_nonlinear = True

            if is_nonlinear:
                # do like AMPL and simply return the expression
                # without extracting the potentially linear part
                ampl_repn = AmplRepn()
                ampl_repn._nonlinear_expr = exp
                for repn in arg_repns:
                    ampl_repn._nonlinear_vars.update(repn._linear_vars)
                    ampl_repn._nonlinear_vars.update(repn._nonlinear_vars)
                return ampl_repn

            # is linear or constant
            ampl_repn = current_repn = arg_repns[0]
            for i in xrange(1,len(arg_repns)):
                e_repn = arg_repns[i]
                ampl_repn = AmplRepn()

                # const_c * const_e
                ampl_repn._constant = current_repn._constant * e_repn._constant

                # const_e * L_c
                if e_repn._constant != 0.0:
                    for (var_ID, var) in iteritems(current_repn._linear_vars):
                        ampl_repn._linear_terms_coef[var_ID] = \
                            current_repn._linear_terms_coef[var_ID] * \
                            e_repn._constant
                    ampl_repn._linear_vars.update(current_repn._linear_vars)

                # const_c * L_e
                if current_repn._constant != 0.0:
                    for (e_var_ID,e_var) in iteritems(e_repn._linear_vars):
                        if e_var_ID in ampl_repn._linear_vars:
                            ampl_repn._linear_terms_coef[e_var_ID] += \
                                current_repn._constant * \
                                e_repn._linear_terms_coef[e_var_ID]
                        else:
                            ampl_repn._linear_terms_coef[e_var_ID] = \
                                current_repn._constant * \
                                e_repn._linear_terms_coef[e_var_ID]
                    ampl_repn._linear_vars.update(e_repn._linear_vars)
                current_repn = ampl_repn

            return ampl_repn

        elif _using_pyomo4_trees and (exp_type is Expr._DivisionExpression):
            # It is assumed this is a binary operator
            # (numerator=args[0], denominator=args[1])
            assert len(exp._args) == 2

            #
            # Check the denominator, if it is not constant,
            # then simply return this expression.
            #
            numerator, denominator = exp._args
            if not is_fixed(denominator):
                ampl_repn._nonlinear_expr = exp
                # we have a nonlinear expression, so build up all the vars
                for e in exp._args:
                    arg_repn = _generate_ampl_repn(e)
                    ampl_repn._nonlinear_vars.update(arg_repn._linear_vars)
                    ampl_repn._nonlinear_vars.update(arg_repn._nonlinear_vars)
                return ampl_repn

            denominator = value(denominator)
            if denominator == 0:
                raise ZeroDivisionError(
                    "Divide-by-zero error - offending sub-expression: "+str(exp._args[1]))

            #
            # OK, the denominator is a constant.
            #

            # build up the ampl_repn for the numerator
            ampl_repn = _generate_ampl_repn(numerator)
            # check if the expression is not nonlinear else it is nonlinear
            if ampl_repn._nonlinear_expr is not None:
                # do like AMPL and simply return the expression
                # without extracting the potentially linear part
                # (be sure to set this to the original expression,
                # not just the numerators)
                ampl_repn._nonlinear_expr = exp
                return ampl_repn

            #
            # OK, we have a linear numerator with a constant denominator
            #

            # update any constants and coefficients by dividing
            # by the fixed denominator
            ampl_repn._constant /= denominator
            for var_ID in ampl_repn._linear_terms_coef:
                ampl_repn._linear_terms_coef[var_ID] /= denominator

            return ampl_repn

        elif _using_pyomo4_trees and (exp_type is Expr._NegationExpression):
            assert len(exp._args) == 1
            ampl_repn = _generate_ampl_repn(exp._args[0])
            if ampl_repn._nonlinear_expr is not None:
                # do like AMPL and simply return the expression
                # without extracting the potentially linear part
                ampl_repn._nonlinear_expr = exp
                return ampl_repn

            # this subexpression is linear, so update any
            # constants and coefficients by negating them
            ampl_repn._constant *= -1
            for var_ID in ampl_repn._linear_terms_coef:
                ampl_repn._linear_terms_coef[var_ID] *= -1

            return ampl_repn

        #
        # Power Expressions
        #
        elif exp_type is Expr._PowExpression:
            assert(len(exp._args) == 2)
            base_repn = _generate_ampl_repn(exp._args[0])
            base_repn_fixed = base_repn.is_fixed()
            exponent_repn = _generate_ampl_repn(exp._args[1])
            exponent_repn_fixed = exponent_repn.is_fixed()

            if base_repn_fixed and exponent_repn_fixed:
                ampl_repn._constant = base_repn._constant**exponent_repn._constant
            elif exponent_repn_fixed and exponent_repn._constant == 1.0:
                ampl_repn = base_repn
            elif exponent_repn_fixed and exponent_repn._constant == 0.0:
                ampl_repn._constant = 1.0
            else:
                # instead, let's just return the expression we are given and only
                # use the ampl_repn for the vars
                ampl_repn._nonlinear_expr = exp
                ampl_repn._nonlinear_vars = base_repn._nonlinear_vars
                ampl_repn._nonlinear_vars.update(exponent_repn._nonlinear_vars)
                ampl_repn._nonlinear_vars.update(base_repn._linear_vars)
                ampl_repn._nonlinear_vars.update(exponent_repn._linear_vars)
            return ampl_repn

        #
        # External Functions
        #
        elif exp_type is Expr._ExternalFunctionExpression:
            # In theory, the argument are fixed, we can simply evaluate this expression
            if exp.is_fixed():
                ampl_repn._constant = value(exp)
                return ampl_repn

            # this is inefficient since it is using much more than what we need
            ampl_repn._nonlinear_expr = exp
            for arg in exp._args:
                if isinstance(arg, basestring):
                    continue
                child_repn = _generate_ampl_repn(arg)
                ampl_repn._nonlinear_vars.update(child_repn._nonlinear_vars)
                ampl_repn._nonlinear_vars.update(child_repn._linear_vars)
            return ampl_repn

        #
        # Intrinsic Functions
        #
        elif isinstance(exp,Expr._IntrinsicFunctionExpression):
            # Checking isinstance above accounts for the fact that _AbsExpression
            # is a subclass of _IntrinsicFunctionExpression
            assert(len(exp._args) == 1)

            # the argument is fixed, we can simply evaluate this expression
            if exp._args[0].is_fixed():
                ampl_repn._constant = value(exp)
                return ampl_repn

            # this is inefficient since it is using much more than what we need
            child_repn = _generate_ampl_repn(exp._args[0])

            ampl_repn._nonlinear_expr = exp
            ampl_repn._nonlinear_vars = child_repn._nonlinear_vars
            ampl_repn._nonlinear_vars.update(child_repn._linear_vars)
            return ampl_repn

        #
        # AMPL-style If-Then-Else expression
        #
        elif exp_type is Expr.Expr_if:
            if exp._if.is_fixed():
                if exp._if():
                    ampl_repn = _generate_ampl_repn(exp._then)
                else:
                    ampl_repn = _generate_ampl_repn(exp._else)
            else:
                if_repn = _generate_ampl_repn(exp._if)
                then_repn = _generate_ampl_repn(exp._then)
                else_repn = _generate_ampl_repn(exp._else)
                ampl_repn._nonlinear_expr = exp
                ampl_repn._nonlinear_vars = if_repn._nonlinear_vars
                ampl_repn._nonlinear_vars.update(then_repn._nonlinear_vars)
                ampl_repn._nonlinear_vars.update(else_repn._nonlinear_vars)
                ampl_repn._nonlinear_vars.update(if_repn._linear_vars)
                ampl_repn._nonlinear_vars.update(then_repn._linear_vars)
                ampl_repn._nonlinear_vars.update(else_repn._linear_vars)
            return ampl_repn
        elif (exp_type is Expr._InequalityExpression) or \
             (exp_type is Expr._EqualityExpression):
            for arg in exp._args:
                arg_repn = _generate_ampl_repn(arg)
                ampl_repn._nonlinear_vars.update(arg_repn._nonlinear_vars)
                ampl_repn._nonlinear_vars.update(arg_repn._linear_vars)
            ampl_repn._nonlinear_expr = exp
            return ampl_repn
        elif exp.is_fixed():
            ampl_repn._constant = value(exp)
            return ampl_repn

        #
        # Expression (the component)
        #
        elif isinstance(exp, _ExpressionData):
            ampl_repn = _generate_ampl_repn(exp.expr)
            return ampl_repn

        #
        # ERROR
        #
        else:
            raise ValueError("Unsupported expression type: "+str(type(exp))+" ("+str(exp)+")")

    #
    # Constant
    #
    elif exp.is_fixed():
        ### GAH: Why were we even checking this
        #if not exp.value.__class__ in native_numeric_types:
        #    ampl_repn = _generate_ampl_repn(exp.value)
        #    return ampl_repn
        ampl_repn._constant = value(exp)
        return ampl_repn

    #
    # Variable
    #
    elif isinstance(exp, _VarData):
        if exp.fixed:
            ampl_repn._constant = exp.value
            return ampl_repn
        var_ID = id(exp)
        ampl_repn._linear_terms_coef[var_ID] = 1.0
        ampl_repn._linear_vars[var_ID] = exp
        return ampl_repn

    #
    # ERROR
    #
    else:
        raise ValueError("Unexpected expression type: "+str(exp))

Example 34

Project: pyomo
Source File: matrix.py
View license
def compile_block_linear_constraints(parent_block,
                                     constraint_name,
                                     skip_trivial_constraints=False,
                                     single_precision_storage=False,
                                     verbose=False,
                                     descend_into=True):

    if verbose:
        print("")
        print("Compiling linear constraints on block with name: %s"
              % (parent_block.name))

    if not parent_block.is_constructed():
        raise RuntimeError(
            "Attempting to compile block '%s' with unconstructed "
            "component(s)" % (parent_block.name))

    #
    # Linear MatrixConstraint in CSR format
    #
    SparseMat_pRows = []
    SparseMat_jCols = []
    SparseMat_Vals = []
    Ranges = []
    RangeTypes = []

    def _get_bound(exp):
        if exp is None:
            return None
        if is_fixed(exp):
            return value(exp)
        raise ValueError("non-fixed bound: " + str(exp))

    start_time = time.time()
    if verbose:
        print("Sorting active blocks...")

    sortOrder = SortComponents.indices | SortComponents.alphabetical
    all_blocks = [_b for _b in parent_block.block_data_objects(
        active=True,
        sort=sortOrder,
        descend_into=descend_into)]

    stop_time = time.time()
    if verbose:
        print("Time to sort active blocks: %.2f seconds"
              % (stop_time-start_time))

    start_time = time.time()
    if verbose:
        print("Collecting variables on active blocks...")

    #
    # First Pass: assign each variable a deterministic id
    #             (an index in a list)
    #
    VarSymbolToVarObject = []
    for block in all_blocks:
        VarSymbolToVarObject.extend(
            block.component_data_objects(Var,
                                         sort=sortOrder,
                                         descend_into=False))
    VarIDToVarSymbol = \
        dict((id(vardata), index)
             for index, vardata in enumerate(VarSymbolToVarObject))

    stop_time = time.time()
    if verbose:
        print("Time to collect variables on active blocks: %.2f seconds"
              % (stop_time-start_time))

    start_time = time.time()
    if verbose:
        print("Compiling active linear constraints...")

    #
    # Second Pass: collect and remove active linear constraints
    #
    constraint_data_to_remove = []
    empty_constraint_containers_to_remove = []
    constraint_containers_to_remove = []
    constraint_containers_to_check = set()
    referenced_variable_symbols = set()
    nnz = 0
    nrows = 0
    SparseMat_pRows = [0]
    for block in all_blocks:

        if hasattr(block, '_canonical_repn'):
            del block._canonical_repn
        if hasattr(block, '_ampl_repn'):
            del block._ampl_repn

        for constraint in block.component_objects(Constraint,
                                                  active=True,
                                                  sort=sortOrder,
                                                  descend_into=False):

            assert not isinstance(constraint, MatrixConstraint)

            if len(constraint) == 0:

                empty_constraint_containers_to_remove.append((block, constraint))

            else:

                singleton = isinstance(constraint, SimpleConstraint)

                for index, constraint_data in iteritems(constraint):

                    if constraint_data.body.polynomial_degree() <= 1:

                        # collect for removal
                        if singleton:
                            constraint_containers_to_remove.append((block, constraint))
                        else:
                            constraint_data_to_remove.append((constraint, index))
                            constraint_containers_to_check.add((block, constraint))

                        canonical_repn = generate_canonical_repn(constraint_data.body)

                        assert isinstance(canonical_repn, LinearCanonicalRepn)

                        row_variable_symbols = []
                        row_coefficients = []
                        if canonical_repn.variables is None:
                            if skip_trivial_constraints:
                                continue
                        else:
                            row_variable_symbols = \
                                [VarIDToVarSymbol[id(vardata)]
                                 for vardata in canonical_repn.variables]
                            referenced_variable_symbols.update(
                                row_variable_symbols)
                            assert canonical_repn.linear is not None
                            row_coefficients = canonical_repn.linear

                        SparseMat_pRows.append(SparseMat_pRows[-1] + \
                                               len(row_variable_symbols))
                        SparseMat_jCols.extend(row_variable_symbols)
                        SparseMat_Vals.extend(row_coefficients)

                        nnz += len(row_variable_symbols)
                        nrows += 1

                        L = _get_bound(constraint_data.lower)
                        U = _get_bound(constraint_data.upper)
                        constant = value(canonical_repn.constant)
                        if constant is None:
                            constant = 0

                        Ranges.append(L - constant if (L is not None) else 0)
                        Ranges.append(U - constant if (U is not None) else 0)
                        if (L is not None) and \
                           (U is not None) and \
                           (not constraint_data.equality):
                            RangeTypes.append(MatrixConstraint.LowerBound |
                                              MatrixConstraint.UpperBound)
                        elif constraint_data.equality:
                            RangeTypes.append(MatrixConstraint.Equality)
                        elif L is not None:
                            assert U is None
                            RangeTypes.append(MatrixConstraint.LowerBound)
                        else:
                            assert U is not None
                            RangeTypes.append(MatrixConstraint.UpperBound)

                        # Start freeing up memory
                        constraint_data.set_value(None)

    ncols = len(referenced_variable_symbols)

    stop_time = time.time()
    if verbose:
        print("Time to compile active linear constraints: %.2f seconds"
              % (stop_time-start_time))

    start_time = time.time()
    if verbose:
        print("Removing compiled constraint objects...")

    #
    # Remove compiled constraints
    #
    constraints_removed = 0
    constraint_containers_removed = 0
    for block, constraint in empty_constraint_containers_to_remove:
        block.del_component(constraint)
        constraint_containers_removed += 1
    for constraint, index in constraint_data_to_remove:
        del constraint[index]
        constraints_removed += 1
    for block, constraint in constraint_containers_to_remove:
        block.del_component(constraint)
        constraints_removed += 1
        constraint_containers_removed += 1
    for block, constraint in constraint_containers_to_check:
        if len(constraint) == 0:
            block.del_component(constraint)
            constraint_containers_removed += 1

    stop_time = time.time()
    if verbose:
        print("Eliminated %s constraints and %s Constraint container objects"
              % (constraints_removed, constraint_containers_removed))
        print("Time to remove compiled constraint objects: %.2f seconds"
              % (stop_time-start_time))

    start_time = time.time()
    if verbose:
        print("Assigning variable column indices...")

    #
    # Assign a column index to the set of referenced variables
    #
    ColumnIndexToVarSymbol = sorted(referenced_variable_symbols)
    VarSymbolToColumnIndex = dict((symbol, column)
                                  for column, symbol in enumerate(ColumnIndexToVarSymbol))
    SparseMat_jCols = [VarSymbolToColumnIndex[symbol] for symbol in SparseMat_jCols]
    del VarSymbolToColumnIndex
    ColumnIndexToVarObject = [VarSymbolToVarObject[var_symbol]
                              for var_symbol in ColumnIndexToVarSymbol]

    stop_time = time.time()
    if verbose:
        print("Time to assign variable column indices: %.2f seconds"
              % (stop_time-start_time))

    start_time = time.time()
    if verbose:
        print("Converting compiled constraint data to array storage...")
        print("  - Using %s precision for numeric values"
              % ('single' if single_precision_storage else 'double'))

    #
    # Convert to array storage
    #

    number_storage = 'f' if single_precision_storage else 'd'
    SparseMat_pRows = array.array('L', SparseMat_pRows)
    SparseMat_jCols = array.array('L', SparseMat_jCols)
    SparseMat_Vals = array.array(number_storage, SparseMat_Vals)
    Ranges = array.array(number_storage, Ranges)
    RangeTypes = array.array('B', RangeTypes)

    stop_time = time.time()
    if verbose:
        storage_bytes = \
            SparseMat_pRows.buffer_info()[1] * SparseMat_pRows.itemsize + \
            SparseMat_jCols.buffer_info()[1] * SparseMat_jCols.itemsize + \
            SparseMat_Vals.buffer_info()[1] * SparseMat_Vals.itemsize + \
            Ranges.buffer_info()[1] * Ranges.itemsize + \
            RangeTypes.buffer_info()[1] * RangeTypes.itemsize
        print("Sparse Matrix Dimension:")
        print("  - Rows: "+str(nrows))
        print("  - Cols: "+str(ncols))
        print("  - Nonzeros: "+str(nnz))
        print("Compiled Data Storage: "+str(_label_bytes(storage_bytes)))
        print("Time to convert compiled constraint data to "
              "array storage: %.2f seconds" % (stop_time-start_time))

    parent_block.add_component(constraint_name,
                               MatrixConstraint(nrows, ncols, nnz,
                                                SparseMat_pRows,
                                                SparseMat_jCols,
                                                SparseMat_Vals,
                                                Ranges,
                                                RangeTypes,
                                                ColumnIndexToVarObject))

Example 35

Project: pyomo
Source File: canonical_repn.py
View license
def pyomo4_generate_canonical_repn(exp, idMap=None, compute_values=True):
    # A **very** special case
    if TreeWalkerHelper.typeList.get(exp.__class__,0) == 4: # _LinearExpression:
        ans = CompiledLinearCanonicalRepn()

        # old format
        ans.constant = exp._const
        ans.variables = list( exp._args )
        _l = exp._coef
        ans.linear = [_l[id(v)] for v in exp._args]

        if idMap:
            if None not in idMap:
                idMap[None] = {}
            _test = idMap[None]
            _key = len(idMap) - 1
            for v in exp._args:
                if id(v) not in _test:
                    _test[id(v)] = _key
                    idMap[_key] = v
                    _key += 1
        return ans
    else:
        degree = exp.polynomial_degree()

    if degree == 1:
        _typeList = TreeWalkerHelper.typeList
        _stackMax = len(_stack)
        _stackIdx = 0
        _stackPtr = _stack[0]

        _stackPtr[0] = exp
        try:
            _stackPtr[1] = exp._args
        except AttributeError:
            ans = CompiledLinearCanonicalRepn()
            ans.variables.append(exp)
            # until we can redefine CompiledLinearCanonicalRepn, restore
            # old format
            #ans.linear[id(exp)] = 1.
            ans.linear = [1.]
            return ans
        try:
            _stackPtr[2] = _type = _typeList[exp.__class__]
            if _stackPtr[2] == 2:
                _stackPtr[5].constant = 1.
        except KeyError:
            _stackPtr[2] = _type = 0
        _stackPtr[3] = len(_stackPtr[1])
        _stackPtr[4] = 0
        #_stackPtr[5] = CompiledLinearCanonicalRepn()

        if _type == 4: # _LinearExpression
            _stackPtr[4] = _stackPtr[3]
            _stackPtr[5].constant = exp._const
            _stackPtr[5].linear = dict(exp._coef)
            _stackPtr[5].variables = list(exp._args)

        while 1: # Note: 1 is faster than True for Python 2.x
            if _stackPtr[4] < _stackPtr[3]:
                _sub = _stackPtr[1][_stackPtr[4]]
                _stackPtr[4] += 1
                _test = _sub.__class__ in native_numeric_types
                if _test or not _sub.is_expression():
                    if not _test and _sub.is_fixed():
                        _sub = value(_sub)
                        _test = 1 # True
                    if _test:
                        if _type == 2:
                            _stackPtr[5].constant *= _sub
                            _l = _stackPtr[5].linear
                            if _l:
                                for _id in _l:
                                    _l[_id] *= _sub
                        elif _type == 1:
                            _stackPtr[5].constant += _sub
                        elif _type == 3:
                            _stackPtr[5].constant = -1. * _sub
                        else:
                            raise RuntimeError("HELP")
                    else:
                        _id = id(_sub)
                        if _type == 2:
                            _lcr = _stackPtr[5]
                            _lcr.variables.append(_sub)
                            _lcr.linear[_id] = _lcr.constant
                            _lcr.constant = 0
                        elif _type == 1:
                            if _id in _stackPtr[5].linear:
                                _stackPtr[5].linear[_id] += 1.
                            else:
                                _stackPtr[5].variables.append(_sub)
                                _stackPtr[5].linear[_id] = 1.
                        elif _type == 3:
                            _lcr = _stackPtr[5]
                            _lcr.variables.append(_sub)
                            _lcr.linear[_id] = -1.
                        else:
                            raise RuntimeError("HELP")
                else:
                    _stackIdx += 1
                    if _stackMax == _stackIdx:
                        _stackMax += 1
                        _stack.append([0,0,0,0,0, CompiledLinearCanonicalRepn()])
                    _stackPtr = _stack[_stackIdx]

                    _stackPtr[0] = _sub
                    _stackPtr[1] = _sub._args
                    #_stackPtr[2] = _type = _typeList.get(_sub.__class__, 0)
                    #if _type == 2:
                    #    _stackPtr[5].constant = 1.
                    try:
                        _stackPtr[2] = _type = _typeList[_sub.__class__]
                        if _type == 2:
                            _stackPtr[5].constant = 1.
                    except KeyError:
                        _stackPtr[2] = _type = 0
                    _stackPtr[3] = len(_stackPtr[1])
                    _stackPtr[4] = 0
                    #_stackPtr[5] = CompiledLinearCanonicalRepn()

                    if _type == 4: # _LinearExpression
                        _stackPtr[4] = _stackPtr[3]
                        _stackPtr[5].constant = _sub._const
                        _stackPtr[5].linear = dict(_sub._coef)
                        _stackPtr[5].variables = list(_sub._args)
            else:
                old = _stackPtr[5]
                if not _type:
                    old.constant = _stackPtr[0]._apply_operation(old.variables)
                    old.variables = []

                if _stackIdx == 0:
                    ans = CompiledLinearCanonicalRepn()
                    ans.variables, old.variables = old.variables, ans.variables
                    ans.linear, old.linear = old.linear, ans.linear
                    ans.constant, old.constant = old.constant, ans.constant
                    # until we can redefine CompiledLinearCanonicalRepn, restore
                    # old format
                    ans.linear = [ans.linear[id(v)] for v in ans.variables]

                    if idMap:
                        if None not in idMap:
                            idMap[None] = {}
                        _test = idMap[None]
                        _key = len(idMap) - 1
                        for v in ans.variables:
                            if id(v) not in _test:
                                _test[id(v)] = _key
                                idMap[_key] = v
                                _key += 1
                    return ans

                _stackIdx -= 1
                _stackPtr = _stack[_stackIdx]
                new = _stackPtr[5]
                _type = _stackPtr[2]
                if _type == 1:
                    new.constant += old.constant
                    _nl = new.linear
                    # Note: append the variables in the order that they
                    # were originally added to the CompiledLinearCanonicalRepn.
                    # This keeps things deterministic.
                    for v in old.variables:
                        _id = id(v)
                        if _id in _nl:
                            _nl[_id] += old.linear[_id]
                        else:
                            new.variables.append(v)
                            _nl[_id] = old.linear[_id]
                    old.constant = 0.
                    old.variables = []
                    old.linear = {}
                elif _type == 2:
                    if old.variables:
                        old.variables, new.variables = new.variables, old.variables
                        old.linear, new.linear = new.linear, old.linear
                        old.constant, new.constant = new.constant, old.constant
                    _c = old.constant
                    new.constant *= _c
                    _nl = new.linear
                    for _id in _nl:
                        _nl[_id] *= _c
                    old.constant = 0.
                elif _type == 3:
                    old.variables, new.variables = new.variables, old.variables
                    old.linear, new.linear = new.linear, old.linear
                    new.constant = -1 * old.constant
                    old.constant = 0.
                    _nl = new.linear
                    for _id in _nl:
                        _nl[_id] *= -1
                else:
                    raise RuntimeError("HELP")

    elif degree == 0:
        if CompiledLinearCanonicalRepn_Pool:
            ans = CompiledLinearCanonicalRepn_Pool.pop()
            ans.__init__()
        else:
            ans = CompiledLinearCanonicalRepn()
        ans.constant = value(exp)
        return ans

    # **Py3k: degree > 1 comparision will error if degree is None
    elif degree and degree > 1:
        ans = collect_general_canonical_repn(exp, idMap, compute_values)
        if 1 in ans:
            linear_terms = {}
            for key, coef in iteritems(ans[1]):
                linear_terms[list(key.keys())[0]] = coef
            ans[1] = linear_terms
        return GeneralCanonicalRepn(ans)
    else:
        return GeneralCanonicalRepn(
            { None: exp, -1 : collect_variables(exp, idMap) } )

Example 36

Project: pyomo
Source File: baron_writer.py
View license
    def __call__(self,
                 model,
                 output_filename,
                 solver_capability,
                 io_options):

        # Make sure not to modify the user's dictionary, they may be
        # reusing it outside of this call
        io_options = dict(io_options)

        # NOTE: io_options is a simple dictionary of keyword-value
        #       pairs specific to this writer.
        symbolic_solver_labels = \
            io_options.pop("symbolic_solver_labels", False)
        labeler = io_options.pop("labeler", None)

        # How much effort do we want to put into ensuring the
        # LP file is written deterministically for a Pyomo model:
        #    0 : None
        #    1 : sort keys of indexed components (default)
        #    2 : sort keys AND sort names (over declaration order)
        file_determinism = io_options.pop("file_determinism", 1)

        sorter = SortComponents.unsorted
        if file_determinism >= 1:
            sorter = sorter | SortComponents.indices
            if file_determinism >= 2:
                sorter = sorter | SortComponents.alphabetical

        # TODO
        #output_fixed_variable_bounds = \
        #    io_options.pop("output_fixed_variable_bounds", False)

        # Skip writing constraints whose body section is fixed (i.e.,
        # no variables)
        skip_trivial_constraints = \
            io_options.pop("skip_trivial_constraints", False)

        # Note: Baron does not allow specification of runtime
        #       option outside of this file, so we add support
        #       for them here
        solver_options = io_options.pop("solver_options", {})

        if len(io_options):
            raise ValueError(
                "ProblemWriter_baron_writer passed unrecognized io_options:\n\t" +
                "\n\t".join("%s = %s" % (k,v) for k,v in iteritems(io_options)))

        if symbolic_solver_labels and (labeler is not None):
            raise ValueError("Baron problem writer: Using both the "
                             "'symbolic_solver_labels' and 'labeler' "
                             "I/O options is forbidden")

        if output_filename is None:
            output_filename = model.name + ".bar"

        output_file=open(output_filename, "w")

        # Process the options. Rely on baron to catch
        # and reset bad option values
        output_file.write("OPTIONS {\n")
        summary_found = False
        if len(solver_options):
            for key, val in iteritems(solver_options):
                if (key.lower() == 'summary'):
                    summary_found = True
                if key.endswith("Name"):
                    output_file.write(key+": \""+str(val)+"\";\n")
                else:
                    output_file.write(key+": "+str(val)+";\n")
        if not summary_found:
            # The 'summary option is defaulted to 0, so that no
            # summary file is generated in the directory where the
            # user calls baron. Check if a user explicitly asked for
            # a summary file.
            output_file.write("Summary: 0;\n")
        output_file.write("}\n\n")

        if symbolic_solver_labels:
            labeler = AlphaNumTextLabeler()
        elif labeler is None:
            labeler = NumericLabeler('x')

        symbol_map = SymbolMap()
        sm_bySymbol = symbol_map.bySymbol
        referenced_variable_ids = set()

        #cache frequently called functions
        create_symbol_func = SymbolMap.createSymbol
        create_symbols_func = SymbolMap.createSymbols
        alias_symbol_func = SymbolMap.alias

        # Cache the list of model blocks so we don't have to call
        # model.block_data_objects() many many times, which is slow
        # for indexed blocks
        all_blocks_list = list(model.block_data_objects(active=True,
                                                        sort=sorter,
                                                        descend_into=True))
        active_components_data_var = {}
        for block in all_blocks_list:
            tmp = active_components_data_var[id(block)] = \
                  list(obj for obj in block.component_data_objects(Var,
                                                                   active=True,
                                                                   sort=sorter,
                                                                   descend_into=False))
            create_symbols_func(symbol_map, tmp, labeler)

            # GAH: Not sure this is necessary, and also it would break for
            #      non-mutable indexed params so I am commenting out for now.
            #for param_data in active_components_data(block, Param, sort=sorter):
                #instead of checking if param_data._mutable:
                #if not param_data.is_constant():
                #    create_symbol_func(symbol_map, param_data, labeler)

        symbol_map_variable_ids = set(symbol_map.byObject.keys())
        object_symbol_dictionary = symbol_map.byObject

        def _skip_trivial(constraint_data):
            if skip_trivial_constraints:
                if isinstance(constraint_data, LinearCanonicalRepn):
                    if constraint_data.variables is None:
                        return True
                else:
                    if constraint_data.body.polynomial_degree() == 0:
                        return True
            return False

        #
        # Check for active suffixes to export
        #
        r_o_eqns = []
        c_eqns = []
        l_eqns = []
        branching_priorities_suffixes = []
        for block in all_blocks_list:
            for name, suffix in active_export_suffix_generator(block):
                if name == 'branching_priorities':
                    branching_priorities_suffixes.append(suffix)
                elif name == 'constraint_types':
                    for constraint_data, constraint_type in iteritems(suffix):
                        if not _skip_trivial(constraint_data):
                            if constraint_type.lower() == 'relaxationonly':
                                r_o_eqns.append(constraint_data)
                            elif constraint_type.lower() == 'convex':
                                c_eqns.append(constraint_data)
                            elif constraint_type.lower() == 'local':
                                l_eqns.append(constraint_data)
                            else:
                                raise ValueError(
                                    "A suffix '%s' contained an invalid value: %s\n"
                                    "Choices are: [relaxationonly, convex, local]"
                                    % (suffix.name, constraint_type))
                else:
                    raise ValueError(
                        "The BARON writer can not export suffix with name '%s'. "
                        "Either remove it from block '%s' or deactivate it."
                        % (block.name, name))

        non_standard_eqns = r_o_eqns + c_eqns + l_eqns

        # GAH 1/5/15: Substituting all non-alphanumeric characters for underscore
        #             in labeler so this manual update should no longer be needed
        #
        # If the text labeler is used, correct the labels to be
        # baron-allowed variable names
        # Change '(' and ')' to '__'
        # This way, for simple variable names like 'x(1_2)' --> 'x__1_2__'
        # FIXME: 7/21/14 This may break if users give variable names
        #        with two or more underscores together
        #if symbolic_solver_labels:
        #    for key,label in iteritems(object_symbol_dictionary):
        #        label = label.replace('(','___')
        #        object_symbol_dictionary[key] = label.replace(')','__')

        #
        # BINARY_VARIABLES, INTEGER_VARIABLES, POSITIVE_VARIABLES, VARIABLES
        #

        BinVars = []
        IntVars = []
        PosVars = []
        Vars = []
        for block in all_blocks_list:
            for var_data in active_components_data_var[id(block)]:

                if isinstance(var_data.domain, BooleanSet):
                    TypeList = BinVars
                elif isinstance(var_data.domain, IntegerSet):
                    TypeList = IntVars
                elif isinstance(var_data.domain, RealSet) and \
                     (var_data.lb is not None) and \
                     (var_data.lb >= 0):
                    TypeList = PosVars
                else:
                    TypeList = Vars

                var_name = object_symbol_dictionary[id(var_data)]
                #if len(var_name) > 15:
                #    logger.warning(
                #        "Variable symbol '%s' for variable %s exceeds maximum "
                #        "character limit for BARON. Solver may fail"
                #        % (var_name, var_data.name))

                TypeList.append(var_name)

        if len(BinVars) > 0:
            output_file.write('BINARY_VARIABLES ')
            for var_name in BinVars[:-1]:
                output_file.write(str(var_name)+', ')
            output_file.write(str(BinVars[-1])+';\n\n')
        if len(IntVars) > 0:
            output_file.write('INTEGER_VARIABLES ')
            for var_name in IntVars[:-1]:
                output_file.write(str(var_name)+', ')
            output_file.write(str(IntVars[-1])+';\n\n')

        output_file.write('POSITIVE_VARIABLES ')
        output_file.write('ONE_VAR_CONST__')
        for var_name in PosVars:
            output_file.write(', '+str(var_name))
        output_file.write(';\n\n')

        if len(Vars) > 0:
            output_file.write('VARIABLES ')
            for var_name in Vars[:-1]:
                output_file.write(str(var_name)+', ')
            output_file.write(str(Vars[-1])+';\n\n')

        #
        # LOWER_BOUNDS
        #

        LowerBoundHeader = False
        for block in all_blocks_list:
            for var_data in active_components_data_var[id(block)]:
                if var_data.fixed:
                    var_data_lb = var_data.value
                else:
                    var_data_lb = var_data.lb
                    if var_data_lb == -infinity:
                        var_data_lb = None

                if var_data_lb is not None:
                    if LowerBoundHeader is False:
                        output_file.write("LOWER_BOUNDS{\n")
                        LowerBoundHeader = True
                    name_to_output = object_symbol_dictionary[id(var_data)]
                    lb_string_template = '%s: %'+self._precision_string+';\n'
                    output_file.write(lb_string_template
                                      % (name_to_output, var_data_lb))

        if LowerBoundHeader:
            output_file.write("}\n\n")

        #
        # UPPER_BOUNDS
        #

        UpperBoundHeader = False
        for block in all_blocks_list:
            for var_data in active_components_data_var[id(block)]:
                if var_data.fixed:
                    var_data_ub = var_data.value
                else:
                    var_data_ub = var_data.ub
                    if var_data_ub == infinity:
                        var_data_ub = None

                if var_data_ub is not None:
                    if UpperBoundHeader is False:
                        output_file.write("UPPER_BOUNDS{\n")
                        UpperBoundHeader = True
                    name_to_output = object_symbol_dictionary[id(var_data)]
                    ub_string_template = '%s: %'+self._precision_string+';\n'
                    output_file.write(ub_string_template
                                      % (name_to_output, var_data_ub))

        if UpperBoundHeader:
            output_file.write("}\n\n")

        #
        # BRANCHING_PRIORITIES
        #

        # Specifyig priorities requires that the pyomo model has established an
        # EXTERNAL, float suffix called 'branching_priorities' on the model
        # object, indexed by the relevant variable
        BranchingPriorityHeader = False
        for suffix in branching_priorities_suffixes:
            for var_data, priority in iteritems(suffix):
                if priority is not None:
                    if not BranchingPriorityHeader:
                        output_file.write('BRANCHING_PRIORITIES{\n')
                        BranchingPriorityHeader = True
                    name_to_output = object_symbol_dictionary[id(var_data)]
                    output_file.write(name_to_output+': '+str(priority)+';\n')

        if BranchingPriorityHeader:
            output_file.write("}\n\n")

        #
        # EQUATIONS
        #

        #Equation Declaration
        n_roeqns = len(r_o_eqns)
        n_ceqns = len(c_eqns)
        n_leqns = len(l_eqns)
        eqns = []

        # Alias the constraints by declaration order since Baron does not
        # include the constraint names in the solution file. It is important
        # that this alias not clash with any real constraint labels, hence
        # the use of the ".c<integer>" template. It is not possible to declare
        # a component having this type of name when using standard syntax.
        # There are ways to do it, but it is unlikely someone will.
        order_counter = 0
        alias_template = ".c%d"
        output_file.write('EQUATIONS ')
        output_file.write("c_e_FIX_ONE_VAR_CONST__")
        order_counter += 1
        for block in all_blocks_list:

            for constraint_data in block.component_data_objects(Constraint,
                                                                active=True,
                                                                sort=sorter,
                                                                descend_into=False):

                if (not _skip_trivial(constraint_data)) and \
                   (constraint_data not in non_standard_eqns):

                    eqns.append(constraint_data)

                    con_symbol = \
                        create_symbol_func(symbol_map, constraint_data, labeler)
                    assert not con_symbol.startswith('.')
                    assert con_symbol != "c_e_FIX_ONE_VAR_CONST__"

                    alias_symbol_func(symbol_map,
                                      constraint_data,
                                      alias_template % order_counter)
                    output_file.write(", "+str(con_symbol))
                    order_counter += 1

        output_file.write(";\n\n")

        if n_roeqns > 0:
            output_file.write('RELAXATION_ONLY_EQUATIONS ')
            for i, constraint_data in enumerate(r_o_eqns):
                con_symbol = create_symbol_func(symbol_map, constraint_data, labeler)
                assert not con_symbol.startswith('.')
                assert con_symbol != "c_e_FIX_ONE_VAR_CONST__"
                alias_symbol_func(symbol_map,
                                  constraint_data,
                                  alias_template % order_counter)
                if i == n_roeqns-1:
                    output_file.write(str(con_symbol)+';\n\n')
                else:
                    output_file.write(str(con_symbol)+', ')
                order_counter += 1

        if n_ceqns > 0:
            output_file.write('CONVEX_EQUATIONS ')
            for i, constraint_data in enumerate(c_eqns):
                con_symbol = create_symbol_func(symbol_map, constraint_data, labeler)
                assert not con_symbol.startswith('.')
                assert con_symbol != "c_e_FIX_ONE_VAR_CONST__"
                alias_symbol_func(symbol_map,
                                  constraint_data,
                                  alias_template % order_counter)
                if i == n_ceqns-1:
                    output_file.write(str(con_symbol)+';\n\n')
                else:
                    output_file.write(str(con_symbol)+', ')
                order_counter += 1

        if n_leqns > 0:
            output_file.write('LOCAL_EQUATIONS ')
            for i, constraint_data in enumerate(l_eqns):
                con_symbol = create_symbol_func(symbol_map, constraint_data, labeler)
                assert not con_symbol.startswith('.')
                assert con_symbol != "c_e_FIX_ONE_VAR_CONST__"
                alias_symbol_func(symbol_map,
                                  constraint_data,
                                  alias_template % order_counter)
                if i == n_leqns-1:
                    output_file.write(str(con_symbol)+';\n\n')
                else:
                    output_file.write(str(con_symbol)+', ')
                order_counter += 1

        # Create a dictionary of baron variable names to match to the
        # strings that constraint.to_string() prints. An important
        # note is that the variable strings are padded by spaces so
        # that whole variable names are recognized, and simple
        # variable names are not identified inside longer names.
        # Example: ' x[1] ' -> ' x3 '
        #FIXME: 7/18/14 CLH: This may cause mistakes if spaces in
        #                    variable names are allowed
        vstring_to_bar_dict = {}
        pstring_to_bar_dict = {}
        for block in all_blocks_list:

            for var_data in active_components_data_var[id(block)]:
                variable_stream = StringIO()
                var_data.to_string(ostream=variable_stream, verbose=False)
                variable_string = variable_stream.getvalue()

                variable_string = ' '+variable_string+' '
                vstring_to_bar_dict[variable_string] = \
                    ' '+object_symbol_dictionary[id(var_data)]+' '

            for param in block.component_objects(Param, active=True):
                if param._mutable and param.is_indexed():
                    param_data_iter = \
                        (param_data for index, param_data in iteritems(param))
                elif not param.is_indexed():
                    param_data_iter = iter([param])
                else:
                    param_data_iter = iter([])

                for param_data in param_data_iter:
                    param_stream = StringIO()
                    param.to_string(ostream=param_stream, verbose=False)
                    param_string = param_stream.getvalue()

                    param_string = ' '+param_string+' '
                    pstring_to_bar_dict[param_string] = ' '+str(param_data())+' '

        # Equation Definition
        string_template = '%'+self._precision_string
        output_file.write('c_e_FIX_ONE_VAR_CONST__:  ONE_VAR_CONST__  == 1;\n');
        for constraint_data in itertools.chain(eqns,
                                               r_o_eqns,
                                               c_eqns,
                                               l_eqns):

            #########################
            #CLH: The section below is kind of a hack-y way to use
            #     the expr.to_string function to print
            #     expressions. A stream is created, writen to, and
            #     then the string is recovered and stored in
            #     eqn_body. Then the variable names are converted
            #     to match the variable names that are used in the
            #     bar file.

            # Fill in the body of the equation
            body_string_buffer = StringIO()

            constraint_data.body.to_string(ostream=body_string_buffer,
                                           verbose=False)
            eqn_body = body_string_buffer.getvalue()

            # First, pad the equation so that if there is a
            # variable name at the start or end of the equation,
            # it can still be identified as padded with spaces.

            # Second, change pyomo's ** to baron's ^, also with
            # padding so that variable can always be found with
            # space around them

            # Third, add more padding around multiplication. Pyomo
            # already has spaces between variable on variable
            # multiplication, but not for constants on variables
            eqn_body = ' '+eqn_body+' '
            eqn_body = eqn_body.replace('**',' ^ ')
            eqn_body = eqn_body.replace('*', ' * ')


            #
            # FIXME: The following block of code is extremely inefficient.
            #        We are looping through every parameter and variable in
            #        the model each time we write a constraint expression.
            #
            ################################################
            vnames = [(variable_string, bar_string)
                      for variable_string, bar_string in iteritems(vstring_to_bar_dict)
                      if variable_string in eqn_body]
            for variable_string, bar_string in vnames:
                eqn_body = eqn_body.replace(variable_string, bar_string)
            for param_string, bar_string in iteritems(pstring_to_bar_dict):
                eqn_body = eqn_body.replace(param_string, bar_string)
            referenced_variable_ids.update(
                id(sm_bySymbol[bar_string.strip()]())
                for variable_string, bar_string in vnames)
            ################################################

            if len(vnames) == 0:
                assert not skip_trivial_constraints
                eqn_body += "+ 0 * ONE_VAR_CONST__ "

            # 7/29/14 CLH:
            #FIXME: Baron doesn't handle many of the
            #       intrinsic_functions available in pyomo. The
            #       error message given by baron is also very
            #       weak.  Either a function here to re-write
            #       unallowed expressions or a way to track solver
            #       capability by intrinsic_expression would be
            #       useful.
            ##########################

            con_symbol = object_symbol_dictionary[id(constraint_data)]
            output_file.write(str(con_symbol) + ': ')

            # Fill in the left and right hand side (constants) of
            #  the equations

            # Equality constraint
            if constraint_data.equality:
                eqn_lhs = ''
                eqn_rhs = ' == ' + \
                          str(string_template
                              % self._get_bound(constraint_data.upper))

            # Greater than constraint
            elif constraint_data.upper is None:
                eqn_rhs = ' >= ' + \
                          str(string_template
                              % self._get_bound(constraint_data.lower))
                eqn_lhs = ''

            # Less than constraint
            elif constraint_data.lower is None:
                eqn_rhs = ' <= ' + \
                          str(string_template
                              % self._get_bound(constraint_data.upper))
                eqn_lhs = ''

            # Double-sided constraint
            elif (constraint_data.upper is not None) and \
                 (constraint_data.lower is not None):
                eqn_lhs = str(string_template
                              % self._get_bound(constraint_data.lower)) + \
                          ' <= '
                eqn_rhs = ' <= ' + \
                          str(string_template
                              % self._get_bound(constraint_data.upper))

            eqn_string = eqn_lhs + eqn_body + eqn_rhs + ';\n'
            output_file.write(eqn_string)

        #
        # OBJECTIVE
        #

        output_file.write("\nOBJ: ")

        n_objs = 0
        for block in all_blocks_list:

            for objective_data in block.component_data_objects(Objective,
                                                               active=True,
                                                               sort=sorter,
                                                               descend_into=False):

                n_objs += 1
                if n_objs > 1:
                    raise ValueError("The BARON writer has detected multiple active "
                                     "objective functions on model %s, but "
                                     "currently only handles a single objective."
                                     % (model.name))

                # create symbol
                create_symbol_func(symbol_map, objective_data, labeler)
                alias_symbol_func(symbol_map, objective_data, "__default_objective__")

                if objective_data.is_minimizing():
                    output_file.write("minimize ")
                else:
                    output_file.write("maximize ")

                #FIXME 7/18/14 See above, constraint writing
                #              section. Will cause problems if there
                #              are spaces in variables
                # Similar to the constraints section above, the
                # objective is generated from the expr.to_string
                # function.
                obj_stream = StringIO()
                objective_data.expr.to_string(ostream=obj_stream, verbose=False)

                obj_string = ' '+obj_stream.getvalue()+' '
                obj_string = obj_string.replace('**',' ^ ')
                obj_string = obj_string.replace('*', ' * ')

                #
                # FIXME: The following block of code is extremely inefficient.
                #        We are looping through every parameter and variable in
                #        the model each time we write an expression.
                #
                ################################################
                vnames = [(variable_string, bar_string)
                          for variable_string, bar_string in iteritems(vstring_to_bar_dict)
                          if variable_string in obj_string]
                for variable_string, bar_string in vnames:
                    obj_string = obj_string.replace(variable_string, bar_string)
                for param_string, bar_string in iteritems(pstring_to_bar_dict):
                    obj_string = obj_string.replace(param_string, bar_string)
                referenced_variable_ids.update(
                    id(sm_bySymbol[bar_string.strip()]())
                    for variable_string, bar_string in vnames)
                ################################################

        output_file.write(obj_string+";\n\n")

        #
        # STARTING_POINT
        #
        output_file.write('STARTING_POINT{\nONE_VAR_CONST__: 1;\n')
        string_template = '%s: %'+self._precision_string+';\n'
        for block in all_blocks_list:
            for var_data in active_components_data_var[id(block)]:
                starting_point = var_data.value
                if starting_point is not None:
                    var_name = object_symbol_dictionary[id(var_data)]
                    output_file.write(string_template % (var_name, starting_point))

        output_file.write('}\n\n')

        output_file.close()

        # Clean up the symbol map to only contain variables referenced
        # in the active constraints
        vars_to_delete = symbol_map_variable_ids - referenced_variable_ids
        sm_byObject = symbol_map.byObject
        for varid in vars_to_delete:
            symbol = sm_byObject[varid]
            del sm_byObject[varid]
            del sm_bySymbol[symbol]

        del symbol_map_variable_ids
        del referenced_variable_ids

        return output_filename, symbol_map

Example 37

Project: pyomo
Source File: cpxlp.py
View license
    def _print_expr_canonical(self,
                              x,
                              output_file,
                              object_symbol_dictionary,
                              variable_symbol_dictionary,
                              is_objective,
                              column_order,
                              force_objective_constant=False):

        """
        Return a expression as a string in LP format.

        Note that this function does not handle any differences in LP format
        interpretation by the solvers (e.g. CPlex vs GLPK).  That decision is
        left up to the caller.

        required arguments:
          x: A Pyomo canonical expression to write in LP format
        """
        assert (not force_objective_constant) or (is_objective)

        # cache - this is referenced numerous times.
        if isinstance(x, LinearCanonicalRepn):
            var_hashes = None # not needed
        else:
            var_hashes = x[-1]

        #
        # Linear
        #
        linear_coef_string_template = '%+'+self._precision_string+' %s\n'
        if isinstance(x, LinearCanonicalRepn):

            #
            # optimization (these might be generated on the fly)
            #
            coefficients = x.linear
            if coefficients is not None:
                variables = x.variables

                # the 99% case is when the input instance is a linear
                # canonical expression, so the exception should be rare.
                for vardata in variables:
                    self._referenced_variable_ids[id(vardata)] = vardata

                if column_order is None:
                    sorted_names = [(variable_symbol_dictionary[id(variables[i])],
                                     coefficients[i])
                                    for i in xrange(0,len(coefficients))]
                    sorted_names.sort()
                else:
                    sorted_names = [(variables[i], coefficients[i])
                                    for i in xrange(0,len(coefficients))]
                    sorted_names.sort(key=lambda _x: column_order[_x[0]])
                    sorted_names = [(variable_symbol_dictionary[id(var)], coef)
                                    for var, coef in sorted_names]

                for name, coef in sorted_names:
                    output_file.write(linear_coef_string_template % (coef, name))

            elif not is_objective:
                # If we made it to here we are outputing
                # trivial constraints place 0 *
                # ONE_VAR_CONSTANT on this side of the
                # constraint for the benefit of solvers like
                # Glpk that cannot parse an LP file without
                # a variable on the left hand side.
                output_file.write(linear_coef_string_template
                                  % (0, 'ONE_VAR_CONSTANT'))

        elif 1 in x:

            for var_hash in x[1]:
                vardata = var_hashes[var_hash]
                self._referenced_variable_ids[id(vardata)] = vardata

            if column_order is None:
                sorted_names = [(variable_symbol_dictionary[id(var_hashes[var_hash])],
                                 var_coefficient)
                                for var_hash, var_coefficient in iteritems(x[1])]
                sorted_names.sort()
            else:
                sorted_names = [(var_hashes[var_hash], var_coefficient)
                                for var_hash, var_coefficient in iteritems(x[1])]
                sorted_names.sort(key=lambda _x: column_order[_x[0]])
                sorted_names = [(variable_symbol_dictionary[id(var)], coef)
                                for var, coef in sorted_names]

            for name, coef in sorted_names:
                output_file.write(linear_coef_string_template % (coef, name))

        #
        # Quadratic
        #
        quad_coef_string_template = '%+'+self._precision_string+' '
        if canonical_degree(x) == 2:

            # first, make sure there is something to output
            # - it is possible for all terms to have
            # coefficients equal to 0.0, in which case you
            # don't want to get into the bracket notation at
            # all.
            # NOTE: if the coefficient is really 0.0, it
            #       should be preprocessed out by the
            #       canonial expression generator!
            found_nonzero_term = False # until proven otherwise
            for var_hash, var_coefficient in iteritems(x[2]):
                for var in var_hash:
                    vardata = var_hashes[var]

                if math.fabs(var_coefficient) != 0.0:
                    found_nonzero_term = True
                    break

            if found_nonzero_term:

                output_file.write("+ [\n")

                num_output = 0

                var_hashes_order = list(iterkeys(x[2]))
                # sort by the sorted tuple of symbols (or column assignments)
                # for the variables appearing in the term
                if column_order is None:
                    var_hashes_order.sort(
                        key=lambda term: \
                          sorted(variable_symbol_dictionary[id(var_hashes[vh])]
                                 for vh in term))
                else:
                    var_hashes_order.sort(
                        key=lambda term: sorted(column_order[var_hashes[vh]]
                                                for vh in term))

                for var_hash in var_hashes_order:

                    coefficient = x[2][var_hash]

                    if is_objective:
                        coefficient *= 2

                    # times 2 because LP format requires /2 for all the quadratic
                    # terms /of the objective only/.  Discovered the last bit thru
                    # trial and error.  Obnoxious.
                    # Ref: ILog CPlex 8.0 User's Manual, p197.

                    output_file.write(quad_coef_string_template % coefficient)
                    term_variables = []

                    var_hash_order = list(iterkeys(var_hash))
                    # sort by symbols (or column assignments)
                    if column_order is None:
                        var_hash_order.sort(
                            key=lambda vh: \
                              variable_symbol_dictionary[id(var_hashes[vh])])
                    else:
                        var_hash_order.sort(
                            key=lambda vh: column_order[var_hashes[vh]])

                    # sort the term for consistent output
                    for var in var_hash_order:
                        vardata = var_hashes[var]
                        self._referenced_variable_ids[id(vardata)] = vardata
                        name = variable_symbol_dictionary[id(vardata)]
                        term_variables.append(name)

                    if len(term_variables) == 2:
                        output_file.write("%s * %s"
                                          % (term_variables[0], term_variables[1]))
                    else:
                        output_file.write("%s ^ 2" % (term_variables[0]))
                    output_file.write("\n")

                output_file.write("]")

                if is_objective:
                    output_file.write(' / 2\n')
                    # divide by 2 because LP format requires /2 for all the quadratic
                    # terms.  Weird.  Ref: ILog CPlex 8.0 User's Manual, p197
                else:
                    output_file.write("\n")

        #
        # Constant offset
        #
        if isinstance(x, LinearCanonicalRepn):
            constant = x.constant
        else:
            if 0 in x:
                constant = x[0][None]
            else:
                constant = None

        if constant is not None:
            offset = constant
        else:
            offset=0.0

        # Currently, it appears that we only need to print
        # the constant offset term for objectives.
        obj_string_template = '%+'+self._precision_string+' %s\n'
        if is_objective and (force_objective_constant or (offset != 0.0)):
            output_file.write(obj_string_template
                              % (offset, 'ONE_VAR_CONSTANT'))

        #
        # Return constant offset
        #
        return offset

Example 38

Project: pyomo
Source File: CPLEX.py
View license
    def process_soln_file(self,results):

        # the only suffixes that we extract from CPLEX are
        # constraint duals, constraint slacks, and variable
        # reduced-costs. scan through the solver suffix list
        # and throw an exception if the user has specified
        # any others.
        extract_duals = False
        extract_slacks = False
        extract_reduced_costs = False
        extract_rc = False
        extract_lrc = False
        extract_urc = False
        for suffix in self._suffixes:
            flag=False
            if re.match(suffix,"dual"):
                extract_duals = True
                flag=True
            if re.match(suffix,"slack"):
                extract_slacks = True
                flag=True
            if re.match(suffix,"rc"):
                extract_reduced_costs = True
                extract_rc = True
                flag=True
            if re.match(suffix,"lrc"):
                extract_reduced_costs = True
                extract_lrc = True
                flag=True
            if re.match(suffix,"urc"):
                extract_reduced_costs = True
                extract_urc = True
                flag=True
            if not flag:
                raise RuntimeError("***The CPLEX solver plugin cannot extract solution suffix="+suffix)

        # check for existence of the solution file
        # not sure why we just return - would think that we
        # would want to indicate some sort of error
        if not os.path.exists(self._soln_file):
            return

        range_duals = {}
        range_slacks = {}
        soln = Solution()
        soln.objective['__default_objective__'] = {'Value':None}

        # caching for efficiency
        soln_variables = soln.variable
        soln_constraints = soln.constraint

        INPUT = open(self._soln_file, "r")
        results.problem.number_of_objectives=1
        time_limit_exceeded = False
        mip_problem=False
        for line in INPUT:
            line = line.strip()
            line = line.lstrip('<?/')
            line = line.rstrip('/>?')
            tokens=line.split(' ')

            if tokens[0] == "variable":
                variable_name = None
                variable_value = None
                variable_reduced_cost = None
                variable_status = None
                for i in xrange(1,len(tokens)):
                    field_name =  tokens[i].split('=')[0]
                    field_value = tokens[i].split('=')[1].lstrip("\"").rstrip("\"")
                    if field_name == "name":
                        variable_name = field_value
                    elif field_name == "value":
                        variable_value = field_value
                    elif (extract_reduced_costs is True) and (field_name == "reducedCost"):
                        variable_reduced_cost = field_value
                    elif (extract_reduced_costs is True) and (field_name == "status"):
                        variable_status = field_value

                # skip the "constant-one" variable, used to capture/retain objective offsets in the CPLEX LP format.
                if variable_name != "ONE_VAR_CONSTANT":
                    variable = soln_variables[variable_name] = {"Value" : float(variable_value)}
                    if (variable_reduced_cost is not None) and (extract_reduced_costs is True):
                        try:
                            if extract_rc is True:
                                variable["Rc"] = float(variable_reduced_cost)
                            if variable_status is not None:
                                if extract_lrc is True:
                                    if variable_status == "LL":
                                        variable["Lrc"] = float(variable_reduced_cost)
                                    else:
                                        variable["Lrc"] = 0.0
                                if extract_urc is True:
                                    if variable_status == "UL":
                                        variable["Urc"] = float(variable_reduced_cost)
                                    else:
                                        variable["Urc"] = 0.0
                        except:
                            raise ValueError("Unexpected reduced-cost value="+str(variable_reduced_cost)+" encountered for variable="+variable_name)
            elif (tokens[0] == "constraint") and ((extract_duals is True) or (extract_slacks is True)):
                is_range = False
                rlabel = None
                rkey = None
                for i in xrange(1,len(tokens)):
                    field_name =  tokens[i].split('=')[0]
                    field_value = tokens[i].split('=')[1].lstrip("\"").rstrip("\"")
                    if field_name == "name":
                        if field_value.startswith('c_'):
                            constraint = soln_constraints[field_value] = {}
                        elif field_value.startswith('r_l_'):
                            is_range = True
                            rlabel = field_value[4:]
                            rkey = 0
                        elif field_value.startswith('r_u_'):
                            is_range = True
                            rlabel = field_value[4:]
                            rkey = 1
                    elif (extract_duals is True) and (field_name == "dual"): # for LPs
                        if is_range is False:
                            constraint["Dual"] = float(field_value)
                        else:
                            range_duals.setdefault(rlabel,[0,0])[rkey] = float(field_value)
                    elif (extract_slacks is True) and (field_name == "slack"): # for MIPs
                        if is_range is False:
                            constraint["Slack"] = float(field_value)
                        else:
                            range_slacks.setdefault(rlabel,[0,0])[rkey] = float(field_value)
            elif tokens[0].startswith("problemName"):
                filename = (tokens[0].split('=')[1].strip()).lstrip("\"").rstrip("\"")
                results.problem.name = os.path.basename(filename)
                if '.' in results.problem.name:
                    results.problem.name = results.problem.name.split('.')[0]
                tINPUT=open(filename,"r")
                for tline in tINPUT:
                    tline = tline.strip()
                    if tline == "":
                        continue
                    tokens = re.split('[\t ]+',tline)
                    if tokens[0][0] in ['\\', '*']:
                        continue
                    elif tokens[0] == "NAME":
                        results.problem.name = tokens[1]
                    else:
                        sense = tokens[0].lower()
                        if sense in ['max','maximize']:
                            results.problem.sense = ProblemSense.maximize
                        if sense in ['min','minimize']:
                            results.problem.sense = ProblemSense.minimize
                    break
                tINPUT.close()

            elif tokens[0].startswith("objectiveValue"):
                objective_value = (tokens[0].split('=')[1].strip()).lstrip("\"").rstrip("\"")
                soln.objective['__default_objective__']['Value'] = float(objective_value)
            elif tokens[0].startswith("solutionStatusValue"):
               pieces = tokens[0].split("=")
               solution_status = eval(pieces[1])
               # solution status = 1 => optimal
               # solution status = 3 => infeasible
               if soln.status == SolutionStatus.unknown:
                  if solution_status == 1:
                    soln.status = SolutionStatus.optimal
                  elif solution_status == 3:
                    soln.status = SolutionStatus.infeasible
                    soln.gap = None
                  else:
                      # we are flagging anything with a solution status >= 4 as an error, to possibly
                      # be over-ridden as we learn more about the status (e.g., due to time limit exceeded).
                      soln.status = SolutionStatus.error
                      soln.gap = None
            elif tokens[0].startswith("solutionStatusString"):
                solution_status = ((" ".join(tokens).split('=')[1]).strip()).lstrip("\"").rstrip("\"")
                if solution_status in ["optimal", "integer optimal solution", "integer optimal, tolerance"]:
                    soln.status = SolutionStatus.optimal
                    soln.gap = 0.0
                    results.problem.lower_bound = soln.objective['__default_objective__']['Value']
                    results.problem.upper_bound = soln.objective['__default_objective__']['Value']
                    if "integer" in solution_status:
                        mip_problem=True
                elif solution_status in ["infeasible"]:
                    soln.status = SolutionStatus.infeasible
                    soln.gap = None
                elif solution_status in ["time limit exceeded"]:
                    # we need to know if the solution is primal feasible, and if it is, set the solution status accordingly.
                    # for now, just set the flag so we can trigger the logic when we see the primalFeasible keyword.
                    time_limit_exceeded = True
            elif tokens[0].startswith("MIPNodes"):
                if mip_problem:
                    n = eval(eval((" ".join(tokens).split('=')[1]).strip()).lstrip("\"").rstrip("\""))
                    results.solver.statistics.branch_and_bound.number_of_created_subproblems=n
                    results.solver.statistics.branch_and_bound.number_of_bounded_subproblems=n
            elif tokens[0].startswith("primalFeasible") and (time_limit_exceeded is True):
                primal_feasible = int(((" ".join(tokens).split('=')[1]).strip()).lstrip("\"").rstrip("\""))
                if primal_feasible == 1:
                    soln.status = SolutionStatus.feasible
                    if (results.problem.sense == ProblemSense.minimize):
                        results.problem.upper_bound = soln.objective['__default_objective__']['Value']
                    else:
                        results.problem.lower_bound = soln.objective['__default_objective__']['Value']
                else:
                    soln.status = SolutionStatus.infeasible


        if self._best_bound is not None:
            if results.problem.sense == ProblemSense.minimize:
                results.problem.lower_bound = self._best_bound
            else:
                results.problem.upper_bound = self._best_bound
        if self._gap is not None:
            soln.gap = self._gap

        # For the range constraints, supply only the dual with the largest
        # magnitude (at least one should always be numerically zero)
        for key,(ld,ud) in iteritems(range_duals):
            if abs(ld) > abs(ud):
                soln_constraints['r_l_'+key] = {"Dual" : ld}
            else:
                soln_constraints['r_l_'+key] = {"Dual" : ud}                # Use the same key
        # slacks
        for key,(ls,us) in iteritems(range_slacks):
            if abs(ls) > abs(us):
                soln_constraints.setdefault('r_l_'+key,{})["Slack"] = ls
            else:
                soln_constraints.setdefault('r_l_'+key,{})["Slack"] = us    # Use the same key

        if not results.solver.status is SolverStatus.error:
            if results.solver.termination_condition in [TerminationCondition.unknown,
                                                        #TerminationCondition.maxIterations,
                                                        #TerminationCondition.minFunctionValue,
                                                        #TerminationCondition.minStepLength,
                                                        TerminationCondition.globallyOptimal,
                                                        TerminationCondition.locallyOptimal,
                                                        TerminationCondition.optimal,
                                                        #TerminationCondition.maxEvaluations,
                                                        TerminationCondition.other]:
                results.solution.insert(soln)
            elif (results.solver.termination_condition is \
                  TerminationCondition.maxTimeLimit) and \
                  (soln.status is not SolutionStatus.infeasible):
                results.solution.insert(soln)

        INPUT.close()

Example 39

Project: pyomo
Source File: GLPK_old.py
View license
    def process_soln_file(self, results):
        soln  = None
        pdata = self._glpfile
        psoln = self._rawfile

        prob = results.problem
        solv = results.solver

        prob.name = 'unknown'   # will ostensibly get updated

        # Step 1: Make use of the GLPK's machine parseable format (--wglp) to
        #    collect variable and constraint names.
        glp_line_count = ' -- File not yet opened'

        # The trick for getting the variable names correctly matched to their
        # values is the note that the --wglp option outputs them in the same
        # order as the --write output.
        # Note that documentation for these formats is available from the GLPK
        # documentation of 'glp_read_prob' and 'glp_write_sol'
        variable_names = dict()    # cols
        constraint_names = dict()  # rows
        obj_name = 'objective'

        try:
            f = open(pdata, 'r')

            glp_line_count = 1
            pprob, ptype, psense, prows, pcols, pnonz = f.readline().split()
            prows = int(prows)  # fails if not a number; intentional
            pcols = int(pcols)  # fails if not a number; intentional
            pnonz = int(pnonz)  # fails if not a number; intentional

            if pprob != 'p' or \
               ptype not in ('lp', 'mip') or \
               psense not in ('max', 'min') or \
               prows < 0 or pcols < 0 or pnonz < 0:
                raise ValueError

            self.is_integer = ('mip' == ptype and True or False)
            prob.sense = 'min' == psense and ProblemSense.minimize or ProblemSense.maximize
            prob.number_of_constraints = prows
            prob.number_of_nonzeros    = pnonz
            prob.number_of_variables   = pcols

            extract_duals = False
            extract_reduced_costs = False
            for suffix in self._suffixes:
                flag = False
                if re.match(suffix, "dual"):
                    if not self.is_integer:
                        flag = True
                        extract_duals = True
                if re.match(suffix, "rc"):
                    if not self.is_integer:
                        flag = True
                        extract_reduced_costs = True
                if not flag:
                    # TODO: log a warning
                    pass

            for line in f:
                glp_line_count += 1
                tokens = line.split()
                switch = tokens.pop(0)

                if switch in ('a', 'e', 'i', 'j'):
                    pass
                elif 'n' == switch:  # naming some attribute
                    ntype = tokens.pop(0)
                    name  = tokens.pop()
                    if 'i' == ntype:      # row
                        row = tokens.pop()
                        constraint_names[int(row)] = name
                        # --write order == --wglp order; store name w/ row no
                    elif 'j' == ntype:    # var
                        col = tokens.pop()
                        variable_names[int(col)] = name
                        # --write order == --wglp order; store name w/ col no
                    elif 'z' == ntype:    # objective
                        obj_name = name
                    elif 'p' == ntype:    # problem name
                        prob.name = name
                    else:                 # anything else is incorrect.
                        raise ValueError

                else:
                    raise ValueError
        except Exception:
            e = sys.exc_info()[1]
            msg = "Error parsing solution description file, line %s: %s"
            raise ValueError(msg % (glp_line_count, str(e)))
        finally:
            f.close()

        range_duals = {}
        # Step 2: Make use of the GLPK's machine parseable format (--write) to
        #    collect solution variable and constraint values.
        raw_line_count = ' -- File not yet opened'
        try:
            f = open(psoln, 'r')

            raw_line_count = 1
            prows, pcols = f.readline().split()
            prows = int(prows)  # fails if not a number; intentional
            pcols = int(pcols)  # fails if not a number; intentional

            raw_line_count = 2
            if self.is_integer:
                pstat, obj_val = f.readline().split()
            else:
                pstat, dstat, obj_val = f.readline().split()
                dstat = float(dstat) # dual status of basic solution.  Ignored.

            pstat = float(pstat)       # fails if not a number; intentional
            obj_val = float(obj_val)   # fails if not a number; intentional
            soln_status = self._glpk_get_solution_status(pstat)

            if soln_status is SolutionStatus.infeasible:
                solv.termination_condition = TerminationCondition.infeasible

            elif soln_status is SolutionStatus.unbounded:
                solv.termination_condition = TerminationCondition.unbounded

            elif soln_status is SolutionStatus.other:
                if solv.termination_condition == TerminationCondition.unknown:
                    solv.termination_condition = TerminationCondition.other

            elif soln_status in (SolutionStatus.optimal, SolutionStatus.feasible):
                soln   = results.solution.add()
                soln.status = soln_status

                prob.lower_bound = obj_val
                prob.upper_bound = obj_val

                # TODO: Does a 'feasible' status mean that we're optimal?
                soln.gap=0.0
                solv.termination_condition = TerminationCondition.optimal

                # I'd like to choose the correct answer rather than just doing
                # something like commenting the obj_name line.  The point is that
                # we ostensibly could or should make use of the user's choice in
                # objective name.  In that vein I'd like to set the objective value
                # to the objective name.  This would make parsing on the user end
                # less 'arbitrary', as in the yaml key 'f'.  Weird
                soln.objective[obj_name] = {'Value': obj_val}

                if (self.is_integer is True) or (extract_duals is False):
                    # we use nothing from this section so just read in the
                    # lines and throw them away
                    for mm in range(1, prows +1):
                        raw_line_count += 1
                        f.readline()
                else:
                    for mm in range(1, prows +1):
                        raw_line_count += 1

                        rstat, rprim, rdual = f.readline().split()
                        rstat = float(rstat)

                        cname = constraint_names[mm]
                        if 'ONE_VAR_CONSTANT' == cname[-16:]: continue

                        if cname.startswith('c_'):
                            soln.constraint[cname] = {"Dual":float(rdual)}
                        elif cname.startswith('r_l_'):
                            range_duals.setdefault(cname[4:],[0,0])[0] = float(rdual)
                        elif cname.startswith('r_u_'):
                            range_duals.setdefault(cname[4:],[0,0])[1] = float(rdual)

                for nn in range(1, pcols +1):
                    raw_line_count += 1
                    if self.is_integer:
                        cprim = f.readline()      # should be a single number
                    else:
                        cstat, cprim, cdual = f.readline().split()
                        cstat = float(cstat)  # fails if not a number; intentional

                    vname = variable_names[nn]
                    if 'ONE_VAR_CONSTANT' == vname: continue
                    cprim = float(cprim)
                    if extract_reduced_costs is False:
                        soln.variable[vname] = {"Value" : cprim}
                    else:
                        soln.variable[vname] = {"Value" : cprim,
                                                "Rc" : float(cdual)}

        except Exception:
            print(sys.exc_info()[1])
            msg = "Error parsing solution data file, line %d" % raw_line_count
            raise ValueError(msg)
        finally:
            f.close()

        if not soln is None:
            # For the range constraints, supply only the dual with the largest
            # magnitude (at least one should always be numerically zero)
            scon = soln.Constraint
            for key,(ld,ud) in iteritems(range_duals):
                if abs(ld) > abs(ud):
                    scon['r_l_'+key] = {"Dual":ld}
                else:
                    scon['r_l_'+key] = {"Dual":ud}      # Use the same key

Example 40

Project: pyomo
Source File: gurobi_direct.py
View license
    def _populate_gurobi_instance (self, pyomo_instance):

        from pyomo.core.base import Var, Objective, Constraint, SOSConstraint
        from pyomo.repn import LinearCanonicalRepn, canonical_degree

        try:
            grbmodel = Model(name=pyomo_instance.name)
        except Exception:
            e = sys.exc_info()[1]
            msg = 'Unable to create Gurobi model.  Have you installed the Python'\
            '\n       bindings for Gurobi?\n\n\tError message: %s'
            raise Exception(msg % e)

        if self._symbolic_solver_labels:
            labeler = TextLabeler()
        else:
            labeler = NumericLabeler('x')
        # cache to avoid dictionary getitem calls in the loops below.
        self_symbol_map = self._symbol_map = SymbolMap()
        pyomo_instance.solutions.add_symbol_map(self_symbol_map)
        self._smap_id = id(self_symbol_map)

        # we use this when iterating over the constraints because it
        # will have a much smaller hash table, we also use this for
        # the warm start code after it is cleaned to only contain
        # variables referenced in the constraints
        self_variable_symbol_map = self._variable_symbol_map = SymbolMap()
        var_symbol_pairs = []

        # maps _VarData labels to the corresponding Gurobi variable object
        pyomo_gurobi_variable_map = {}

        self._referenced_variable_ids.clear()

        # cache to avoid dictionary getitem calls in the loop below.
        grb_infinity = GRB.INFINITY

        for var_value in pyomo_instance.component_data_objects(Var, active=True):

            lb = -grb_infinity
            ub = grb_infinity

            if (var_value.lb is not None) and (var_value.lb != -infinity):
                lb = value(var_value.lb)
            if (var_value.ub is not None) and (var_value.ub != infinity):
                ub = value(var_value.ub)

            # _VarValue objects will not be in the symbol map yet, so
            # avoid some checks.
            var_value_label = self_symbol_map.createSymbol(var_value, labeler)
            var_symbol_pairs.append((var_value, var_value_label))

            # be sure to impart the integer and binary nature of any variables
            if var_value.is_integer():
                var_type = GRB.INTEGER
            elif var_value.is_binary():
                var_type = GRB.BINARY
            elif var_value.is_continuous():
                var_type = GRB.CONTINUOUS
            else:
                raise TypeError("Invalid domain type for variable with name '%s'. "
                                "Variable is not continuous, integer, or binary.")

            pyomo_gurobi_variable_map[var_value_label] = \
                grbmodel.addVar(lb=lb, \
                                ub=ub, \
                                vtype=var_type, \
                                name=var_value_label)

        self_variable_symbol_map.addSymbols(var_symbol_pairs)

        grbmodel.update()

        # The next loop collects the following component types from the model:
        #  - SOSConstraint
        #  - Objective
        #  - Constraint
        sos1 = self._capabilities.sos1
        sos2 = self._capabilities.sos2
        modelSOS = ModelSOS()
        objective_cntr = 0
        # Track the range constraints and their associated variables added by gurobi
        self._last_native_var_idx = grbmodel.NumVars-1
        range_var_idx = grbmodel.NumVars
        _self_range_con_var_pairs = self._range_con_var_pairs = []
        for block in pyomo_instance.block_data_objects(active=True):

            gen_obj_canonical_repn = \
                getattr(block, "_gen_obj_canonical_repn", True)
            gen_con_canonical_repn = \
                getattr(block, "_gen_con_canonical_repn", True)
            # Get/Create the ComponentMap for the repn
            if not hasattr(block,'_canonical_repn'):
                block._canonical_repn = ComponentMap()
            block_canonical_repn = block._canonical_repn

            # SOSConstraints
            for soscondata in block.component_data_objects(SOSConstraint,
                                                           active=True,
                                                           descend_into=False):
                level = soscondata.level
                if (level == 1 and not sos1) or \
                   (level == 2 and not sos2) or \
                   (level > 2):
                    raise RuntimeError(
                        "Solver does not support SOS level %s constraints" % (level,))
                modelSOS.count_constraint(self_symbol_map,
                                          labeler,
                                          self_variable_symbol_map,
                                          pyomo_gurobi_variable_map,
                                          soscondata)

            # Objective
            for obj_data in block.component_data_objects(Objective,
                                                         active=True,
                                                         descend_into=False):

                if objective_cntr > 1:
                    raise ValueError(
                        "Multiple active objectives found on Pyomo instance '%s'. "
                        "Solver '%s' will only handle a single active objective" \
                        % (pyomo_instance.name, self.type))

                sense = GRB_MIN if (obj_data.is_minimizing()) else GRB_MAX
                grbmodel.ModelSense = sense
                obj_expr = LinExpr()

                if gen_obj_canonical_repn:
                    obj_repn = generate_canonical_repn(obj_data.expr)
                    block_canonical_repn[obj_data] = obj_repn
                else:
                    obj_repn = block_canonical_repn[obj_data]

                if isinstance(obj_repn, LinearCanonicalRepn):

                    if obj_repn.constant != None:
                        obj_expr.addConstant(obj_repn.constant)

                    if obj_repn.linear != None:

                        for i in xrange(len(obj_repn.linear)):
                            var_coefficient = obj_repn.linear[i]
                            var_value = obj_repn.variables[i]
                            self._referenced_variable_ids.add(id(var_value))
                            label = self_variable_symbol_map.getSymbol(var_value)
                            obj_expr.addTerms(var_coefficient,
                                              pyomo_gurobi_variable_map[label])
                else:

                    if 0 in obj_repn: # constant term
                        obj_expr.addConstant(obj_repn[0][None])

                    if 1 in obj_repn: # first-order terms
                        hash_to_variable_map = obj_repn[-1]
                        for var_hash, var_coefficient in iteritems(obj_repn[1]):
                            vardata = hash_to_variable_map[var_hash]
                            self._referenced_variable_ids.add(id(vardata))
                            label = self_variable_symbol_map.getSymbol(vardata)
                            obj_expr.addTerms(var_coefficient,
                                              pyomo_gurobi_variable_map[label])

                    if 2 in obj_repn:
                        obj_expr = QuadExpr(obj_expr)
                        hash_to_variable_map = obj_repn[-1]
                        for quad_repn, coef in iteritems(obj_repn[2]):
                            gurobi_expr = QuadExpr(coef)
                            for var_hash, exponent in iteritems(quad_repn):
                                vardata = hash_to_variable_map[var_hash]
                                self._referenced_variable_ids.add(id(vardata))
                                gurobi_var = pyomo_gurobi_variable_map\
                                             [self_variable_symbol_map.\
                                              getSymbol(vardata)]
                                gurobi_expr *= gurobi_var
                                if exponent == 2:
                                    gurobi_expr *= gurobi_var
                            obj_expr += gurobi_expr

                    degree = canonical_degree(obj_repn)
                    if (degree is None) or (degree > 2):
                        raise ValueError(
                            "gurobi_direct plugin does not support general nonlinear "
                            "objective expressions (only linear or quadratic).\n"
                            "Objective: %s" % (obj_data.name))

                # need to cache the objective label, because the
                # GUROBI python interface doesn't track this.
                # _ObjectiveData objects will not be in the symbol map
                # yet, so avoid some checks.
                self._objective_label = \
                    self_symbol_map.createSymbol(obj_data, labeler)

                grbmodel.setObjective(obj_expr, sense=sense)

            # Constraint
            for constraint_data in block.component_data_objects(Constraint,
                                                                active=True,
                                                                descend_into=False):

                if (constraint_data.lower is None) and \
                   (constraint_data.upper is None):
                    continue  # not binding at all, don't bother

                con_repn = None
                if isinstance(constraint_data, LinearCanonicalRepn):
                    con_repn = constraint_data
                else:
                    if gen_con_canonical_repn:
                        con_repn = generate_canonical_repn(constraint_data.body)
                        block_canonical_repn[constraint_data] = con_repn
                    else:
                        con_repn = block_canonical_repn[constraint_data]

                offset = 0.0
                # _ConstraintData objects will not be in the symbol
                # map yet, so avoid some checks.
                constraint_label = \
                    self_symbol_map.createSymbol(constraint_data, labeler)

                trivial = False
                if isinstance(con_repn, LinearCanonicalRepn):

                    #
                    # optimization (these might be generated on the fly)
                    #
                    constant = con_repn.constant
                    coefficients = con_repn.linear
                    variables = con_repn.variables

                    if constant is not None:
                        offset = constant
                    expr = LinExpr() + offset

                    if coefficients is not None:

                        linear_coefs = list()
                        linear_vars = list()

                        for i in xrange(len(coefficients)):

                            var_coefficient = coefficients[i]
                            var_value = variables[i]
                            self._referenced_variable_ids.add(id(var_value))
                            label = self_variable_symbol_map.getSymbol(var_value)
                            linear_coefs.append(var_coefficient)
                            linear_vars.append(pyomo_gurobi_variable_map[label])

                        expr += LinExpr(linear_coefs, linear_vars)

                    else:

                        trivial = True

                else:

                    if 0 in con_repn:
                        offset = con_repn[0][None]
                    expr = LinExpr() + offset

                    if 1 in con_repn: # first-order terms

                        linear_coefs = list()
                        linear_vars = list()

                        hash_to_variable_map = con_repn[-1]
                        for var_hash, var_coefficient in iteritems(con_repn[1]):
                            var = hash_to_variable_map[var_hash]
                            self._referenced_variable_ids.add(id(var))
                            label = self_variable_symbol_map.getSymbol(var)
                            linear_coefs.append( var_coefficient )
                            linear_vars.append( pyomo_gurobi_variable_map[label] )

                        expr += LinExpr(linear_coefs, linear_vars)

                    if 2 in con_repn: # quadratic constraint
                        if _GUROBI_VERSION_MAJOR < 5:
                            raise ValueError(
                                "The gurobi_direct plugin does not handle quadratic "
                                "constraint expressions for Gurobi major versions "
                                "< 5. Current version: Gurobi %s.%s%s"
                                % (gurobi.version()))

                        expr = QuadExpr(expr)
                        hash_to_variable_map = con_repn[-1]
                        for quad_repn, coef in iteritems(con_repn[2]):
                            gurobi_expr = QuadExpr(coef)
                            for var_hash, exponent in iteritems(quad_repn):
                                vardata = hash_to_variable_map[var_hash]
                                self._referenced_variable_ids.add(id(vardata))
                                gurobi_var = pyomo_gurobi_variable_map\
                                             [self_variable_symbol_map.\
                                              getSymbol(vardata)]
                                gurobi_expr *= gurobi_var
                                if exponent == 2:
                                    gurobi_expr *= gurobi_var
                            expr += gurobi_expr

                    degree = canonical_degree(con_repn)
                    if (degree is None) or (degree > 2):
                        raise ValueError(
                            "gurobi_direct plugin does not support general nonlinear "
                            "constraint expressions (only linear or quadratic).\n"
                            "Constraint: %s" % (constraint_data.name))

                if (not trivial) or (not self._skip_trivial_constraints):

                    if constraint_data.equality:
                        sense = GRB.EQUAL
                        bound = self._get_bound(constraint_data.lower)
                        grbmodel.addConstr(lhs=expr,
                                           sense=sense,
                                           rhs=bound,
                                           name=constraint_label)
                    else:
                        # L <= body <= U
                        if (constraint_data.upper is not None) and \
                           (constraint_data.lower is not None):
                            grb_con = grbmodel.addRange(
                                expr,
                                self._get_bound(constraint_data.lower),
                                self._get_bound(constraint_data.upper),
                                constraint_label)
                            _self_range_con_var_pairs.append((grb_con,range_var_idx))
                            range_var_idx += 1
                        # body <= U
                        elif constraint_data.upper is not None:
                            bound = self._get_bound(constraint_data.upper)
                            if bound < float('inf'):
                                grbmodel.addConstr(
                                    lhs=expr,
                                    sense=GRB.LESS_EQUAL,
                                    rhs=bound,
                                    name=constraint_label
                                    )
                        # L <= body
                        else:
                            bound = self._get_bound(constraint_data.lower)
                            if bound > -float('inf'):
                                grbmodel.addConstr(
                                    lhs=expr,
                                    sense=GRB.GREATER_EQUAL,
                                    rhs=bound,
                                    name=constraint_label
                                    )

        if modelSOS.sosType:
            for key in modelSOS.sosType:
                grbmodel.addSOS(modelSOS.sosType[key], \
                                modelSOS.varnames[key], \
                                modelSOS.weights[key] )
                self._referenced_variable_ids.update(modelSOS.varids[key])

        for var_id in self._referenced_variable_ids:
            varname = self._variable_symbol_map.byObject[var_id]
            vardata = self._variable_symbol_map.bySymbol[varname]()
            if vardata.fixed:
                if not self._output_fixed_variable_bounds:
                    raise ValueError("Encountered a fixed variable (%s) inside an active objective "
                                     "or constraint expression on model %s, which is usually indicative of "
                                     "a preprocessing error. Use the IO-option 'output_fixed_variable_bounds=True' "
                                     "to suppress this error and fix the variable by overwriting its bounds in "
                                     "the Gurobi instance."
                                     % (vardata.name,pyomo_instance.name,))

                grbvar = pyomo_gurobi_variable_map[varname]
                grbvar.setAttr(GRB.Attr.UB, vardata.value)
                grbvar.setAttr(GRB.Attr.LB, vardata.value)

        grbmodel.update()

        self._gurobi_instance = grbmodel
        self._pyomo_gurobi_variable_map = pyomo_gurobi_variable_map

Example 41

Project: python-escpos
Source File: cli.py
View license
def main():
    """

    Handles loading of configuration and creating and processing of command
    line arguments. Called when run from a CLI.

    """

    parser = argparse.ArgumentParser(
        description='CLI for python-escpos',
        epilog='Printer configuration is defined in the python-escpos config'
        'file. See documentation for details.',
    )

    parser.register('type', 'bool', str_to_bool)

    # Allow config file location to be passed
    parser.add_argument(
        '-c', '--config',
        help='Alternate path to the configuration file',
    )

    # Everything interesting runs off of a subparser so we can use the format
    # cli [subparser] -args
    command_subparsers = parser.add_subparsers(
        title='ESCPOS Command',
        dest='parser',
    )
    # fix inconsistencies in the behaviour of some versions of argparse
    command_subparsers.required = False   # force 'required' testing

    # Build the ESCPOS command arguments
    for command in ESCPOS_COMMANDS:
        parser_command = command_subparsers.add_parser(**command['parser'])
        parser_command.set_defaults(**command['defaults'])
        for argument in command['arguments']:
            option_strings = argument.pop('option_strings')
            parser_command.add_argument(*option_strings, **argument)

    # Build any custom arguments
    parser_command_demo = command_subparsers.add_parser('demo',
                                                        help='Demonstrates various functions')
    parser_command_demo.set_defaults(func='demo')
    demo_group = parser_command_demo.add_mutually_exclusive_group()
    demo_group.add_argument(
        '--barcodes-a',
        help='Print demo barcodes for function type A',
        action='store_true',
    )
    demo_group.add_argument(
        '--barcodes-b',
        help='Print demo barcodes for function type B',
        action='store_true',
    )
    demo_group.add_argument(
        '--qr',
        help='Print some demo QR codes',
        action='store_true',
    )
    demo_group.add_argument(
        '--text',
        help='Print some demo text',
        action='store_true',
    )

    parser_command_version = command_subparsers.add_parser('version',
                                                           help='Print the version of python-escpos')
    parser_command_version.set_defaults(version=True)

    # hook in argcomplete
    if 'argcomplete' in globals():
        argcomplete.autocomplete(parser)

    # Get only arguments actually passed
    args_dict = vars(parser.parse_args())
    if not args_dict:
        parser.print_help()
        sys.exit()
    command_arguments = dict([k, v] for k, v in six.iteritems(args_dict) if v is not None)

    # If version should be printed, do this, then exit
    print_version = command_arguments.pop('version', None)
    if print_version:
        print(version.version)
        sys.exit()

    # If there was a config path passed, grab it
    config_path = command_arguments.pop('config', None)

    # Load the configuration and defined printer
    saved_config = config.Config()
    saved_config.load(config_path)
    printer = saved_config.printer()

    if not printer:
        raise Exception('No printers loaded from config')

    target_command = command_arguments.pop('func')

    # remove helper-argument 'parser' from dict
    command_arguments.pop('parser', None)

    if hasattr(printer, target_command):
        # print command with args
        getattr(printer, target_command)(**command_arguments)
        if target_command in REQUIRES_NEWLINE:
            printer.text("\n")
    else:
        command_arguments['printer'] = printer
        globals()[target_command](**command_arguments)

Example 42

Project: pgcontents
Source File: test_hybrid_manager.py
View license
    def test_get(self):
        cm = self.contents_manager

        untitled_nb = 'Untitled.ipynb'
        untitled_txt = 'untitled.txt'
        for prefix, real_dir in iteritems(self.temp_dir_names):
            # Create a notebook
            model = cm.new_untitled(path=prefix, type='notebook')
            name = model['name']
            path = model['path']

            self.assertEqual(name, untitled_nb)
            self.assertEqual(path, pjoin(prefix, untitled_nb))
            self.assertTrue(
                exists(osjoin(real_dir, untitled_nb))
            )

            # Check that we can 'get' on the notebook we just created
            model2 = cm.get(path)
            assert isinstance(model2, dict)
            self.assertDictContainsSubset(
                {'name': name, 'path': path},
                model2,
            )

            nb_as_file = cm.get(path, content=True, type='file')
            self.assertDictContainsSubset(
                {'name': name, 'path': path, 'format': 'text'},
                nb_as_file,
            )
            self.assertNotIsInstance(nb_as_file['content'], dict)

            nb_as_bin_file = cm.get(
                path=path,
                content=True,
                type='file',
                format='base64'
            )
            self.assertDictContainsSubset(
                {'name': name, 'path': path, 'format': 'base64'},
                nb_as_bin_file,
            )
            self.assertNotIsInstance(nb_as_bin_file['content'], dict)

            # Test notebook in sub-directory
            sub_dir = 'foo'
            mkdir(osjoin(real_dir, sub_dir))
            prefixed_sub_dir = pjoin(prefix, sub_dir)

            cm.new_untitled(path=prefixed_sub_dir, ext='.ipynb')
            self.assertTrue(exists(osjoin(real_dir, sub_dir, untitled_nb)))

            sub_dir_nbpath = pjoin(prefixed_sub_dir, untitled_nb)
            model2 = cm.get(sub_dir_nbpath)
            self.assertDictContainsSubset(
                {
                    'type': 'notebook',
                    'format': 'json',
                    'name': untitled_nb,
                    'path': sub_dir_nbpath,
                },
                model2,
            )
            self.assertIn('content', model2)

            # Test .txt in sub-directory.
            cm.new_untitled(path=prefixed_sub_dir, ext='.txt')
            self.assertTrue(exists(osjoin(real_dir, sub_dir, untitled_txt)))

            sub_dir_txtpath = pjoin(prefixed_sub_dir, untitled_txt)
            file_model = cm.get(path=sub_dir_txtpath)
            self.assertDictContainsSubset(
                {
                    'content': '',
                    'format': 'text',
                    'mimetype': 'text/plain',
                    'name': 'untitled.txt',
                    'path': sub_dir_txtpath,
                    'type': 'file',
                    'writable': True,
                },
                file_model,
            )
            self.assertIn('created', file_model)
            self.assertIn('last_modified', file_model)

            # Test directory in sub-directory.
            sub_sub_dirname = 'bar'
            sub_sub_dirpath = pjoin(prefixed_sub_dir, sub_sub_dirname)
            cm.save(
                {'type': 'directory', 'path': sub_sub_dirpath},
                sub_sub_dirpath,
            )
            self.assertTrue(exists(osjoin(real_dir, sub_dir, sub_sub_dirname)))
            sub_sub_dir_model = cm.get(sub_sub_dirpath)
            self.assertDictContainsSubset(
                {
                    'type': 'directory',
                    'format': 'json',
                    'name': sub_sub_dirname,
                    'path': sub_sub_dirpath,
                    'content': [],
                },
                sub_sub_dir_model,
            )

            # Test list with content on prefix/foo.
            dirmodel = cm.get(prefixed_sub_dir)
            self.assertDictContainsSubset(
                {
                    'type': 'directory',
                    'path': prefixed_sub_dir,
                    'name': sub_dir,
                },
                dirmodel,
            )
            self.assertIsInstance(dirmodel['content'], list)
            self.assertEqual(len(dirmodel['content']), 3)

            # Request each item in the subdirectory with no content.
            nbmodel_no_content = cm.get(sub_dir_nbpath, content=False)
            file_model_no_content = cm.get(sub_dir_txtpath, content=False)
            sub_sub_dir_no_content = cm.get(sub_sub_dirpath, content=False)

            for entry in dirmodel['content']:
                # Order isn't guaranteed by the spec, so this is a hacky way of
                # verifying that all entries are matched.
                if entry['path'] == sub_sub_dir_no_content['path']:
                    self.assertEqual(entry, sub_sub_dir_no_content)
                elif entry['path'] == nbmodel_no_content['path']:
                    self.assertEqual(entry, nbmodel_no_content)
                elif entry['path'] == file_model_no_content['path']:
                    self.assertEqual(entry, file_model_no_content)
                else:
                    self.fail("Unexpected directory entry: %s" % entry)

Example 43

Project: zipline
Source File: test_finance.py
View license
    def transaction_sim(self, **params):
        """This is a utility method that asserts expected
        results for conversion of orders to transactions given a
        trade history
        """
        trade_count = params['trade_count']
        trade_interval = params['trade_interval']
        order_count = params['order_count']
        order_amount = params['order_amount']
        order_interval = params['order_interval']
        expected_txn_count = params['expected_txn_count']
        expected_txn_volume = params['expected_txn_volume']

        # optional parameters
        # ---------------------
        # if present, alternate between long and short sales
        alternate = params.get('alternate')

        # if present, expect transaction amounts to match orders exactly.
        complete_fill = params.get('complete_fill')

        sid = 1
        metadata = make_simple_equity_info([sid], self.start, self.end)
        with TempDirectory() as tempdir, \
                tmp_trading_env(equities=metadata) as env:

            if trade_interval < timedelta(days=1):
                sim_params = factory.create_simulation_parameters(
                    start=self.start,
                    end=self.end,
                    data_frequency="minute"
                )

                minutes = self.trading_calendar.minutes_window(
                    sim_params.first_open,
                    int((trade_interval.total_seconds() / 60) * trade_count)
                    + 100)

                price_data = np.array([10.1] * len(minutes))
                assets = {
                    sid: pd.DataFrame({
                        "open": price_data,
                        "high": price_data,
                        "low": price_data,
                        "close": price_data,
                        "volume": np.array([100] * len(minutes)),
                        "dt": minutes
                    }).set_index("dt")
                }

                write_bcolz_minute_data(
                    self.trading_calendar,
                    self.trading_calendar.sessions_in_range(
                        self.trading_calendar.minute_to_session_label(
                            minutes[0]
                        ),
                        self.trading_calendar.minute_to_session_label(
                            minutes[-1]
                        )
                    ),
                    tempdir.path,
                    iteritems(assets),
                )

                equity_minute_reader = BcolzMinuteBarReader(tempdir.path)

                data_portal = DataPortal(
                    env.asset_finder, self.trading_calendar,
                    first_trading_day=equity_minute_reader.first_trading_day,
                    equity_minute_reader=equity_minute_reader,
                )
            else:
                sim_params = factory.create_simulation_parameters(
                    data_frequency="daily"
                )

                days = sim_params.sessions

                assets = {
                    1: pd.DataFrame({
                        "open": [10.1] * len(days),
                        "high": [10.1] * len(days),
                        "low": [10.1] * len(days),
                        "close": [10.1] * len(days),
                        "volume": [100] * len(days),
                        "day": [day.value for day in days]
                    }, index=days)
                }

                path = os.path.join(tempdir.path, "testdata.bcolz")
                BcolzDailyBarWriter(path, self.trading_calendar, days[0],
                                    days[-1]).write(
                    assets.items()
                )

                equity_daily_reader = BcolzDailyBarReader(path)

                data_portal = DataPortal(
                    env.asset_finder, self.trading_calendar,
                    first_trading_day=equity_daily_reader.first_trading_day,
                    equity_daily_reader=equity_daily_reader,
                )

            if "default_slippage" not in params or \
               not params["default_slippage"]:
                slippage_func = FixedSlippage()
            else:
                slippage_func = None

            blotter = Blotter(sim_params.data_frequency, self.env.asset_finder,
                              slippage_func)

            start_date = sim_params.first_open

            if alternate:
                alternator = -1
            else:
                alternator = 1

            tracker = PerformanceTracker(sim_params, self.trading_calendar,
                                         self.env)

            # replicate what tradesim does by going through every minute or day
            # of the simulation and processing open orders each time
            if sim_params.data_frequency == "minute":
                ticks = minutes
            else:
                ticks = days

            transactions = []

            order_list = []
            order_date = start_date
            for tick in ticks:
                blotter.current_dt = tick
                if tick >= order_date and len(order_list) < order_count:
                    # place an order
                    direction = alternator ** len(order_list)
                    order_id = blotter.order(
                        blotter.asset_finder.retrieve_asset(sid),
                        order_amount * direction,
                        MarketOrder())
                    order_list.append(blotter.orders[order_id])
                    order_date = order_date + order_interval
                    # move after market orders to just after market next
                    # market open.
                    if order_date.hour >= 21:
                        if order_date.minute >= 00:
                            order_date = order_date + timedelta(days=1)
                            order_date = order_date.replace(hour=14, minute=30)
                else:
                    bar_data = BarData(
                        data_portal=data_portal,
                        simulation_dt_func=lambda: tick,
                        data_frequency=sim_params.data_frequency,
                        trading_calendar=self.trading_calendar,
                        restrictions=NoRestrictions(),
                    )
                    txns, _, closed_orders = blotter.get_transactions(bar_data)
                    for txn in txns:
                        tracker.process_transaction(txn)
                        transactions.append(txn)

                    blotter.prune_orders(closed_orders)

            for i in range(order_count):
                order = order_list[i]
                self.assertEqual(order.sid, sid)
                self.assertEqual(order.amount, order_amount * alternator ** i)

            if complete_fill:
                self.assertEqual(len(transactions), len(order_list))

            total_volume = 0
            for i in range(len(transactions)):
                txn = transactions[i]
                total_volume += txn.amount
                if complete_fill:
                    order = order_list[i]
                    self.assertEqual(order.amount, txn.amount)

            self.assertEqual(total_volume, expected_txn_volume)

            self.assertEqual(len(transactions), expected_txn_count)

            cumulative_pos = tracker.position_tracker.positions[sid]
            if total_volume == 0:
                self.assertIsNone(cumulative_pos)
            else:
                self.assertEqual(total_volume, cumulative_pos.amount)

            # the open orders should not contain sid.
            oo = blotter.open_orders
            self.assertNotIn(sid, oo, "Entry is removed when no open orders")

Example 44

Project: zipline
Source File: test_restrictions.py
View license
    @parameter_space(
        date_offset=(
            pd.Timedelta(0),
            pd.Timedelta('1 minute'),
            pd.Timedelta('15 hours 5 minutes')
        ),
        restriction_order=(
            list(range(6)),      # Keep restrictions in order.
            [0, 2, 1, 3, 5, 4],  # Re-order within asset.
            [0, 3, 1, 4, 2, 5],  # Scramble assets, maintain per-asset order.
            [0, 5, 2, 3, 1, 4],  # Scramble assets and per-asset order.
        ),
        __fail_fast=True,
    )
    def test_historical_restrictions(self, date_offset, restriction_order):
        """
        Test historical restrictions for both interday and intraday
        restrictions, as well as restrictions defined in/not in order, for both
        single- and multi-asset queries
        """
        def rdate(s):
            """Convert a date string into a restriction for that date."""
            # Add date_offset to check that we handle intraday changes.
            return str_to_ts(s) + date_offset

        base_restrictions = [
            Restriction(self.ASSET1, rdate('2011-01-04'), FROZEN),
            Restriction(self.ASSET1, rdate('2011-01-05'), ALLOWED),
            Restriction(self.ASSET1, rdate('2011-01-06'), FROZEN),
            Restriction(self.ASSET2, rdate('2011-01-05'), FROZEN),
            Restriction(self.ASSET2, rdate('2011-01-06'), ALLOWED),
            Restriction(self.ASSET2, rdate('2011-01-07'), FROZEN),
        ]
        # Scramble the restrictions based on restriction_order to check that we
        # don't depend on the order in which restrictions are provided to us.
        all_restrictions = [base_restrictions[i] for i in restriction_order]

        restrictions_by_asset = groupby(lambda r: r.asset, all_restrictions)

        rl = HistoricalRestrictions(all_restrictions)
        assert_not_restricted = partial(self.assert_not_restricted, rl)
        assert_is_restricted = partial(self.assert_is_restricted, rl)
        assert_all_restrictions = partial(self.assert_all_restrictions, rl)

        # Check individual restrictions.
        for asset, r_history in iteritems(restrictions_by_asset):
            freeze_dt, unfreeze_dt, re_freeze_dt = (
                sorted([r.effective_date for r in r_history])
            )

            # Starts implicitly unrestricted. Restricted on or after the freeze
            assert_not_restricted(asset, freeze_dt - MINUTE)
            assert_is_restricted(asset, freeze_dt)
            assert_is_restricted(asset, freeze_dt + MINUTE)

            # Unrestricted on or after the unfreeze
            assert_is_restricted(asset, unfreeze_dt - MINUTE)
            assert_not_restricted(asset, unfreeze_dt)
            assert_not_restricted(asset, unfreeze_dt + MINUTE)

            # Restricted again on or after the freeze
            assert_not_restricted(asset, re_freeze_dt - MINUTE)
            assert_is_restricted(asset, re_freeze_dt)
            assert_is_restricted(asset, re_freeze_dt + MINUTE)

            # Should stay restricted for the rest of time
            assert_is_restricted(asset, re_freeze_dt + MINUTE * 1000000)

        # Check vectorized restrictions.
        # Expected results for [self.ASSET1, self.ASSET2, self.ASSET3],
        # ASSET3 is always False as it has no defined restrictions

        # 01/04 XX:00 ASSET1: ALLOWED --> FROZEN; ASSET2: ALLOWED
        d0 = rdate('2011-01-04')
        assert_all_restrictions([False, False, False], d0 - MINUTE)
        assert_all_restrictions([True, False, False], d0)
        assert_all_restrictions([True, False, False], d0 + MINUTE)

        # 01/05 XX:00 ASSET1: FROZEN --> ALLOWED; ASSET2: ALLOWED --> FROZEN
        d1 = rdate('2011-01-05')
        assert_all_restrictions([True, False, False], d1 - MINUTE)
        assert_all_restrictions([False, True, False], d1)
        assert_all_restrictions([False, True, False], d1 + MINUTE)

        # 01/06 XX:00 ASSET1: ALLOWED --> FROZEN; ASSET2: FROZEN --> ALLOWED
        d2 = rdate('2011-01-06')
        assert_all_restrictions([False, True, False], d2 - MINUTE)
        assert_all_restrictions([True, False, False], d2)
        assert_all_restrictions([True, False, False], d2 + MINUTE)

        # 01/07 XX:00 ASSET1: FROZEN; ASSET2: ALLOWED --> FROZEN
        d3 = rdate('2011-01-07')
        assert_all_restrictions([True, False, False], d3 - MINUTE)
        assert_all_restrictions([True, True, False], d3)
        assert_all_restrictions([True, True, False], d3 + MINUTE)

        # Should stay restricted for the rest of time
        assert_all_restrictions(
            [True, True, False],
            d3 + (MINUTE * 10000000)
        )

Example 45

Project: rabix
Source File: main.py
View license
def main():
    disable_warnings()
    logging.basicConfig(level=logging.WARN)
    if len(sys.argv) == 1:
        print(USAGE)
        return

    usage = USAGE.format(resources=make_resources_usage_string(),
                         inputs='<inputs>')
    app_usage = usage

    if len(sys.argv) == 2 and \
            (sys.argv[1] == '--help' or sys.argv[1] == '-h'):
        print(USAGE)
        return

    dry_run_args = dry_run_parse()
    if not dry_run_args:
        print(USAGE)
        return

    if not (dry_run_args['<tool>']):
        print('You have to specify a tool, with --tool option')
        print(usage)
        return

    tool = get_tool(dry_run_args)
    if not tool:
        fail("Couldn't find tool.")

    if isinstance(tool, list):
        tool = loader.index.get('#main')

    if 'class' not in tool:
        fail("Document must have a 'class' field")

    if 'id' not in tool:
        tool['id'] = dry_run_args['<tool>']

    context = init_context(tool)

    app = process_builder(context, tool)
    job = None

    if isinstance(app, Job):
        job = app
        app = job.app

    rabix.expressions.update_engines(app)

    if dry_run_args['--install']:
        app.install()
        print("Install successful.")
        return

    if dry_run_args['--conformance-test']:
        job_dict = from_url(dry_run_args['<job>'])
        conformance_test(context, app, job_dict, dry_run_args.get('--basedir'))
        return

    try:
        args = docopt.docopt(usage, version=version, help=False)
        job_dict = copy.deepcopy(TEMPLATE_JOB)
        logging.root.setLevel(log_level(dry_run_args['--verbose']))

        input_file_path = args.get('<inp>') or args.get('--inp-file')
        if input_file_path:
            basedir = os.path.dirname(os.path.abspath(input_file_path))
            input_file = from_url(input_file_path)
            inputs = get_inputs(input_file, app.inputs, basedir)
            job_dict['inputs'].update(inputs)

        input_usage = job_dict['inputs']

        if job:
            basedir = os.path.dirname(args.get('<tool>'))
            job.inputs = get_inputs(job.inputs, app.inputs, basedir)
            input_usage.update(job.inputs)

        app_inputs_usage = make_app_usage_string(
            app, template=TOOL_TEMPLATE, inp=input_usage)

        app_usage = make_app_usage_string(app, USAGE, job_dict['inputs'])

        try:
            app_inputs = docopt.docopt(app_inputs_usage, args['<inputs>'])
        except docopt.DocoptExit:
            if not job:
                raise
            for inp in job.app.inputs:
                if inp.required and inp.id not in job.inputs:
                    raise
            app_inputs = {}

        if args['--help']:
            print(app_usage)
            return
        # trim leading --, and ignore empty arays
        app_inputs = {
            k[2:]: v
            for k, v in six.iteritems(app_inputs)
            if v != []
        }

        inp = get_inputs(app_inputs, app.inputs)
        if not job:
            job_dict['id'] = args.get('--outdir') or args.get('--dir')
            job_dict['app'] = app
            job = Job.from_dict(context, job_dict)

        job.inputs.update(inp)

        if args['--print-cli']:
            if not isinstance(app, CommandLineTool):
                fail(dry_run_args['<tool>'] + " is not a command line app")

            print(CLIJob(job).cmd_line())
            return

        if args['--pretty-print']:
            fmt = partial(result_str, job.id)
        else:
            fmt = lambda result: json.dumps(context.to_primitive(result))

        if not job.inputs and not args['--'] and not args['--quiet']:
            print(app_usage)
            return

        try:
            context.executor.execute(job, lambda _, result: print(fmt(result)))
        except RabixError as err:
            fail(err.message)

    except docopt.DocoptExit:
        fail(app_usage)

Example 46

Project: ansible-kafka
Source File: rax.py
View license
def _list_into_cache(regions):
    groups = collections.defaultdict(list)
    hostvars = collections.defaultdict(dict)
    images = {}
    cbs_attachments = collections.defaultdict(dict)

    prefix = get_config(p, 'rax', 'meta_prefix', 'RAX_META_PREFIX', 'meta')

    networks = get_config(p, 'rax', 'access_network', 'RAX_ACCESS_NETWORK',
                          'public', islist=True)
    try:
        ip_versions = map(int, get_config(p, 'rax', 'access_ip_version',
                                          'RAX_ACCESS_IP_VERSION', 4,
                                          islist=True))
    except:
        ip_versions = [4]
    else:
        ip_versions = [v for v in ip_versions if v in [4, 6]]
        if not ip_versions:
            ip_versions = [4]

    # Go through all the regions looking for servers
    for region in regions:
        # Connect to the region
        cs = pyrax.connect_to_cloudservers(region=region)
        if cs is None:
            warnings.warn(
                'Connecting to Rackspace region "%s" has caused Pyrax to '
                'return None. Is this a valid region?' % region,
                RuntimeWarning)
            continue
        for server in cs.servers.list():
            # Create a group on region
            groups[region].append(server.name)

            # Check if group metadata key in servers' metadata
            group = server.metadata.get('group')
            if group:
                groups[group].append(server.name)

            for extra_group in server.metadata.get('groups', '').split(','):
                if extra_group:
                    groups[extra_group].append(server.name)

            # Add host metadata
            for key, value in to_dict(server).items():
                hostvars[server.name][key] = value

            hostvars[server.name]['rax_region'] = region

            for key, value in iteritems(server.metadata):
                groups['%s_%s_%s' % (prefix, key, value)].append(server.name)

            groups['instance-%s' % server.id].append(server.name)
            groups['flavor-%s' % server.flavor['id']].append(server.name)

            # Handle boot from volume
            if not server.image:
                if not cbs_attachments[region]:
                    cbs = pyrax.connect_to_cloud_blockstorage(region)
                    for vol in cbs.list():
                        if mk_boolean(vol.bootable):
                            for attachment in vol.attachments:
                                metadata = vol.volume_image_metadata
                                server_id = attachment['server_id']
                                cbs_attachments[region][server_id] = {
                                    'id': metadata['image_id'],
                                    'name': slugify(metadata['image_name'])
                                }
                image = cbs_attachments[region].get(server.id)
                if image:
                    server.image = {'id': image['id']}
                    hostvars[server.name]['rax_image'] = server.image
                    hostvars[server.name]['rax_boot_source'] = 'volume'
                    images[image['id']] = image['name']
            else:
                hostvars[server.name]['rax_boot_source'] = 'local'

            try:
                imagegroup = 'image-%s' % images[server.image['id']]
                groups[imagegroup].append(server.name)
                groups['image-%s' % server.image['id']].append(server.name)
            except KeyError:
                try:
                    image = cs.images.get(server.image['id'])
                except cs.exceptions.NotFound:
                    groups['image-%s' % server.image['id']].append(server.name)
                else:
                    images[image.id] = image.human_id
                    groups['image-%s' % image.human_id].append(server.name)
                    groups['image-%s' % server.image['id']].append(server.name)

            # And finally, add an IP address
            ansible_ssh_host = None
            # use accessIPv[46] instead of looping address for 'public'
            for network_name in networks:
                if ansible_ssh_host:
                    break
                if network_name == 'public':
                    for version_name in ip_versions:
                        if ansible_ssh_host:
                            break
                        if version_name == 6 and server.accessIPv6:
                            ansible_ssh_host = server.accessIPv6
                        elif server.accessIPv4:
                            ansible_ssh_host = server.accessIPv4
                if not ansible_ssh_host:
                    addresses = server.addresses.get(network_name, [])
                    for address in addresses:
                        for version_name in ip_versions:
                            if ansible_ssh_host:
                                break
                            if address.get('version') == version_name:
                                ansible_ssh_host = address.get('addr')
                                break
            if ansible_ssh_host:
                hostvars[server.name]['ansible_ssh_host'] = ansible_ssh_host

    if hostvars:
        groups['_meta'] = {'hostvars': hostvars}

    with open(get_cache_file_path(regions), 'w') as cache_file:
        json.dump(groups, cache_file)

Example 47

Project: bugwarrior
Source File: db.py
View license
def find_local_uuid(tw, keys, issue, legacy_matching=False):
    """ For a given issue issue, find its local UUID.

    Assembles a list of task IDs existing in taskwarrior
    matching the supplied issue (`issue`) on the combination of any
    set of supplied unique identifiers (`keys`) or, optionally,
    the task's description field (should `legacy_matching` be `True`).

    :params:
    * `tw`: An instance of `taskw.TaskWarriorShellout`
    * `keys`: A list of lists of keys to use for uniquely identifying
      an issue.  To clarify the "list of lists" behavior, assume that
      there are two services, one having a single primary key field
      -- 'serviceAid' -- and another having a pair of fields composing
      its primary key -- 'serviceBproject' and 'serviceBnumber' --, the
      incoming data for this field would be::

        [
            ['serviceAid'],
            ['serviceBproject', 'serviceBnumber'],
        ]

    * `issue`: An instance of a subclass of `bugwarrior.services.Issue`.
    * `legacy_matching`: By default, this is disabled, and it allows
      the matching algorithm to -- in addition to searching by stored
      issue keys -- search using the task's description for a match.
      It is prone to error and should avoided if possible.

    :returns:
    * A single string UUID.

    :raises:
    * `bugwarrior.db.MultipleMatches`: if multiple matches were found.
    * `bugwarrior.db.NotFound`: if an issue was not found.

    """
    if not issue['description']:
        raise ValueError('Issue %s has no description.' % issue)

    possibilities = set([])

    if legacy_matching:
        legacy_description = issue.get_default_description().rsplit('..', 1)[0]
        # Furthermore, we have to kill off any single quotes which break in
        # task-2.4.x, as much as it saddens me.
        legacy_description = legacy_description.split("'")[0]
        results = tw.filter_tasks({
            'description.startswith': legacy_description,
            'or': [
                ('status', 'pending'),
                ('status', 'waiting'),
            ],
        })
        possibilities = possibilities | set([
            task['uuid'] for task in results
        ])

    for service, key_list in six.iteritems(keys):
        if any([key in issue for key in key_list]):
            results = tw.filter_tasks({
                'and': [("%s.is" % key, issue[key]) for key in key_list],
                'or': [
                    ('status', 'pending'),
                    ('status', 'waiting'),
                ],
            })
            possibilities = possibilities | set([
                task['uuid'] for task in results
            ])

    if len(possibilities) == 1:
        return possibilities.pop()

    if len(possibilities) > 1:
        raise MultipleMatches(
            "Issue %s matched multiple IDs: %s" % (
                issue['description'],
                possibilities
            )
        )

    raise NotFound(
        "No issue was found matching %s" % issue
    )

Example 48

Project: attention-lvcsr
Source File: svhn.py
View license
@check_exists(required_files=FORMAT_1_FILES)
def convert_svhn_format_1(directory, output_directory,
                          output_filename='svhn_format_1.hdf5'):
    """Converts the SVHN dataset (format 1) to HDF5.

    This method assumes the existence of the files
    `{train,test,extra}.tar.gz`, which are accessible through the
    official website [SVHNSITE].

    .. [SVHNSITE] http://ufldl.stanford.edu/housenumbers/

    Parameters
    ----------
    directory : str
        Directory in which input files reside.
    output_directory : str
        Directory in which to save the converted dataset.
    output_filename : str, optional
        Name of the saved dataset. Defaults to 'svhn_format_1.hdf5'.

    Returns
    -------
    output_paths : tuple of str
        Single-element tuple containing the path to the converted dataset.

    """
    try:
        output_path = os.path.join(output_directory, output_filename)
        h5file = h5py.File(output_path, mode='w')
        TMPDIR = tempfile.mkdtemp()

        # Every image has three channels (RGB) and variable height and width.
        # It features a variable number of bounding boxes that identify the
        # location and label of digits. The bounding box location is specified
        # using the x and y coordinates of its top left corner along with its
        # width and height.
        BoundingBoxes = namedtuple(
            'BoundingBoxes', ['labels', 'heights', 'widths', 'lefts', 'tops'])
        sources = ('features',) + tuple('bbox_{}'.format(field)
                                        for field in BoundingBoxes._fields)
        source_dtypes = dict([(source, 'uint8') for source in sources[:2]] +
                             [(source, 'uint16') for source in sources[2:]])
        source_axis_labels = {
            'features': ('channel', 'height', 'width'),
            'bbox_labels': ('bounding_box', 'index'),
            'bbox_heights': ('bounding_box', 'height'),
            'bbox_widths': ('bounding_box', 'width'),
            'bbox_lefts': ('bounding_box', 'x'),
            'bbox_tops': ('bounding_box', 'y')}

        # The dataset is split into three sets: the training set, the test set
        # and an extra set of examples that are somewhat less difficult but
        # can be used as extra training data. These sets are stored separately
        # as 'train.tar.gz', 'test.tar.gz' and 'extra.tar.gz'. Each file
        # contains a directory named after the split it stores. The examples
        # are stored in that directory as PNG images. The directory also
        # contains a 'digitStruct.mat' file with all the bounding box and
        # label information.
        splits = ('train', 'test', 'extra')
        file_paths = dict(zip(splits, FORMAT_1_FILES))
        for split, path in file_paths.items():
            file_paths[split] = os.path.join(directory, path)
        digit_struct_paths = dict(
            [(split, os.path.join(TMPDIR, split, 'digitStruct.mat'))
             for split in splits])

        # We first extract the data files in a temporary directory. While doing
        # that, we also count the number of examples for each split. Files are
        # extracted individually, which allows to display a progress bar. Since
        # the splits will be concatenated in the HDF5 file, we also compute the
        # start and stop intervals of each split within the concatenated array.
        def extract_tar(split):
            with tarfile.open(file_paths[split], 'r:gz') as f:
                members = f.getmembers()
                num_examples = sum(1 for m in members if '.png' in m.name)
                progress_bar_context = progress_bar(
                    name='{} file'.format(split), maxval=len(members),
                    prefix='Extracting')
                with progress_bar_context as bar:
                    for i, member in enumerate(members):
                        f.extract(member, path=TMPDIR)
                        bar.update(i)
            return num_examples

        examples_per_split = OrderedDict(
            [(split, extract_tar(split)) for split in splits])
        cumulative_num_examples = numpy.cumsum(
            [0] + list(examples_per_split.values()))
        num_examples = cumulative_num_examples[-1]
        intervals = zip(cumulative_num_examples[:-1],
                        cumulative_num_examples[1:])
        split_intervals = dict(zip(splits, intervals))

        # The start and stop indices are used to create a split dict that will
        # be parsed into the split array required by the H5PYDataset interface.
        # The split dict is organized as follows:
        #
        #     dict(split -> dict(source -> (start, stop)))
        #
        split_dict = OrderedDict([
            (split, OrderedDict([(s, split_intervals[split])
                                 for s in sources]))
            for split in splits])
        h5file.attrs['split'] = H5PYDataset.create_split_array(split_dict)

        # We then prepare the HDF5 dataset. This involves creating datasets to
        # store data sources and datasets to store auxiliary information
        # (namely the shapes for variable-length axes, and labels to indicate
        # what these variable-length axes represent).
        def make_vlen_dataset(source):
            # Create a variable-length 1D dataset
            dtype = h5py.special_dtype(vlen=numpy.dtype(source_dtypes[source]))
            dataset = h5file.create_dataset(
                source, (num_examples,), dtype=dtype)
            # Create a dataset to store variable-length shapes.
            axis_labels = source_axis_labels[source]
            dataset_shapes = h5file.create_dataset(
                '{}_shapes'.format(source), (num_examples, len(axis_labels)),
                dtype='uint16')
            # Create a dataset to store labels for variable-length axes.
            dataset_vlen_axis_labels = h5file.create_dataset(
                '{}_vlen_axis_labels'.format(source), (len(axis_labels),),
                dtype='S{}'.format(
                    numpy.max([len(label) for label in axis_labels])))
            # Fill variable-length axis labels
            dataset_vlen_axis_labels[...] = [
                label.encode('utf8') for label in axis_labels]
            # Attach auxiliary datasets as dimension scales of the
            # variable-length 1D dataset. This is in accordance with the
            # H5PYDataset interface.
            dataset.dims.create_scale(dataset_shapes, 'shapes')
            dataset.dims[0].attach_scale(dataset_shapes)
            dataset.dims.create_scale(dataset_vlen_axis_labels, 'shape_labels')
            dataset.dims[0].attach_scale(dataset_vlen_axis_labels)
            # Tag fixed-length axis with its label
            dataset.dims[0].label = 'batch'

        for source in sources:
            make_vlen_dataset(source)

        # The "fun" part begins: we extract the bounding box and label
        # information contained in 'digitStruct.mat'. This is a version 7.3
        # Matlab file, which uses HDF5 under the hood, albeit with a very
        # convoluted layout.
        def get_boxes(split):
            boxes = []
            with h5py.File(digit_struct_paths[split], 'r') as f:
                bar_name = '{} digitStruct'.format(split)
                bar_maxval = examples_per_split[split]
                with progress_bar(bar_name, bar_maxval) as bar:
                    for image_number in range(examples_per_split[split]):
                        # The 'digitStruct' group is the main group of the HDF5
                        # file. It contains two datasets: 'bbox' and 'name'.
                        # The 'name' dataset isn't of interest to us, as it
                        # stores file names and there's already a one-to-one
                        # mapping between row numbers and image names (e.g.
                        # row 0 corresponds to '1.png', row 1 corresponds to
                        # '2.png', and so on).
                        main_group = f['digitStruct']
                        # The 'bbox' dataset contains the bounding box and
                        # label information we're after. It has as many rows
                        # as there are images, and one column. Elements of the
                        # 'bbox' dataset are object references that point to
                        # (yet another) group that contains the information
                        # for the corresponding image.
                        image_reference = main_group['bbox'][image_number, 0]

                        # There are five datasets contained in that group:
                        # 'label', 'height', 'width', 'left' and 'top'. Each of
                        # those datasets has as many rows as there are bounding
                        # boxes in the corresponding image, and one column.
                        def get_dataset(name):
                            return main_group[image_reference][name][:, 0]
                        names = ('label', 'height', 'width', 'left', 'top')
                        datasets = dict(
                            [(name, get_dataset(name)) for name in names])

                        # If there is only one bounding box, the information is
                        # stored directly in the datasets. If there are
                        # multiple bounding boxes, elements of those datasets
                        # are object references pointing to 1x1 datasets that
                        # store the information (fortunately, it's the last
                        # hop we need to make).
                        def get_elements(dataset):
                            if len(dataset) > 1:
                                return [int(main_group[reference][0, 0])
                                        for reference in dataset]
                            else:
                                return [int(dataset[0])]
                        # Names are pluralized in the BoundingBox named tuple.
                        kwargs = dict(
                            [(name + 's', get_elements(dataset))
                             for name, dataset in iteritems(datasets)])
                        boxes.append(BoundingBoxes(**kwargs))
                        if bar:
                            bar.update(image_number)
            return boxes

        split_boxes = dict([(split, get_boxes(split)) for split in splits])

        # The final step is to fill the HDF5 file.
        def fill_split(split, bar=None):
            for image_number in range(examples_per_split[split]):
                image_path = os.path.join(
                    TMPDIR, split, '{}.png'.format(image_number + 1))
                image = numpy.asarray(
                    Image.open(image_path)).transpose(2, 0, 1)
                bounding_boxes = split_boxes[split][image_number]
                num_boxes = len(bounding_boxes.labels)
                index = image_number + split_intervals[split][0]

                h5file['features'][index] = image.flatten()
                h5file['features'].dims[0]['shapes'][index] = image.shape
                for field in BoundingBoxes._fields:
                    name = 'bbox_{}'.format(field)
                    h5file[name][index] = numpy.maximum(0,
                                                        getattr(bounding_boxes,
                                                                field))
                    h5file[name].dims[0]['shapes'][index] = [num_boxes, 1]

                # Replace label '10' with '0'.
                labels = h5file['bbox_labels'][index]
                labels[labels == 10] = 0
                h5file['bbox_labels'][index] = labels

                if image_number % 1000 == 0:
                    h5file.flush()
                if bar:
                    bar.update(index)

        with progress_bar('SVHN format 1', num_examples) as bar:
            for split in splits:
                fill_split(split, bar=bar)
    finally:
        if os.path.isdir(TMPDIR):
            shutil.rmtree(TMPDIR)
        h5file.flush()
        h5file.close()

    return (output_path,)

Example 49

Project: attention-lvcsr
Source File: profilemode.py
View license
    @staticmethod
    def print_summary_(fct_name, compile_time, fct_call_time, fct_call,
                       apply_time, apply_cimpl, message, variable_shape,
                       local_time, other_time,
                       n_apply_to_print=config.ProfileMode.n_apply_to_print,
                       n_ops_to_print=config.ProfileMode.n_ops_to_print,
                       print_apply=True,
                       min_memory_size=config.ProfileMode.min_memory_size,
                       ):
        """
        Do the actual printing of print_summary and print_diff_summary.

        Parameters
        ----------
        n_apply_to_print
            The number of apply to print. Default 15.
        n_ops_to_print
            The number of ops to print. Default 20.
        min_memory_size
            Don't print memory profile of apply whose outputs memory size is
            lower than that.

        """

        print("ProfileMode is deprecated! Use the new profiler.")
        print(" The Theano flags to enable it ise: profile=True")
        print(" The Theano flags for the memory profile to it is: "
              "profile_memory=True")

        total_time = time.time() - import_time
        total_fct_time = sum(fct_call_time.values())
        total_fct_call = sum(fct_call.values())
        unknown_time = total_time - total_fct_time - compile_time
        overhead_time = total_fct_time - local_time
        if total_fct_time > 0:
            time_pr_in_fct = local_time / total_fct_time * 100
            overhead_time_pourcent_fct_time = (overhead_time / total_fct_time *
                                               100)
            time_per_call = total_fct_time / total_fct_call
        else:
            time_pr_in_fct = 0
            overhead_time_pourcent_fct_time = 0
            time_per_call = 0

        print()
        print('ProfileMode.%s(%s)' % (fct_name, message))
        print('---------------------------')
        print()
        print('Time since import %.3fs' % (total_time))
        print('Theano compile time: %.3fs (%.1f%% since import)' %
              (compile_time, compile_time / total_time * 100))
        print('    Optimization time: %.3fs' % (other_time['optimizer_time']))
        print('    Linker time: %.3fs' % (other_time['linker_time']))
        print('Theano fct call %.3fs (%.1f%% since import)' %
              (total_fct_time, total_fct_time / total_time * 100))
        print('   Theano Op time %.3fs %.1f%%(since import) %.1f%%'
              '(of fct call)' % (local_time, local_time / total_time * 100,
                                 time_pr_in_fct))
        print('   Theano function overhead in ProfileMode %.3fs %.1f%%'
              '(since import) %.1f%%(of fct call)' % (
                  overhead_time, overhead_time / total_time * 100,
                  overhead_time_pourcent_fct_time))
        print('%i Theano fct call, %.3fs per call' %
              (total_fct_call, time_per_call))
        print('Rest of the time since import %.3fs %.1f%%' %
              (unknown_time, unknown_time / total_time * 100))

        print()
        print('Theano fct summary:')
        print('<% total fct time> <total time> <time per call> <nb call> '
              '<fct name>')
        for key in fct_call:
            if fct_call[key] > 0:
                print('   %4.1f%% %.3fs %.2es %d %s' %
                      (fct_call_time[key] / total_fct_time * 100,
                       fct_call_time[key],
                       fct_call_time[key] / fct_call[key],
                       fct_call[key],
                       key.name))
            else:
                print('   NOT CALLED', key.name)

        # Compute stats per op.
        op_time = {}
        op_call = {}
        op_apply = {}
        op_cimpl = {}
        sop_apply = {}
        for (i, a), t in iteritems(apply_time):
            op = a.op
            op_time.setdefault(op, 0)
            op_call.setdefault(op, 0)
            op_apply.setdefault(op, 0)
            sop_apply.setdefault(type(a.op), 0)
            op_time[op] += t
            nb_call = [v for k, v in iteritems(fct_call)
                       if k.maker.fgraph is a.fgraph][0]
            op_cimpl.setdefault(a.op, True)
            op_cimpl[a.op] = op_cimpl[a.op] and apply_cimpl.get(a, False)
            if t == 0:
                assert nb_call == 0, nb_call
            else:
                op_call[op] += nb_call
                op_apply[op] += 1
                sop_apply[type(a.op)] += 1

        # Compute stats per op class
        sop_time = {}
        sop_call = {}
        sop_op = {}
        # map each op class to Bool. True iff all applies were done in c.
        sop_cimpl = {}
        for a, t in iteritems(op_time):
            typ = type(a)
            sop_time.setdefault(typ, 0)
            sop_time[typ] += t
            sop_op.setdefault(typ, 0)
            sop_op[typ] += 1
            sop_cimpl.setdefault(typ, True)
            sop_cimpl[typ] = sop_cimpl[typ] and op_cimpl.get(a, False)
            sop_call[typ] = sop_call.get(typ, 0) + op_call[a]

        # Print the summary per op class.
        print()
        print('Single Op-wise summary:')
        print('<% of local_time spent on this kind of Op> <cumulative %> '
              '<self seconds> <cumulative seconds> <time per call> [*] '
              '<nb_call> <nb_op> <nb_apply> <Op name>')
        sotimes = [(t * 100 / local_time, t, a, sop_cimpl[a], sop_call[a],
                    sop_op[a], sop_apply[a]) for a, t in iteritems(sop_time)]
        sotimes.sort()
        sotimes.reverse()
        tot = 0
        for f, t, a, ci, nb_call, nb_op, nb_apply in sotimes[:n_ops_to_print]:
            if nb_call == 0:
                assert t == 0
                continue
            tot += t
            ftot = tot * 100 / local_time
            if ci:
                msg = '*'
            else:
                msg = ' '
            print('   %4.1f%%  %5.1f%%  %5.3fs  %5.3fs  %.2es %s %5d %2d '
                  '%2d %s' % (f, ftot, t, tot, t / nb_call, msg, nb_call,
                              nb_op, nb_apply, a))
        print('   ... (remaining %i single Op account for %.2f%%(%.2fs) of '
              'the runtime)' %
              (max(0, len(sotimes) - n_ops_to_print),
               sum(soinfo[0] for soinfo in sotimes[n_ops_to_print:]),
               sum(soinfo[1] for soinfo in sotimes[n_ops_to_print:])))

        print('(*) Op is running a c implementation')

        # The summary per op
        op_flops = {}
        for a, t in iteritems(op_time):
            if hasattr(a, 'flops'):
                op_flops[a] = a.flops * op_call[a] / t / 1e6
        flops_msg = ''
        if op_flops:
            flops_msg = ' <MFlops/s>'
            print("\nHACK WARNING: we print the flops for some OP, but the "
                  "logic doesn't always work. You need to know the "
                  "internals of Theano to make it work correctly. "
                  "Otherwise don't use it!")
        print()
        print('Op-wise summary:')
        print('<%% of local_time spent on this kind of Op> <cumulative %%> '
              '<self seconds> <cumulative seconds> <time per call> [*] %s '
              '<nb_call> <nb apply> <Op name>' % (flops_msg))

        otimes = [(t * 100 / local_time, t, a, op_cimpl.get(a, 0),
                   op_call.get(a, 0), op_apply.get(a, 0))
                  for a, t in iteritems(op_time)]
        otimes.sort()
        otimes.reverse()
        tot = 0
        for f, t, a, ci, nb_call, nb_apply in otimes[:n_ops_to_print]:
            if nb_call == 0:
                assert t == 0
                continue
            tot += t
            ftot = tot * 100 / local_time
            if ci:
                msg = '*'
            else:
                msg = ' '
            if op_flops:
                print('   %4.1f%%  %5.1f%%  %5.3fs  %5.3fs  %.2es %s %7.1f '
                      '%5d %2d %s' % (f, ftot, t, tot, t / nb_call, msg,
                                      op_flops.get(a, -1), nb_call, nb_apply,
                                      a))
            else:
                print('   %4.1f%%  %5.1f%%  %5.3fs  %5.3fs  %.2es %s %5d %2d '
                      '%s' % (f, ftot, t, tot, t / nb_call, msg, nb_call,
                              nb_apply, a))
        print('   ... (remaining %i Op account for %6.2f%%(%.2fs) of the '
              'runtime)' %
              (max(0, len(otimes) - n_ops_to_print),
               sum(f for f, t, a, ci, nb_call, nb_op in
                   otimes[n_ops_to_print:]),
               sum(t for f, t, a, ci, nb_call, nb_op in
                   otimes[n_ops_to_print:])))
        print('(*) Op is running a c implementation')

        if print_apply:
            print()
            print('Apply-wise summary:')
            print('<% of local_time spent at this position> <cumulative %%> '
                  '<apply time> <cumulative seconds> <time per call> [*] '
                  '<nb_call> <Apply position> <Apply Op name>')
            atimes = [(t * 100 / local_time, t, a,
                       [v for k, v in iteritems(fct_call)
                        if k.maker.fgraph is a[1].fgraph][0])
                      for a, t in iteritems(apply_time)]
            atimes.sort()
            atimes.reverse()
            tot = 0
            for f, t, a, nb_call in atimes[:n_apply_to_print]:
                tot += t
                ftot = tot * 100 / local_time
                if nb_call == 0:
                    continue
                if apply_cimpl.get(a[1], False):
                    msg = '*'
                else:
                    msg = ' '
                print('   %4.1f%%  %5.1f%%  %5.3fs  %5.3fs %.2es  %s %i  '
                      '%2i %s' %
                      (f, ftot, t, tot, t / nb_call, msg, nb_call, a[0],
                       str(a[1])))
            print('   ... (remaining %i Apply instances account for '
                  '%.2f%%(%.2fs) of the runtime)' %
                  (max(0, len(atimes) - n_apply_to_print),
                   sum(f for f, t, a, nb_call in atimes[n_apply_to_print:]),
                   sum(t for f, t, a, nb_call in atimes[n_apply_to_print:])))
            print('(*) Op is running a c implementation')
        for printer in profiler_printers:
            printer(fct_name, compile_time, fct_call_time, fct_call,
                    apply_time, apply_cimpl, message, variable_shape,
                    other_time)

        if not variable_shape:
            print("\nProfile of Theano intermediate memory disabled. "
                  "To enable, set the Theano flag ProfileMode.profile_memory "
                  "to True.")
        else:
            print("""
            The memory profile in ProfileMode is removed!
            Use the new profiler. Use the Theano flags
            profile=True,profile_memory=True to enable it.""")

        print()
        print("""Here are tips to potentially make your code run faster
(if you think of new ones, suggest them on the mailing list).
Test them first, as they are not guaranteed to always provide a speedup.""")
        from theano import tensor as T
        from theano.tensor.raw_random import RandomFunction
        import theano
        import theano.scalar as scal
        scalar_op_amdlibm_no_speed_up = [scal.LT, scal.GT, scal.LE, scal.GE,
                                         scal.EQ, scal.NEQ, scal.InRange,
                                         scal.Switch, scal.OR, scal.XOR,
                                         scal.AND, scal.Invert, scal.Maximum,
                                         scal.Minimum, scal.Add, scal.Mul,
                                         scal.Sub, scal.TrueDiv, scal.IntDiv,
                                         scal.Clip, scal.Second, scal.Identity,
                                         scal.Cast, scal.Sgn, scal.Neg,
                                         scal.Inv, scal.Sqr]
        scalar_op_amdlibm_speed_up = [scal.Mod, scal.Pow, scal.Ceil,
                                      scal.Floor, scal.RoundHalfToEven,
                                      scal.RoundHalfAwayFromZero, scal.Log,
                                      scal.Log2, scal.Log10, scal.Log1p,
                                      scal.Exp, scal.Sqrt, scal.Abs, scal.Cos,
                                      scal.Sin, scal.Tan, scal.Tanh,
                                      scal.Cosh, scal.Sinh,
                                      T.nnet.sigm.ScalarSigmoid,
                                      T.nnet.sigm.ScalarSoftplus]

        def get_scalar_ops(s):
            if isinstance(s, theano.scalar.Composite):
                l = []
                for node in s.fgraph.toposort():
                    l += get_scalar_ops(node.op)
                return l
            else:
                return [s]

        def list_scalar_op(op):
            if isinstance(op.scalar_op, theano.scalar.Composite):
                return get_scalar_ops(op.scalar_op)
            else:
                return [op.scalar_op]

        def amdlibm_speed_up(op):
            if not isinstance(op, T.Elemwise):
                return False
            else:
                l = list_scalar_op(op)
                for s_op in l:
                    if s_op.__class__ in scalar_op_amdlibm_speed_up:
                        return True
                    elif s_op.__class__ not in scalar_op_amdlibm_no_speed_up:
                        print("We don't know if amdlibm will accelerate "
                              "this scalar op.", s_op)
                return False

        def exp_float32_op(op):
            if not isinstance(op, T.Elemwise):
                return False
            else:
                l = list_scalar_op(op)
                return any([s_op.__class__ in [scal.Exp] for s_op in l])

        printed_tip = False
        # tip 1
        if config.floatX == 'float64':
            print("  - Try the Theano flag floatX=float32")
            printed_tip = True

        # tip 2
        if not config.lib.amdlibm and any([amdlibm_speed_up(a.op) for i, a
                                           in apply_time]):
            print("  - Try installing amdlibm and set the Theano flag "
                  "lib.amdlibm=True. This speeds up only some Elemwise "
                  "operation.")
            printed_tip = True

        # tip 3
        if not config.lib.amdlibm and any([exp_float32_op(a.op) and
                                           a.inputs[0].dtype == 'float32'
                                           for i, a in apply_time]):
            print("  - With the default gcc libm, exp in float32 is slower "
                  "than in float64! Try Theano flag floatX=float64, or "
                  "install amdlibm and set the theano flags lib.amdlibm=True")
            printed_tip = True

        # tip 4
        for a, t in iteritems(apply_time):
            node = a[1]
            if (isinstance(node.op, T.Dot) and
                    all([len(i.type.broadcastable) == 2
                         for i in node.inputs])):
                print("  - You have a dot operation that was not optimized to"
                      " dot22 (which is faster). Make sure the inputs are "
                      "float32 or float64, and are the same for both inputs. "
                      "Currently they are: %s" %
                      [i.type for i in node.inputs])
                printed_tip = True

        # tip 5
        for a, t in iteritems(apply_time):
            node = a[1]
            if isinstance(node.op, RandomFunction):
                printed_tip = True
                print("  - Replace the default random number generator by "
                      "'from theano.sandbox.rng_mrg import MRG_RandomStreams "
                      "as RandomStreams', as this is is faster. It is still "
                      "experimental, but seems to work correctly.")
                if config.device.startswith("gpu"):
                    print("     - MRG_RandomStreams is the only random number"
                          " generator supported on the GPU.")
                break

        if not printed_tip:
            print("  Sorry, no tip for today.")

Example 50

Project: attention-lvcsr
Source File: formatting.py
View license
    def __call__(self, fct, graph=None):
        """Create pydot graph from function.

        Parameters
        ----------
        fct : theano.compile.function_module.Function
            A compiled Theano function, variable, apply or a list of variables.
        graph: pydot.Dot
            `pydot` graph to which nodes are added. Creates new one if
            undefined.

        Returns
        -------
        pydot.Dot
            Pydot graph of `fct`
        """
        if graph is None:
            graph = pd.Dot()

        self.__nodes = {}

        profile = None
        if isinstance(fct, Function):
            mode = fct.maker.mode
            if (not isinstance(mode, ProfileMode) or
                    fct not in mode.profile_stats):
                mode = None
            if mode:
                profile = mode.profile_stats[fct]
            else:
                profile = getattr(fct, "profile", None)
            outputs = fct.maker.fgraph.outputs
            topo = fct.maker.fgraph.toposort()
        elif isinstance(fct, gof.FunctionGraph):
            outputs = fct.outputs
            topo = fct.toposort()
        else:
            if isinstance(fct, gof.Variable):
                fct = [fct]
            elif isinstance(fct, gof.Apply):
                fct = fct.outputs
            assert isinstance(fct, (list, tuple))
            assert all(isinstance(v, gof.Variable) for v in fct)
            fct = gof.FunctionGraph(inputs=gof.graph.inputs(fct),
                                    outputs=fct)
            outputs = fct.outputs
            topo = fct.toposort()
        outputs = list(outputs)

        # Loop over apply nodes
        for node in topo:
            nparams = {}
            __node_id = self.__node_id(node)
            nparams['name'] = __node_id
            nparams['label'] = apply_label(node)
            nparams['profile'] = apply_profile(node, profile)
            nparams['node_type'] = 'apply'
            nparams['apply_op'] = nparams['label']
            nparams['shape'] = self.shapes['apply']

            use_color = None
            for opName, color in iteritems(self.apply_colors):
                if opName in node.op.__class__.__name__:
                    use_color = color
            if use_color:
                nparams['style'] = 'filled'
                nparams['fillcolor'] = use_color
                nparams['type'] = 'colored'

            pd_node = dict_to_pdnode(nparams)
            graph.add_node(pd_node)

            # Loop over input nodes
            for id, var in enumerate(node.inputs):
                var_id = self.__node_id(var.owner if var.owner else var)
                if var.owner is None:
                    vparams = {'name': var_id,
                               'label': var_label(var),
                               'node_type': 'input'}
                    if isinstance(var, gof.Constant):
                        vparams['node_type'] = 'constant_input'
                    elif isinstance(var, theano.tensor.sharedvar.
                                    TensorSharedVariable):
                        vparams['node_type'] = 'shared_input'
                    vparams['dtype'] = type_to_str(var.type)
                    vparams['tag'] = var_tag(var)
                    vparams['style'] = 'filled'
                    vparams['fillcolor'] = self.node_colors[
                        vparams['node_type']]
                    vparams['shape'] = self.shapes['input']
                    pd_var = dict_to_pdnode(vparams)
                    graph.add_node(pd_var)

                edge_params = {}
                if hasattr(node.op, 'view_map') and \
                        id in reduce(list.__add__,
                                     itervalues(node.op.view_map), []):
                    edge_params['color'] = self.node_colors['output']
                elif hasattr(node.op, 'destroy_map') and \
                        id in reduce(list.__add__,
                                     itervalues(node.op.destroy_map), []):
                    edge_params['color'] = 'red'

                edge_label = vparams['dtype']
                if len(node.inputs) > 1:
                    edge_label = str(id) + ' ' + edge_label
                pdedge = pd.Edge(var_id, __node_id, label=edge_label,
                                 **edge_params)
                graph.add_edge(pdedge)

            # Loop over output nodes
            for id, var in enumerate(node.outputs):
                var_id = self.__node_id(var)

                if var in outputs or len(var.clients) == 0:
                    vparams = {'name': var_id,
                               'label': var_label(var),
                               'node_type': 'output',
                               'dtype': type_to_str(var.type),
                               'tag': var_tag(var),
                               'style': 'filled'}
                    if len(var.clients) == 0:
                        vparams['fillcolor'] = self.node_colors['unused']
                    else:
                        vparams['fillcolor'] = self.node_colors['output']
                    vparams['shape'] = self.shapes['output']
                    pd_var = dict_to_pdnode(vparams)
                    graph.add_node(pd_var)

                    graph.add_edge(pd.Edge(__node_id, var_id,
                                           label=vparams['dtype']))
                elif var.name or not self.compact:
                    graph.add_edge(pd.Edge(__node_id, var_id,
                                           label=vparams['dtype']))

            # Create sub-graph for OpFromGraph nodes
            if isinstance(node.op, builders.OpFromGraph):
                subgraph = pd.Cluster(__node_id)
                gf = PyDotFormatter()
                # Use different node prefix for sub-graphs
                gf.__node_prefix = __node_id
                gf(node.op.fn, subgraph)
                graph.add_subgraph(subgraph)
                pd_node.get_attributes()['subg'] = subgraph.get_name()

                def format_map(m):
                    return str([list(x) for x in m])

                # Inputs mapping
                ext_inputs = [self.__node_id(x) for x in node.inputs]
                int_inputs = [gf.__node_id(x)
                              for x in node.op.fn.maker.fgraph.inputs]
                assert len(ext_inputs) == len(int_inputs)
                h = format_map(zip(ext_inputs, int_inputs))
                pd_node.get_attributes()['subg_map_inputs'] = h

                # Outputs mapping
                ext_outputs = []
                for n in topo:
                    for i in n.inputs:
                        h = i.owner if i.owner else i
                        if h is node:
                            ext_outputs.append(self.__node_id(n))
                int_outputs = node.op.fn.maker.fgraph.outputs
                int_outputs = [gf.__node_id(x) for x in int_outputs]
                assert len(ext_outputs) == len(int_outputs)
                h = format_map(zip(int_outputs, ext_outputs))
                pd_node.get_attributes()['subg_map_outputs'] = h

        return graph