functools.partial

Here are the examples of the python api functools.partial taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: raspberry_pwn
Source File: target.py
View license
def _setRequestParams():
    """
    Check and set the parameters and perform checks on 'data' option for
    HTTP method POST.
    """

    if conf.direct:
        conf.parameters[None] = "direct connection"
        return

    testableParameters = False

    # Perform checks on GET parameters
    if conf.parameters.get(PLACE.GET):
        parameters = conf.parameters[PLACE.GET]
        paramDict = paramToDict(PLACE.GET, parameters)

        if paramDict:
            conf.paramDict[PLACE.GET] = paramDict
            testableParameters = True

    # Perform checks on POST parameters
    if conf.method == HTTPMETHOD.POST and conf.data is None:
        errMsg = "HTTP POST method depends on HTTP data value to be posted"
        raise SqlmapSyntaxException(errMsg)

    if conf.data is not None:
        conf.method = HTTPMETHOD.POST if not conf.method or conf.method == HTTPMETHOD.GET else conf.method

        def process(match, repl):
            retVal = match.group(0)

            if not (conf.testParameter and match.group("name") not in conf.testParameter):
                retVal = repl
                while True:
                    _ = re.search(r"\\g<([^>]+)>", retVal)
                    if _:
                        retVal = retVal.replace(_.group(0), match.group(int(_.group(1)) if _.group(1).isdigit() else _.group(1)))
                    else:
                        break

            return retVal

        if kb.processUserMarks is None and CUSTOM_INJECTION_MARK_CHAR in conf.data:
            message = "custom injection marking character ('%s') found in option " % CUSTOM_INJECTION_MARK_CHAR
            message += "'--data'. Do you want to process it? [Y/n/q] "
            test = readInput(message, default="Y")
            if test and test[0] in ("q", "Q"):
                raise SqlmapUserQuitException
            else:
                kb.processUserMarks = not test or test[0] not in ("n", "N")

                if kb.processUserMarks and "=%s" % CUSTOM_INJECTION_MARK_CHAR in conf.data:
                    warnMsg = "it seems that you've provided empty parameter value(s) "
                    warnMsg += "for testing. Please, always use only valid parameter values "
                    warnMsg += "so sqlmap could be able to run properly"
                    logger.warn(warnMsg)

        if not (kb.processUserMarks and CUSTOM_INJECTION_MARK_CHAR in conf.data):
            if re.search(JSON_RECOGNITION_REGEX, conf.data):
                message = "JSON data found in %s data. " % conf.method
                message += "Do you want to process it? [Y/n/q] "
                test = readInput(message, default="Y")
                if test and test[0] in ("q", "Q"):
                    raise SqlmapUserQuitException
                elif test[0] not in ("n", "N"):
                    conf.data = conf.data.replace(CUSTOM_INJECTION_MARK_CHAR, ASTERISK_MARKER)
                    conf.data = re.sub(r'("(?P<name>[^"]+)"\s*:\s*"[^"]+)"', functools.partial(process, repl=r'\g<1>%s"' % CUSTOM_INJECTION_MARK_CHAR), conf.data)
                    conf.data = re.sub(r'("(?P<name>[^"]+)"\s*:\s*)(-?\d[\d\.]*\b)', functools.partial(process, repl=r'\g<0>%s' % CUSTOM_INJECTION_MARK_CHAR), conf.data)
                    match = re.search(r'(?P<name>[^"]+)"\s*:\s*\[([^\]]+)\]', conf.data)
                    if match and not (conf.testParameter and match.group("name") not in conf.testParameter):
                        _ = match.group(2)
                        _ = re.sub(r'("[^"]+)"', '\g<1>%s"' % CUSTOM_INJECTION_MARK_CHAR, _)
                        _ = re.sub(r'(\A|,|\s+)(-?\d[\d\.]*\b)', '\g<0>%s' % CUSTOM_INJECTION_MARK_CHAR, _)
                        conf.data = conf.data.replace(match.group(0), match.group(0).replace(match.group(2), _))
                    kb.postHint = POST_HINT.JSON

            elif re.search(JSON_LIKE_RECOGNITION_REGEX, conf.data):
                message = "JSON-like data found in %s data. " % conf.method
                message += "Do you want to process it? [Y/n/q] "
                test = readInput(message, default="Y")
                if test and test[0] in ("q", "Q"):
                    raise SqlmapUserQuitException
                elif test[0] not in ("n", "N"):
                    conf.data = conf.data.replace(CUSTOM_INJECTION_MARK_CHAR, ASTERISK_MARKER)
                    conf.data = re.sub(r"('(?P<name>[^']+)'\s*:\s*'[^']+)'", functools.partial(process, repl=r"\g<1>%s'" % CUSTOM_INJECTION_MARK_CHAR), conf.data)
                    conf.data = re.sub(r"('(?P<name>[^']+)'\s*:\s*)(-?\d[\d\.]*\b)", functools.partial(process, repl=r"\g<0>%s" % CUSTOM_INJECTION_MARK_CHAR), conf.data)
                    kb.postHint = POST_HINT.JSON_LIKE

            elif re.search(ARRAY_LIKE_RECOGNITION_REGEX, conf.data):
                message = "Array-like data found in %s data. " % conf.method
                message += "Do you want to process it? [Y/n/q] "
                test = readInput(message, default="Y")
                if test and test[0] in ("q", "Q"):
                    raise SqlmapUserQuitException
                elif test[0] not in ("n", "N"):
                    conf.data = conf.data.replace(CUSTOM_INJECTION_MARK_CHAR, ASTERISK_MARKER)
                    conf.data = re.sub(r"(=[^%s]+)" % DEFAULT_GET_POST_DELIMITER, r"\g<1>%s" % CUSTOM_INJECTION_MARK_CHAR, conf.data)
                    kb.postHint = POST_HINT.ARRAY_LIKE

            elif re.search(XML_RECOGNITION_REGEX, conf.data):
                message = "SOAP/XML data found in %s data. " % conf.method
                message += "Do you want to process it? [Y/n/q] "
                test = readInput(message, default="Y")
                if test and test[0] in ("q", "Q"):
                    raise SqlmapUserQuitException
                elif test[0] not in ("n", "N"):
                    conf.data = conf.data.replace(CUSTOM_INJECTION_MARK_CHAR, ASTERISK_MARKER)
                    conf.data = re.sub(r"(<(?P<name>[^>]+)( [^<]*)?>)([^<]+)(</\2)", functools.partial(process, repl=r"\g<1>\g<4>%s\g<5>" % CUSTOM_INJECTION_MARK_CHAR), conf.data)
                    kb.postHint = POST_HINT.SOAP if "soap" in conf.data.lower() else POST_HINT.XML

            elif re.search(MULTIPART_RECOGNITION_REGEX, conf.data):
                message = "Multipart like data found in %s data. " % conf.method
                message += "Do you want to process it? [Y/n/q] "
                test = readInput(message, default="Y")
                if test and test[0] in ("q", "Q"):
                    raise SqlmapUserQuitException
                elif test[0] not in ("n", "N"):
                    conf.data = conf.data.replace(CUSTOM_INJECTION_MARK_CHAR, ASTERISK_MARKER)
                    conf.data = re.sub(r"(?si)((Content-Disposition[^\n]+?name\s*=\s*[\"'](?P<name>[^\n]+?)[\"']).+?)(((\r)?\n)+--)", functools.partial(process, repl=r"\g<1>%s\g<4>" % CUSTOM_INJECTION_MARK_CHAR), conf.data)
                    kb.postHint = POST_HINT.MULTIPART

        if not kb.postHint:
            if CUSTOM_INJECTION_MARK_CHAR in conf.data:  # later processed
                pass
            else:
                place = PLACE.POST

                conf.parameters[place] = conf.data
                paramDict = paramToDict(place, conf.data)

                if paramDict:
                    conf.paramDict[place] = paramDict
                    testableParameters = True
        else:
            if CUSTOM_INJECTION_MARK_CHAR not in conf.data:  # in case that no usable parameter values has been found
                conf.parameters[PLACE.POST] = conf.data

    kb.processUserMarks = True if (kb.postHint and CUSTOM_INJECTION_MARK_CHAR in conf.data) else kb.processUserMarks

    if re.search(URI_INJECTABLE_REGEX, conf.url, re.I) and not any(place in conf.parameters for place in (PLACE.GET, PLACE.POST)) and not kb.postHint and not CUSTOM_INJECTION_MARK_CHAR in (conf.data or ""):
        warnMsg = "you've provided target URL without any GET "
        warnMsg += "parameters (e.g. www.site.com/article.php?id=1) "
        warnMsg += "and without providing any POST parameters "
        warnMsg += "through --data option"
        logger.warn(warnMsg)

        message = "do you want to try URI injections "
        message += "in the target URL itself? [Y/n/q] "
        test = readInput(message, default="Y")

        if not test or test[0] not in ("n", "N"):
            conf.url = "%s%s" % (conf.url, CUSTOM_INJECTION_MARK_CHAR)
            kb.processUserMarks = True
        elif test[0] in ("q", "Q"):
            raise SqlmapUserQuitException

    for place, value in ((PLACE.URI, conf.url), (PLACE.CUSTOM_POST, conf.data), (PLACE.CUSTOM_HEADER, str(conf.httpHeaders))):
        _ = re.sub(PROBLEMATIC_CUSTOM_INJECTION_PATTERNS, "", value or "") if place == PLACE.CUSTOM_HEADER else value or ""
        if CUSTOM_INJECTION_MARK_CHAR in _:
            if kb.processUserMarks is None:
                lut = {PLACE.URI: '-u', PLACE.CUSTOM_POST: '--data', PLACE.CUSTOM_HEADER: '--headers/--user-agent/--referer/--cookie'}
                message = "custom injection marking character ('%s') found in option " % CUSTOM_INJECTION_MARK_CHAR
                message += "'%s'. Do you want to process it? [Y/n/q] " % lut[place]
                test = readInput(message, default="Y")
                if test and test[0] in ("q", "Q"):
                    raise SqlmapUserQuitException
                else:
                    kb.processUserMarks = not test or test[0] not in ("n", "N")

                    if kb.processUserMarks and "=%s" % CUSTOM_INJECTION_MARK_CHAR in _:
                        warnMsg = "it seems that you've provided empty parameter value(s) "
                        warnMsg += "for testing. Please, always use only valid parameter values "
                        warnMsg += "so sqlmap could be able to run properly"
                        logger.warn(warnMsg)

            if not kb.processUserMarks:
                if place == PLACE.URI:
                    query = urlparse.urlsplit(value).query
                    if query:
                        parameters = conf.parameters[PLACE.GET] = query
                        paramDict = paramToDict(PLACE.GET, parameters)

                        if paramDict:
                            conf.url = conf.url.split('?')[0]
                            conf.paramDict[PLACE.GET] = paramDict
                            testableParameters = True
                elif place == PLACE.CUSTOM_POST:
                    conf.parameters[PLACE.POST] = conf.data
                    paramDict = paramToDict(PLACE.POST, conf.data)

                    if paramDict:
                        conf.paramDict[PLACE.POST] = paramDict
                        testableParameters = True

            else:
                conf.parameters[place] = value
                conf.paramDict[place] = OrderedDict()

                if place == PLACE.CUSTOM_HEADER:
                    for index in xrange(len(conf.httpHeaders)):
                        header, value = conf.httpHeaders[index]
                        if CUSTOM_INJECTION_MARK_CHAR in re.sub(PROBLEMATIC_CUSTOM_INJECTION_PATTERNS, "", value):
                            parts = value.split(CUSTOM_INJECTION_MARK_CHAR)
                            for i in xrange(len(parts) - 1):
                                conf.paramDict[place]["%s #%d%s" % (header, i + 1, CUSTOM_INJECTION_MARK_CHAR)] = "%s,%s" % (header, "".join("%s%s" % (parts[j], CUSTOM_INJECTION_MARK_CHAR if i == j else "") for j in xrange(len(parts))))
                            conf.httpHeaders[index] = (header, value.replace(CUSTOM_INJECTION_MARK_CHAR, ""))
                else:
                    parts = value.split(CUSTOM_INJECTION_MARK_CHAR)

                    for i in xrange(len(parts) - 1):
                        conf.paramDict[place]["%s#%d%s" % (("%s " % kb.postHint) if kb.postHint else "", i + 1, CUSTOM_INJECTION_MARK_CHAR)] = "".join("%s%s" % (parts[j], CUSTOM_INJECTION_MARK_CHAR if i == j else "") for j in xrange(len(parts)))

                    if place == PLACE.URI and PLACE.GET in conf.paramDict:
                        del conf.paramDict[PLACE.GET]
                    elif place == PLACE.CUSTOM_POST and PLACE.POST in conf.paramDict:
                        del conf.paramDict[PLACE.POST]

                testableParameters = True

    if kb.processUserMarks:
        for item in ("url", "data", "agent", "referer", "cookie"):
            if conf.get(item):
                conf[item] = conf[item].replace(CUSTOM_INJECTION_MARK_CHAR, "")

    # Perform checks on Cookie parameters
    if conf.cookie:
        conf.parameters[PLACE.COOKIE] = conf.cookie
        paramDict = paramToDict(PLACE.COOKIE, conf.cookie)

        if paramDict:
            conf.paramDict[PLACE.COOKIE] = paramDict
            testableParameters = True

    # Perform checks on header values
    if conf.httpHeaders:
        for httpHeader, headerValue in conf.httpHeaders:
            # Url encoding of the header values should be avoided
            # Reference: http://stackoverflow.com/questions/5085904/is-ok-to-urlencode-the-value-in-headerlocation-value

            httpHeader = httpHeader.title()

            if httpHeader == HTTP_HEADER.USER_AGENT:
                conf.parameters[PLACE.USER_AGENT] = urldecode(headerValue)

                condition = any((not conf.testParameter, intersect(conf.testParameter, USER_AGENT_ALIASES)))

                if condition:
                    conf.paramDict[PLACE.USER_AGENT] = {PLACE.USER_AGENT: headerValue}
                    testableParameters = True

            elif httpHeader == HTTP_HEADER.REFERER:
                conf.parameters[PLACE.REFERER] = urldecode(headerValue)

                condition = any((not conf.testParameter, intersect(conf.testParameter, REFERER_ALIASES)))

                if condition:
                    conf.paramDict[PLACE.REFERER] = {PLACE.REFERER: headerValue}
                    testableParameters = True

            elif httpHeader == HTTP_HEADER.HOST:
                conf.parameters[PLACE.HOST] = urldecode(headerValue)

                condition = any((not conf.testParameter, intersect(conf.testParameter, HOST_ALIASES)))

                if condition:
                    conf.paramDict[PLACE.HOST] = {PLACE.HOST: headerValue}
                    testableParameters = True

    if not conf.parameters:
        errMsg = "you did not provide any GET, POST and Cookie "
        errMsg += "parameter, neither an User-Agent, Referer or Host header value"
        raise SqlmapGenericException(errMsg)

    elif not testableParameters:
        errMsg = "all testable parameters you provided are not present "
        errMsg += "within the given request data"
        raise SqlmapGenericException(errMsg)

    if conf.csrfToken:
        if not any(conf.csrfToken in _ for _ in (conf.paramDict.get(PLACE.GET, {}), conf.paramDict.get(PLACE.POST, {}))) and not conf.csrfToken in set(_[0].lower() for _ in conf.httpHeaders) and not conf.csrfToken in conf.paramDict.get(PLACE.COOKIE, {}):
            errMsg = "CSRF protection token parameter '%s' not " % conf.csrfToken
            errMsg += "found in provided GET, POST, Cookie or header values"
            raise SqlmapGenericException(errMsg)
    else:
        for place in (PLACE.GET, PLACE.POST, PLACE.COOKIE):
            for parameter in conf.paramDict.get(place, {}):
                if any(parameter.lower().count(_) for _ in CSRF_TOKEN_PARAMETER_INFIXES):
                    message = "%s parameter '%s' appears to hold CSRF protection token. " % (place, parameter)
                    message += "Do you want sqlmap to automatically update it in further requests? [y/N] "
                    test = readInput(message, default="N")
                    if test and test[0] in ("y", "Y"):
                        conf.csrfToken = parameter
                    break

Example 2

Project: cgstudiomap
Source File: writer.py
View license
    @classmethod
    def serialize(cls, value):
        """
        Serializes the value to a string that's parsable by Python, along
        with any needed imports to make that string work.
        More advanced than repr() as it can encode things
        like datetime.datetime.now.
        """
        # FIXME: Ideally Promise would be reconstructible, but for now we
        # use force_text on them and defer to the normal string serialization
        # process.
        if isinstance(value, Promise):
            value = force_text(value)
        elif isinstance(value, LazyObject):
            # The unwrapped value is returned as the first item of the
            # arguments tuple.
            value = value.__reduce__()[1][0]

        # Sequences
        if isinstance(value, (frozenset, list, set, tuple)):
            imports = set()
            strings = []
            for item in value:
                item_string, item_imports = cls.serialize(item)
                imports.update(item_imports)
                strings.append(item_string)
            if isinstance(value, set):
                # Don't use the literal "{%s}" as it doesn't support empty set
                format = "set([%s])"
            elif isinstance(value, frozenset):
                format = "frozenset([%s])"
            elif isinstance(value, tuple):
                # When len(value)==0, the empty tuple should be serialized as
                # "()", not "(,)" because (,) is invalid Python syntax.
                format = "(%s)" if len(value) != 1 else "(%s,)"
            else:
                format = "[%s]"
            return format % (", ".join(strings)), imports
        # Dictionaries
        elif isinstance(value, dict):
            imports = set()
            strings = []
            for k, v in sorted(value.items()):
                k_string, k_imports = cls.serialize(k)
                v_string, v_imports = cls.serialize(v)
                imports.update(k_imports)
                imports.update(v_imports)
                strings.append((k_string, v_string))
            return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports
        # Datetimes
        elif isinstance(value, datetime.datetime):
            value_repr = cls.serialize_datetime(value)
            imports = ["import datetime"]
            if value.tzinfo is not None:
                imports.append("from django.utils.timezone import utc")
            return value_repr, set(imports)
        # Dates
        elif isinstance(value, datetime.date):
            value_repr = repr(value)
            if isinstance(value, datetime_safe.date):
                value_repr = "datetime.%s" % value_repr
            return value_repr, {"import datetime"}
        # Times
        elif isinstance(value, datetime.time):
            value_repr = repr(value)
            if isinstance(value, datetime_safe.time):
                value_repr = "datetime.%s" % value_repr
            return value_repr, {"import datetime"}
        # Timedeltas
        elif isinstance(value, datetime.timedelta):
            return repr(value), {"import datetime"}
        # Settings references
        elif isinstance(value, SettingsReference):
            return "settings.%s" % value.setting_name, {"from django.conf import settings"}
        # Simple types
        elif isinstance(value, float):
            if math.isnan(value) or math.isinf(value):
                return 'float("{}")'.format(value), set()
            return repr(value), set()
        elif isinstance(value, six.integer_types + (bool, type(None))):
            return repr(value), set()
        elif isinstance(value, six.binary_type):
            value_repr = repr(value)
            if six.PY2:
                # Prepend the `b` prefix since we're importing unicode_literals
                value_repr = 'b' + value_repr
            return value_repr, set()
        elif isinstance(value, six.text_type):
            value_repr = repr(value)
            if six.PY2:
                # Strip the `u` prefix since we're importing unicode_literals
                value_repr = value_repr[1:]
            return value_repr, set()
        # Decimal
        elif isinstance(value, decimal.Decimal):
            return repr(value), {"from decimal import Decimal"}
        # Django fields
        elif isinstance(value, models.Field):
            attr_name, path, args, kwargs = value.deconstruct()
            return cls.serialize_deconstructed(path, args, kwargs)
        # Classes
        elif isinstance(value, type):
            special_cases = [
                (models.Model, "models.Model", []),
            ]
            for case, string, imports in special_cases:
                if case is value:
                    return string, set(imports)
            if hasattr(value, "__module__"):
                module = value.__module__
                if module == six.moves.builtins.__name__:
                    return value.__name__, set()
                else:
                    return "%s.%s" % (module, value.__name__), {"import %s" % module}
        elif isinstance(value, models.manager.BaseManager):
            as_manager, manager_path, qs_path, args, kwargs = value.deconstruct()
            if as_manager:
                name, imports = cls._serialize_path(qs_path)
                return "%s.as_manager()" % name, imports
            else:
                return cls.serialize_deconstructed(manager_path, args, kwargs)
        elif isinstance(value, Operation):
            string, imports = OperationWriter(value, indentation=0).serialize()
            # Nested operation, trailing comma is handled in upper OperationWriter._write()
            return string.rstrip(','), imports
        elif isinstance(value, functools.partial):
            imports = {'import functools'}
            # Serialize functools.partial() arguments
            func_string, func_imports = cls.serialize(value.func)
            args_string, args_imports = cls.serialize(value.args)
            keywords_string, keywords_imports = cls.serialize(value.keywords)
            # Add any imports needed by arguments
            imports.update(func_imports)
            imports.update(args_imports)
            imports.update(keywords_imports)
            return (
                "functools.partial(%s, *%s, **%s)" % (
                    func_string, args_string, keywords_string,
                ),
                imports,
            )
        # Anything that knows how to deconstruct itself.
        elif hasattr(value, 'deconstruct'):
            return cls.serialize_deconstructed(*value.deconstruct())
        # Functions
        elif isinstance(value, (types.FunctionType, types.BuiltinFunctionType)):
            # @classmethod?
            if getattr(value, "__self__", None) and isinstance(value.__self__, type):
                klass = value.__self__
                module = klass.__module__
                return "%s.%s.%s" % (module, klass.__name__, value.__name__), {"import %s" % module}
            # Further error checking
            if value.__name__ == '<lambda>':
                raise ValueError("Cannot serialize function: lambda")
            if value.__module__ is None:
                raise ValueError("Cannot serialize function %r: No module" % value)
            # Python 3 is a lot easier, and only uses this branch if it's not local.
            if getattr(value, "__qualname__", None) and getattr(value, "__module__", None):
                if "<" not in value.__qualname__:  # Qualname can include <locals>
                    return "%s.%s" % (value.__module__, value.__qualname__), {"import %s" % value.__module__}
            # Python 2/fallback version
            module_name = value.__module__
            # Make sure it's actually there and not an unbound method
            module = import_module(module_name)
            if not hasattr(module, value.__name__):
                raise ValueError(
                    "Could not find function %s in %s.\n"
                    "Please note that due to Python 2 limitations, you cannot "
                    "serialize unbound method functions (e.g. a method "
                    "declared and used in the same class body). Please move "
                    "the function into the main module body to use migrations.\n"
                    "For more information, see "
                    "https://docs.djangoproject.com/en/%s/topics/migrations/#serializing-values"
                    % (value.__name__, module_name, get_docs_version()))
            # Needed on Python 2 only
            if module_name == '__builtin__':
                return value.__name__, set()
            return "%s.%s" % (module_name, value.__name__), {"import %s" % module_name}
        # Other iterables
        elif isinstance(value, collections.Iterable):
            imports = set()
            strings = []
            for item in value:
                item_string, item_imports = cls.serialize(item)
                imports.update(item_imports)
                strings.append(item_string)
            # When len(strings)==0, the empty iterable should be serialized as
            # "()", not "(,)" because (,) is invalid Python syntax.
            format = "(%s)" if len(strings) != 1 else "(%s,)"
            return format % (", ".join(strings)), imports
        # Compiled regex
        elif isinstance(value, (COMPILED_REGEX_TYPE, RegexObject)):
            imports = {"import re"}
            regex_pattern, pattern_imports = cls.serialize(value.pattern)
            regex_flags, flag_imports = cls.serialize(value.flags)
            imports.update(pattern_imports)
            imports.update(flag_imports)
            args = [regex_pattern]
            if value.flags:
                args.append(regex_flags)
            return "re.compile(%s)" % ', '.join(args), imports
        # Uh oh.
        else:
            raise ValueError(
                "Cannot serialize: %r\nThere are some values Django cannot serialize into "
                "migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
                "topics/migrations/#migration-serializing" % (value, get_docs_version())
            )

Example 3

Project: namesync
Source File: main.py
View license
def main(argv=None, outfile=sys.stdout):
    parser = argparse.ArgumentParser(prog='namesync')
    parser.add_argument('-g', '--get', default=False, action='store_true',
                        help='save existing DNS records to RECORDS')
    parser.add_argument('-z', '--zone',
                        help='specify the zone instead of using the RECORDS filename')
    parser.add_argument('-y', '--yes', default=False, action='store_true',
                        help='sync records without prompting before making changes')
    parser.add_argument('-p', '--provider', default='cloudflare',
                        help='use the specified provider [default: cloudflare]')
    parser.add_argument('-c', '--config', default=DEFAULT_CONFIG_LOCATION,
                        help=(
                            'specify config location [default: ~/.namesync]'
                        ))
    parser.add_argument('records', metavar='RECORDS',
                        help=(
                            'file containing DNS records, one per line. '
                            'The zone is derived from the basename of this file. '
                            'For example, if "dns/example.com" is used then the zone is '
                            'assumed to be "example.com" unless the --zone option is used'
                        ))

    args = parser.parse_args(argv or sys.argv[1:])

    provider_class = get_provider(args.provider)

    config = environment_check(args.config, args.provider, provider_class)

    zone = args.zone if args.zone else os.path.basename(args.records)
    auto_commit = args.yes

    provider_config = config['providers'].get(args.provider, {})
    provider = provider_class(provider_config, zone)

    current_records = provider.records()

    def println(message):
        outfile.write(message)
        outfile.write('\n')

    if args.get:
        if os.path.exists(args.records):
            println('The file "{}" already exists. Refusing to overwrite!'.format(args.records))
            sys.exit(1)

        with open(args.records, 'w') as f:
            records_to_flatfile(current_records, f)

        sys.exit(0)

    with open(args.records) as f:
        new_records = flatfile_to_records(f)

    diff = diff_records(current_records, new_records)
    actions = []

    records = diff['add'] + diff['update'] + diff['remove']

    try:
        max_name_length = max(len(record.name) for record in records)
    except ValueError:
        max_name_length = 0

    for record in diff['add']:
        actions.append(Action(
            description=action_description('ADD', record, max_name_length),
            closure=functools.partial(provider.add, record),
        ))

    for record in diff['update']:
        actions.append(Action(
            description=action_description('UPDATE', record, max_name_length),
            closure=functools.partial(provider.update, record),
        ))

    for record in diff['remove']:
        actions.append(Action(
            description=action_description('REMOVE', record, max_name_length),
            closure=functools.partial(provider.delete, record),
        ))

    if not len(actions):
        println('All records up to date.')
    else:
        if not auto_commit:
            println('The following changes will be made:')

            for action in actions:
                println(action.description)

        commit = auto_commit or prompt_user_to_commit()

        if commit:
            for action in actions:
                println(action.description)
                action.apply()
        else:
            println('Abort.')
            sys.exit(1)

Example 4

Project: HealthStarter
Source File: writer.py
View license
    @classmethod
    def serialize(cls, value):
        """
        Serializes the value to a string that's parsable by Python, along
        with any needed imports to make that string work.
        More advanced than repr() as it can encode things
        like datetime.datetime.now.
        """
        # FIXME: Ideally Promise would be reconstructible, but for now we
        # use force_text on them and defer to the normal string serialization
        # process.
        if isinstance(value, Promise):
            value = force_text(value)
        elif isinstance(value, LazyObject):
            # The unwrapped value is returned as the first item of the
            # arguments tuple.
            value = value.__reduce__()[1][0]

        # Sequences
        if isinstance(value, (frozenset, list, set, tuple)):
            imports = set()
            strings = []
            for item in value:
                item_string, item_imports = cls.serialize(item)
                imports.update(item_imports)
                strings.append(item_string)
            if isinstance(value, set):
                # Don't use the literal "{%s}" as it doesn't support empty set
                format = "set([%s])"
            elif isinstance(value, frozenset):
                format = "frozenset([%s])"
            elif isinstance(value, tuple):
                # When len(value)==0, the empty tuple should be serialized as
                # "()", not "(,)" because (,) is invalid Python syntax.
                format = "(%s)" if len(value) != 1 else "(%s,)"
            else:
                format = "[%s]"
            return format % (", ".join(strings)), imports
        # Dictionaries
        elif isinstance(value, dict):
            imports = set()
            strings = []
            for k, v in sorted(value.items()):
                k_string, k_imports = cls.serialize(k)
                v_string, v_imports = cls.serialize(v)
                imports.update(k_imports)
                imports.update(v_imports)
                strings.append((k_string, v_string))
            return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports
        # Datetimes
        elif isinstance(value, datetime.datetime):
            value_repr = cls.serialize_datetime(value)
            imports = ["import datetime"]
            if value.tzinfo is not None:
                imports.append("from django.utils.timezone import utc")
            return value_repr, set(imports)
        # Dates
        elif isinstance(value, datetime.date):
            value_repr = repr(value)
            if isinstance(value, datetime_safe.date):
                value_repr = "datetime.%s" % value_repr
            return value_repr, {"import datetime"}
        # Times
        elif isinstance(value, datetime.time):
            value_repr = repr(value)
            if isinstance(value, datetime_safe.time):
                value_repr = "datetime.%s" % value_repr
            return value_repr, {"import datetime"}
        # Timedeltas
        elif isinstance(value, datetime.timedelta):
            return repr(value), {"import datetime"}
        # Settings references
        elif isinstance(value, SettingsReference):
            return "settings.%s" % value.setting_name, {"from django.conf import settings"}
        # Simple types
        elif isinstance(value, float):
            if math.isnan(value) or math.isinf(value):
                return 'float("{}")'.format(value), set()
            return repr(value), set()
        elif isinstance(value, six.integer_types + (bool, type(None))):
            return repr(value), set()
        elif isinstance(value, six.binary_type):
            value_repr = repr(value)
            if six.PY2:
                # Prepend the `b` prefix since we're importing unicode_literals
                value_repr = 'b' + value_repr
            return value_repr, set()
        elif isinstance(value, six.text_type):
            value_repr = repr(value)
            if six.PY2:
                # Strip the `u` prefix since we're importing unicode_literals
                value_repr = value_repr[1:]
            return value_repr, set()
        # Decimal
        elif isinstance(value, decimal.Decimal):
            return repr(value), {"from decimal import Decimal"}
        # Django fields
        elif isinstance(value, models.Field):
            attr_name, path, args, kwargs = value.deconstruct()
            return cls.serialize_deconstructed(path, args, kwargs)
        # Classes
        elif isinstance(value, type):
            special_cases = [
                (models.Model, "models.Model", []),
            ]
            for case, string, imports in special_cases:
                if case is value:
                    return string, set(imports)
            if hasattr(value, "__module__"):
                module = value.__module__
                if module == six.moves.builtins.__name__:
                    return value.__name__, set()
                else:
                    return "%s.%s" % (module, value.__name__), {"import %s" % module}
        elif isinstance(value, models.manager.BaseManager):
            as_manager, manager_path, qs_path, args, kwargs = value.deconstruct()
            if as_manager:
                name, imports = cls._serialize_path(qs_path)
                return "%s.as_manager()" % name, imports
            else:
                return cls.serialize_deconstructed(manager_path, args, kwargs)
        elif isinstance(value, Operation):
            string, imports = OperationWriter(value, indentation=0).serialize()
            # Nested operation, trailing comma is handled in upper OperationWriter._write()
            return string.rstrip(','), imports
        elif isinstance(value, functools.partial):
            imports = {'import functools'}
            # Serialize functools.partial() arguments
            func_string, func_imports = cls.serialize(value.func)
            args_string, args_imports = cls.serialize(value.args)
            keywords_string, keywords_imports = cls.serialize(value.keywords)
            # Add any imports needed by arguments
            imports.update(func_imports)
            imports.update(args_imports)
            imports.update(keywords_imports)
            return (
                "functools.partial(%s, *%s, **%s)" % (
                    func_string, args_string, keywords_string,
                ),
                imports,
            )
        # Anything that knows how to deconstruct itself.
        elif hasattr(value, 'deconstruct'):
            return cls.serialize_deconstructed(*value.deconstruct())
        # Functions
        elif isinstance(value, (types.FunctionType, types.BuiltinFunctionType)):
            # @classmethod?
            if getattr(value, "__self__", None) and isinstance(value.__self__, type):
                klass = value.__self__
                module = klass.__module__
                return "%s.%s.%s" % (module, klass.__name__, value.__name__), {"import %s" % module}
            # Further error checking
            if value.__name__ == '<lambda>':
                raise ValueError("Cannot serialize function: lambda")
            if value.__module__ is None:
                raise ValueError("Cannot serialize function %r: No module" % value)
            # Python 3 is a lot easier, and only uses this branch if it's not local.
            if getattr(value, "__qualname__", None) and getattr(value, "__module__", None):
                if "<" not in value.__qualname__:  # Qualname can include <locals>
                    return "%s.%s" % (value.__module__, value.__qualname__), {"import %s" % value.__module__}
            # Python 2/fallback version
            module_name = value.__module__
            # Make sure it's actually there and not an unbound method
            module = import_module(module_name)
            if not hasattr(module, value.__name__):
                raise ValueError(
                    "Could not find function %s in %s.\n"
                    "Please note that due to Python 2 limitations, you cannot "
                    "serialize unbound method functions (e.g. a method "
                    "declared and used in the same class body). Please move "
                    "the function into the main module body to use migrations.\n"
                    "For more information, see "
                    "https://docs.djangoproject.com/en/%s/topics/migrations/#serializing-values"
                    % (value.__name__, module_name, get_docs_version()))
            # Needed on Python 2 only
            if module_name == '__builtin__':
                return value.__name__, set()
            return "%s.%s" % (module_name, value.__name__), {"import %s" % module_name}
        # Other iterables
        elif isinstance(value, collections.Iterable):
            imports = set()
            strings = []
            for item in value:
                item_string, item_imports = cls.serialize(item)
                imports.update(item_imports)
                strings.append(item_string)
            # When len(strings)==0, the empty iterable should be serialized as
            # "()", not "(,)" because (,) is invalid Python syntax.
            format = "(%s)" if len(strings) != 1 else "(%s,)"
            return format % (", ".join(strings)), imports
        # Compiled regex
        elif isinstance(value, (COMPILED_REGEX_TYPE, RegexObject)):
            imports = {"import re"}
            regex_pattern, pattern_imports = cls.serialize(value.pattern)
            regex_flags, flag_imports = cls.serialize(value.flags)
            imports.update(pattern_imports)
            imports.update(flag_imports)
            args = [regex_pattern]
            if value.flags:
                args.append(regex_flags)
            return "re.compile(%s)" % ', '.join(args), imports
        # Uh oh.
        else:
            raise ValueError(
                "Cannot serialize: %r\nThere are some values Django cannot serialize into "
                "migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
                "topics/migrations/#migration-serializing" % (value, get_docs_version())
            )

Example 5

Project: pyramid_redis_sessions
Source File: __init__.py
View license
def RedisSessionFactory(
    secret,
    timeout=1200,
    cookie_name='session',
    cookie_max_age=None,
    cookie_path='/',
    cookie_domain=None,
    cookie_secure=False,
    cookie_httponly=True,
    cookie_on_exception=True,
    url=None,
    host='localhost',
    port=6379,
    db=0,
    password=None,
    socket_timeout=None,
    connection_pool=None,
    encoding='utf-8',
    encoding_errors='strict',
    unix_socket_path=None,
    client_callable=None,
    serialize=cPickle.dumps,
    deserialize=cPickle.loads,
    id_generator=_generate_session_id,
    ):
    """
    Constructs and returns a session factory that will provide session data
    from a Redis server. The returned factory can be supplied as the
    ``session_factory`` argument of a :class:`pyramid.config.Configurator`
    constructor, or used as the ``session_factory`` argument of the
    :meth:`pyramid.config.Configurator.set_session_factory` method.

    Parameters:

    ``secret``
    A string which is used to sign the cookie.

    ``timeout``
    A number of seconds of inactivity before a session times out.

    ``cookie_name``
    The name of the cookie used for sessioning. Default: ``session``.

    ``cookie_max_age``
    The maximum age of the cookie used for sessioning (in seconds).
    Default: ``None`` (browser scope).

    ``cookie_path``
    The path used for the session cookie. Default: ``/``.

    ``cookie_domain``
    The domain used for the session cookie. Default: ``None`` (no domain).

    ``cookie_secure``
    The 'secure' flag of the session cookie. Default: ``False``.

    ``cookie_httponly``
    The 'httpOnly' flag of the session cookie. Default: ``True``.

    ``cookie_on_exception``
    If ``True``, set a session cookie even if an exception occurs
    while rendering a view. Default: ``True``.

    ``url``
    A connection string for a Redis server, in the format:
    redis://username:[email protected]:6379/0
    Default: ``None``.

    ``host``
    A string representing the IP of your Redis server. Default: ``localhost``.

    ``port``
    An integer representing the port of your Redis server. Default: ``6379``.

    ``db``
    An integer to select a specific database on your Redis server.
    Default: ``0``

    ``password``
    A string password to connect to your Redis server/database if
    required. Default: ``None``.

    ``client_callable``
    A python callable that accepts a Pyramid `request` and Redis config options
    and returns a Redis client such as redis-py's `StrictRedis`.
    Default: ``None``.

    ``serialize``
    A function to serialize the session dict for storage in Redis.
    Default: ``cPickle.dumps``.

    ``deserialize``
    A function to deserialize the stored session data in Redis.
    Default: ``cPickle.loads``.

    ``id_generator``
    A function to create a unique ID to be used as the session key when a
    session is first created.
    Default: private function that uses sha1 with the time and random elements
    to create a 40 character unique ID.

    The following arguments are also passed straight to the ``StrictRedis``
    constructor and allow you to further configure the Redis client::

      socket_timeout
      connection_pool
      encoding
      encoding_errors
      unix_socket_path
    """
    def factory(request, new_session_id=get_unique_session_id):
        redis_options = dict(
            host=host,
            port=port,
            db=db,
            password=password,
            socket_timeout=socket_timeout,
            connection_pool=connection_pool,
            encoding=encoding,
            encoding_errors=encoding_errors,
            unix_socket_path=unix_socket_path,
            )

        # an explicit client callable gets priority over the default
        redis = client_callable(request, **redis_options) \
            if client_callable is not None \
            else get_default_connection(request, url=url, **redis_options)

        # attempt to retrieve a session_id from the cookie
        session_id_from_cookie = _get_session_id_from_cookie(
            request=request,
            cookie_name=cookie_name,
            secret=secret,
            )

        new_session = functools.partial(
            new_session_id,
            redis=redis,
            timeout=timeout,
            serialize=serialize,
            generator=id_generator,
            )

        if session_id_from_cookie and redis.exists(session_id_from_cookie):
            session_id = session_id_from_cookie
            session_cookie_was_valid = True
        else:
            session_id = new_session()
            session_cookie_was_valid = False

        session = RedisSession(
            redis=redis,
            session_id=session_id,
            new=not session_cookie_was_valid,
            new_session=new_session,
            serialize=serialize,
            deserialize=deserialize,
            )

        set_cookie = functools.partial(
            _set_cookie,
            session,
            cookie_name=cookie_name,
            cookie_max_age=cookie_max_age,
            cookie_path=cookie_path,
            cookie_domain=cookie_domain,
            cookie_secure=cookie_secure,
            cookie_httponly=cookie_httponly,
            secret=secret,
            )
        delete_cookie = functools.partial(
            _delete_cookie,
            cookie_name=cookie_name,
            cookie_path=cookie_path,
            cookie_domain=cookie_domain,
            )
        cookie_callback = functools.partial(
            _cookie_callback,
            session,
            session_cookie_was_valid=session_cookie_was_valid,
            cookie_on_exception=cookie_on_exception,
            set_cookie=set_cookie,
            delete_cookie=delete_cookie,
            )
        request.add_response_callback(cookie_callback)

        return session

    return factory

Example 6

Project: golismero
Source File: target.py
View license
def _setRequestParams():
    """
    Check and set the parameters and perform checks on 'data' option for
    HTTP method POST.
    """

    if conf.direct:
        conf.parameters[None] = "direct connection"
        return

    testableParameters = False

    # Perform checks on GET parameters
    if conf.parameters.get(PLACE.GET):
        parameters = conf.parameters[PLACE.GET]
        paramDict = paramToDict(PLACE.GET, parameters)

        if paramDict:
            conf.paramDict[PLACE.GET] = paramDict
            testableParameters = True

    # Perform checks on POST parameters
    if conf.method == HTTPMETHOD.POST and conf.data is None:
        errMsg = "HTTP POST method depends on HTTP data value to be posted"
        raise SqlmapSyntaxException(errMsg)

    if conf.data is not None:
        conf.method = HTTPMETHOD.POST if not conf.method or conf.method == HTTPMETHOD.GET else conf.method

        def process(match, repl):
            retVal = match.group(0)

            if not (conf.testParameter and match.group("name") not in conf.testParameter):
                retVal = repl
                while True:
                    _ = re.search(r"\\g<([^>]+)>", retVal)
                    if _:
                        retVal = retVal.replace(_.group(0), match.group(int(_.group(1)) if _.group(1).isdigit() else _.group(1)))
                    else:
                        break

            return retVal

        if kb.processUserMarks is None and CUSTOM_INJECTION_MARK_CHAR in conf.data:
            message = "custom injection marking character ('%s') found in option " % CUSTOM_INJECTION_MARK_CHAR
            message += "'--data'. Do you want to process it? [Y/n/q] "
            test = readInput(message, default="Y")
            if test and test[0] in ("q", "Q"):
                raise SqlmapUserQuitException
            else:
                kb.processUserMarks = not test or test[0] not in ("n", "N")

        if not (kb.processUserMarks and CUSTOM_INJECTION_MARK_CHAR in conf.data):
            if re.search(JSON_RECOGNITION_REGEX, conf.data):
                message = "JSON like data found in %s data. " % conf.method
                message += "Do you want to process it? [Y/n/q] "
                test = readInput(message, default="Y")
                if test and test[0] in ("q", "Q"):
                    raise SqlmapUserQuitException
                elif test[0] not in ("n", "N"):
                    conf.data = conf.data.replace(CUSTOM_INJECTION_MARK_CHAR, ASTERISK_MARKER)
                    conf.data = re.sub(r'("(?P<name>[^"]+)"\s*:\s*"[^"]+)"', functools.partial(process, repl=r'\g<1>%s"' % CUSTOM_INJECTION_MARK_CHAR), conf.data)
                    conf.data = re.sub(r'("(?P<name>[^"]+)"\s*:\s*)(-?\d[\d\.]*\b)', functools.partial(process, repl=r'\g<0>%s' % CUSTOM_INJECTION_MARK_CHAR), conf.data)
                    kb.postHint = POST_HINT.JSON

            elif re.search(SOAP_RECOGNITION_REGEX, conf.data):
                message = "SOAP/XML like data found in %s data. " % conf.method
                message += "Do you want to process it? [Y/n/q] "
                test = readInput(message, default="Y")
                if test and test[0] in ("q", "Q"):
                    raise SqlmapUserQuitException
                elif test[0] not in ("n", "N"):
                    conf.data = conf.data.replace(CUSTOM_INJECTION_MARK_CHAR, ASTERISK_MARKER)
                    conf.data = re.sub(r"(<(?P<name>[^>]+)( [^<]*)?>)([^<]+)(</\2)", functools.partial(process, repl=r"\g<1>\g<4>%s\g<5>" % CUSTOM_INJECTION_MARK_CHAR), conf.data)
                    kb.postHint = POST_HINT.SOAP if "soap" in conf.data.lower() else POST_HINT.XML

            elif re.search(MULTIPART_RECOGNITION_REGEX, conf.data):
                message = "Multipart like data found in %s data. " % conf.method
                message += "Do you want to process it? [Y/n/q] "
                test = readInput(message, default="Y")
                if test and test[0] in ("q", "Q"):
                    raise SqlmapUserQuitException
                elif test[0] not in ("n", "N"):
                    conf.data = conf.data.replace(CUSTOM_INJECTION_MARK_CHAR, ASTERISK_MARKER)
                    conf.data = re.sub(r"(?si)((Content-Disposition[^\n]+?name\s*=\s*[\"'](?P<name>[^\n]+?)[\"']).+?)(((\r)?\n)+--)", functools.partial(process, repl=r"\g<1>%s\g<4>" % CUSTOM_INJECTION_MARK_CHAR), conf.data)
                    kb.postHint = POST_HINT.MULTIPART

        if not kb.postHint:
            if CUSTOM_INJECTION_MARK_CHAR in conf.data:  # later processed
                pass
            else:
                place = PLACE.POST

                conf.parameters[place] = conf.data
                paramDict = paramToDict(place, conf.data)

                if paramDict:
                    conf.paramDict[place] = paramDict
                    testableParameters = True
        else:
            if CUSTOM_INJECTION_MARK_CHAR not in conf.data:  # in case that no usable parameter values has been found
                conf.parameters[PLACE.POST] = conf.data

    kb.processUserMarks = True if (kb.postHint and CUSTOM_INJECTION_MARK_CHAR in conf.data) else kb.processUserMarks

    if re.search(URI_INJECTABLE_REGEX, conf.url, re.I) and not any(place in conf.parameters for place in (PLACE.GET, PLACE.POST)) and not kb.postHint and not CUSTOM_INJECTION_MARK_CHAR in (conf.data or ""):
        warnMsg = "you've provided target URL without any GET "
        warnMsg += "parameters (e.g. www.site.com/article.php?id=1) "
        warnMsg += "and without providing any POST parameters "
        warnMsg += "through --data option"
        logger.warn(warnMsg)

        message = "do you want to try URI injections "
        message += "in the target URL itself? [Y/n/q] "
        test = readInput(message, default="Y")

        if not test or test[0] not in ("n", "N"):
            conf.url = "%s%s" % (conf.url, CUSTOM_INJECTION_MARK_CHAR)
            kb.processUserMarks = True
        elif test[0] in ("q", "Q"):
            raise SqlmapUserQuitException

    for place, value in ((PLACE.URI, conf.url), (PLACE.CUSTOM_POST, conf.data), (PLACE.CUSTOM_HEADER, str(conf.httpHeaders))):
        _ = re.sub(PROBLEMATIC_CUSTOM_INJECTION_PATTERNS, "", value or "") if place == PLACE.CUSTOM_HEADER else value or ""
        if CUSTOM_INJECTION_MARK_CHAR in _:
            if kb.processUserMarks is None:
                lut = {PLACE.URI: '-u', PLACE.CUSTOM_POST: '--data', PLACE.CUSTOM_HEADER: '--headers/--user-agent/--referer/--cookie'}
                message = "custom injection marking character ('%s') found in option " % CUSTOM_INJECTION_MARK_CHAR
                message += "'%s'. Do you want to process it? [Y/n/q] " % lut[place]
                test = readInput(message, default="Y")
                if test and test[0] in ("q", "Q"):
                    raise SqlmapUserQuitException
                else:
                    kb.processUserMarks = not test or test[0] not in ("n", "N")

            if not kb.processUserMarks:
                if place == PLACE.URI:
                    query = urlparse.urlsplit(value).query
                    if query:
                        parameters = conf.parameters[PLACE.GET] = query
                        paramDict = paramToDict(PLACE.GET, parameters)

                        if paramDict:
                            conf.url = conf.url.split('?')[0]
                            conf.paramDict[PLACE.GET] = paramDict
                            testableParameters = True
                elif place == PLACE.CUSTOM_POST:
                    conf.parameters[PLACE.POST] = conf.data
                    paramDict = paramToDict(PLACE.POST, conf.data)

                    if paramDict:
                        conf.paramDict[PLACE.POST] = paramDict
                        testableParameters = True

            else:
                conf.parameters[place] = value
                conf.paramDict[place] = OrderedDict()

                if place == PLACE.CUSTOM_HEADER:
                    for index in xrange(len(conf.httpHeaders)):
                        header, value = conf.httpHeaders[index]
                        if CUSTOM_INJECTION_MARK_CHAR in re.sub(PROBLEMATIC_CUSTOM_INJECTION_PATTERNS, "", value):
                            parts = value.split(CUSTOM_INJECTION_MARK_CHAR)
                            for i in xrange(len(parts) - 1):
                                conf.paramDict[place]["%s #%d%s" % (header, i + 1, CUSTOM_INJECTION_MARK_CHAR)] = "%s,%s" % (header, "".join("%s%s" % (parts[j], CUSTOM_INJECTION_MARK_CHAR if i == j else "") for j in xrange(len(parts))))
                            conf.httpHeaders[index] = (header, value.replace(CUSTOM_INJECTION_MARK_CHAR, ""))
                else:
                    parts = value.split(CUSTOM_INJECTION_MARK_CHAR)

                    for i in xrange(len(parts) - 1):
                        conf.paramDict[place]["%s#%d%s" % (("%s " % kb.postHint) if kb.postHint else "", i + 1, CUSTOM_INJECTION_MARK_CHAR)] = "".join("%s%s" % (parts[j], CUSTOM_INJECTION_MARK_CHAR if i == j else "") for j in xrange(len(parts)))

                    if place == PLACE.URI and PLACE.GET in conf.paramDict:
                        del conf.paramDict[PLACE.GET]
                    elif place == PLACE.CUSTOM_POST and PLACE.POST in conf.paramDict:
                        del conf.paramDict[PLACE.POST]

                testableParameters = True

    if kb.processUserMarks:
        for item in ("url", "data", "agent", "referer", "cookie"):
            if conf.get(item):
                conf[item] = conf[item].replace(CUSTOM_INJECTION_MARK_CHAR, "")

    # Perform checks on Cookie parameters
    if conf.cookie:
        conf.parameters[PLACE.COOKIE] = conf.cookie
        paramDict = paramToDict(PLACE.COOKIE, conf.cookie)

        if paramDict:
            conf.paramDict[PLACE.COOKIE] = paramDict
            testableParameters = True

    # Perform checks on header values
    if conf.httpHeaders:
        for httpHeader, headerValue in conf.httpHeaders:
            # Url encoding of the header values should be avoided
            # Reference: http://stackoverflow.com/questions/5085904/is-ok-to-urlencode-the-value-in-headerlocation-value

            httpHeader = httpHeader.title()

            if httpHeader == HTTP_HEADER.USER_AGENT:
                conf.parameters[PLACE.USER_AGENT] = urldecode(headerValue)

                condition = any((not conf.testParameter, intersect(conf.testParameter, USER_AGENT_ALIASES)))

                if condition:
                    conf.paramDict[PLACE.USER_AGENT] = {PLACE.USER_AGENT: headerValue}
                    testableParameters = True

            elif httpHeader == HTTP_HEADER.REFERER:
                conf.parameters[PLACE.REFERER] = urldecode(headerValue)

                condition = any((not conf.testParameter, intersect(conf.testParameter, REFERER_ALIASES)))

                if condition:
                    conf.paramDict[PLACE.REFERER] = {PLACE.REFERER: headerValue}
                    testableParameters = True

            elif httpHeader == HTTP_HEADER.HOST:
                conf.parameters[PLACE.HOST] = urldecode(headerValue)

                condition = any((not conf.testParameter, intersect(conf.testParameter, HOST_ALIASES)))

                if condition:
                    conf.paramDict[PLACE.HOST] = {PLACE.HOST: headerValue}
                    testableParameters = True

    if not conf.parameters:
        errMsg = "you did not provide any GET, POST and Cookie "
        errMsg += "parameter, neither an User-Agent, Referer or Host header value"
        raise SqlmapGenericException(errMsg)

    elif not testableParameters:
        errMsg = "all testable parameters you provided are not present "
        errMsg += "within the given request data"
        raise SqlmapGenericException(errMsg)

Example 7

Project: color_trace
Source File: color_trace_multi.py
View license
def get_args(cmdargs=None):
    """return parser and namespace of parsed command-line arguments

    cmdargs: if specified, a list of command-line arguments to use instead of
        those provided to this script (i.e. a string that has been shlex.split)
"""
    parser = argparse.ArgumentParser(description="trace a color image with "
        "potrace, output color SVG file", add_help=False, prefix_chars='-/')
    # help also accessible via /?
    parser.add_argument(
        '-h', '--help', '/?',
        action='help',
        help="show this help message and exit")
    # file io arguments
    parser.add_argument('-i',
        '--input', metavar='src', nargs='+', required=True,
        help="path of input image(s) to trace, supports * and ? wildcards")
    parser.add_argument('-o',
        '--output', metavar='dest',
        help="path of output image to save to, supports * wildcard")
    parser.add_argument('-d',
        '--directory', metavar='destdir',
        help="outputs to destdir")
    # processing arguments
    parser.add_argument('-C',
        '--cores', metavar='N',
        type=functools.partial(check_range, 0, None, int, "an integer"),
        help="number of cores to use for image processing. "
             "Ignored if processing a single file with 1 color "
             "(default tries to use all cores)")
    # color trace options
    #make colors & palette mutually exclusive
    color_palette_group = parser.add_mutually_exclusive_group(required=True)
    color_palette_group.add_argument('-c',
        '--colors', metavar='N',
        type=functools.partial(check_range, 0, 256, int, "an integer"),
        help="[required unless -p is used instead] "
             "number of colors to reduce each image to before tracing, up to 256. "
             "Value of 0 skips color reduction (not recommended unless images "
             "are already color-reduced)")
    parser.add_argument('-q',
        '--quantization', metavar='algorithm',
        choices=('mc','as','nq'), default='mc',
        help="color quantization algorithm: mc, as, or nq. "
            "'mc' (Median-Cut, default); "
            "'as' (Adaptive Spatial Subdivision, may result in fewer colors); "
            "'nq' (NeuQuant, for hundreds of colors). Disabled if --colors 0")
    #make --floydsteinberg and --riemersma dithering mutually exclusive
    dither_group = parser.add_mutually_exclusive_group()
    dither_group.add_argument('-fs',
        '--floydsteinberg', action='store_true',
        help="enable Floyd-Steinberg dithering (for any quantization or -p/--palette)."
            " Warning: any dithering will greatly increase output svg's size and complexity.")
    dither_group.add_argument('-ri',
        '--riemersma', action='store_true',
        help="enable Rimersa dithering (only for Adaptive Spatial Subdivision quantization or -p/--palette)")
    color_palette_group.add_argument('-r',
        '--remap', metavar='paletteimg',
        help=("use a custom palette image for color reduction [overrides -c "
              "and -q]"))
    # image options
    parser.add_argument('-s',
        '--stack',
        action='store_true',
        help="stack color traces (recommended for more accurate output)")
    parser.add_argument('-p',
        '--prescale', metavar='size',
        type=functools.partial(check_range, 0, None, float, "a floating-point number"), default=2,
        help="scale image this much before tracing for greater detail (default: 2). "
            "The image's output size is not changed. (2 is recommended, or 3 for smaller "
            "details.)")
    # potrace options
    parser.add_argument('-D',
        '--despeckle', metavar='size',
        type=functools.partial(check_range, 0, None, int, "an integer"), default=2,
        help='supress speckles of this many pixels (default: 2)')
    parser.add_argument('-S',
        '--smoothcorners', metavar='threshold',
        type=functools.partial(check_range, 0, 1.334, float, "a floating-point number"), default=1.0,
        help="set corner smoothing: 0 for no smoothing, 1.334 for max "
            "(default: 1.0)")
    parser.add_argument('-O',
        '--optimizepaths', metavar='tolerance',
        type=functools.partial(check_range, 0, 5, float, "a floating-point number"), default=0.2,
        help="set Bezier curve optimization: 0 for least, 5 for most "
              "(default: 0.2)")
    # other options
    parser.add_argument('-v',
        '--verbose', action='store_true',
        help="print details about commands executed by this script")
    parser.add_argument('--version', action='version',
        version='%(prog)s {ver}'.format(ver=VERSION))

    if cmdargs is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(cmdargs)

    # with multiple inputs, --output must use at least one * wildcard
    multi_inputs = False
    for i, input_ in enumerate(get_inputs_outputs(args.input)):
        if i:
            multi_inputs = True
            break
    if multi_inputs and args.output is not None and '*' not in args.output:
        parser.error("argument -o/--output: must contain '*' wildcard when using multiple input files")

    # 'riemersma' dithering is only allowed with 'as' quantization or --palette option
    if args.riemersma:
        if args.quantization != 'as' and args.palette is None:
            parser.error("argument -ri/--riemersma: only allowed with 'as' quantization")

    return args

Example 8

Project: color_trace
Source File: color_trace_multi.py
View license
def get_args(cmdargs=None):
    """return parser and namespace of parsed command-line arguments

    cmdargs: if specified, a list of command-line arguments to use instead of
        those provided to this script (i.e. a string that has been shlex.split)
"""
    parser = argparse.ArgumentParser(description="trace a color image with "
        "potrace, output color SVG file", add_help=False, prefix_chars='-/')
    # help also accessible via /?
    parser.add_argument(
        '-h', '--help', '/?',
        action='help',
        help="show this help message and exit")
    # file io arguments
    parser.add_argument('-i',
        '--input', metavar='src', nargs='+', required=True,
        help="path of input image(s) to trace, supports * and ? wildcards")
    parser.add_argument('-o',
        '--output', metavar='dest',
        help="path of output image to save to, supports * wildcard")
    parser.add_argument('-d',
        '--directory', metavar='destdir',
        help="outputs to destdir")
    # processing arguments
    parser.add_argument('-C',
        '--cores', metavar='N',
        type=functools.partial(check_range, 0, None, int, "an integer"),
        help="number of cores to use for image processing. "
             "Ignored if processing a single file with 1 color "
             "(default tries to use all cores)")
    # color trace options
    #make colors & palette mutually exclusive
    color_palette_group = parser.add_mutually_exclusive_group(required=True)
    color_palette_group.add_argument('-c',
        '--colors', metavar='N',
        type=functools.partial(check_range, 0, 256, int, "an integer"),
        help="[required unless -p is used instead] "
             "number of colors to reduce each image to before tracing, up to 256. "
             "Value of 0 skips color reduction (not recommended unless images "
             "are already color-reduced)")
    parser.add_argument('-q',
        '--quantization', metavar='algorithm',
        choices=('mc','as','nq'), default='mc',
        help="color quantization algorithm: mc, as, or nq. "
            "'mc' (Median-Cut, default); "
            "'as' (Adaptive Spatial Subdivision, may result in fewer colors); "
            "'nq' (NeuQuant, for hundreds of colors). Disabled if --colors 0")
    #make --floydsteinberg and --riemersma dithering mutually exclusive
    dither_group = parser.add_mutually_exclusive_group()
    dither_group.add_argument('-fs',
        '--floydsteinberg', action='store_true',
        help="enable Floyd-Steinberg dithering (for any quantization or -p/--palette)."
            " Warning: any dithering will greatly increase output svg's size and complexity.")
    dither_group.add_argument('-ri',
        '--riemersma', action='store_true',
        help="enable Rimersa dithering (only for Adaptive Spatial Subdivision quantization or -p/--palette)")
    color_palette_group.add_argument('-r',
        '--remap', metavar='paletteimg',
        help=("use a custom palette image for color reduction [overrides -c "
              "and -q]"))
    # image options
    parser.add_argument('-s',
        '--stack',
        action='store_true',
        help="stack color traces (recommended for more accurate output)")
    parser.add_argument('-p',
        '--prescale', metavar='size',
        type=functools.partial(check_range, 0, None, float, "a floating-point number"), default=2,
        help="scale image this much before tracing for greater detail (default: 2). "
            "The image's output size is not changed. (2 is recommended, or 3 for smaller "
            "details.)")
    # potrace options
    parser.add_argument('-D',
        '--despeckle', metavar='size',
        type=functools.partial(check_range, 0, None, int, "an integer"), default=2,
        help='supress speckles of this many pixels (default: 2)')
    parser.add_argument('-S',
        '--smoothcorners', metavar='threshold',
        type=functools.partial(check_range, 0, 1.334, float, "a floating-point number"), default=1.0,
        help="set corner smoothing: 0 for no smoothing, 1.334 for max "
            "(default: 1.0)")
    parser.add_argument('-O',
        '--optimizepaths', metavar='tolerance',
        type=functools.partial(check_range, 0, 5, float, "a floating-point number"), default=0.2,
        help="set Bezier curve optimization: 0 for least, 5 for most "
              "(default: 0.2)")
    # other options
    parser.add_argument('-v',
        '--verbose', action='store_true',
        help="print details about commands executed by this script")
    parser.add_argument('--version', action='version',
        version='%(prog)s {ver}'.format(ver=VERSION))

    if cmdargs is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(cmdargs)

    # with multiple inputs, --output must use at least one * wildcard
    multi_inputs = False
    for i, input_ in enumerate(get_inputs_outputs(args.input)):
        if i:
            multi_inputs = True
            break
    if multi_inputs and args.output is not None and '*' not in args.output:
        parser.error("argument -o/--output: must contain '*' wildcard when using multiple input files")

    # 'riemersma' dithering is only allowed with 'as' quantization or --palette option
    if args.riemersma:
        if args.quantization != 'as' and args.palette is None:
            parser.error("argument -ri/--riemersma: only allowed with 'as' quantization")

    return args

Example 9

Project: treadmill
Source File: _base_service.py
View license
    def run(self, watchdogs_dir, *impl_args, **impl_kwargs):
        """Run the service."""
        # Load the implementation
        if self._service_class is None:
            self._service_class = self._load_impl()
        impl = self._service_class(*impl_args, **impl_kwargs)

        # Setup the watchdog
        watchdogs = watchdog.Watchdog(os.path.realpath(watchdogs_dir))
        watchdog_lease = watchdogs.create(
            name='svc-{svc_name}'.format(svc_name=self.name),
            timeout='{hb:d}s'.format(hb=impl.WATCHDOG_HEARTBEAT_SEC),
            content='Service %r failed' % self.name
        )

        # Create the status socket
        ss = self._create_status_socket()

        # Run initialization
        impl.initialize(self._dir)

        watcher = idirwatch.DirWatcher(self._rsrc_dir)
        # Call all the callbacks with the implementation instance
        watcher.on_created = functools.partial(self._on_created, impl)
        watcher.on_deleted = functools.partial(self._on_deleted, impl)
        # NOTE: A modified request is treated as a brand new request
        watcher.on_modified = functools.partial(self._on_created, impl)
        self._io_eventfd = eventfd.eventfd(0, eventfd.EFD_CLOEXEC)

        # Before starting, check the request directory
        svcs = self._check_requests()
        # and "fake" a created event on all the existing requests
        for existing_svcs in svcs:
            self._on_created(impl, existing_svcs)

        # Before starting, make sure backend state and service state are
        # synchronized.
        impl.synchronize()

        # Report service status
        status_info = {}
        status_info.update(impl.report_status())

        # Setup the poll object
        loop_poll = select.poll()
        loop_callbacks = {}

        base_event_handlers = [
            (
                self._io_eventfd,
                select.POLLIN,
                functools.partial(
                    self._handle_queued_io_events,
                    watcher=watcher,
                    impl=impl,
                )
            ),
            (
                watcher.inotify,
                select.POLLIN,
                functools.partial(
                    self._handle_io_events,
                    watcher=watcher,
                    impl=impl,
                )
            ),
            (
                ss,
                select.POLLIN,
                functools.partial(
                    self._publish_status,
                    status_socket=ss,
                    status_info=status_info,
                )
            ),
        ]
        # Initial collection of implementation' event handlers
        impl_event_handlers = impl.event_handlers()

        self._update_poll_registration(
            loop_poll,
            loop_callbacks,
            base_event_handlers + impl_event_handlers,
        )

        loop_timeout = impl.WATCHDOG_HEARTBEAT_SEC/2
        while not self._is_dead:

            # Check for events
            updated = self._run_events(
                loop_poll,
                loop_timeout,
                loop_callbacks,
            )

            if updated:
                # Report service status
                status_info.clear()
                status_info.update(impl.report_status())

                # Update poll registration if needed
                impl_event_handlers = impl.event_handlers()
                self._update_poll_registration(
                    loop_poll, loop_callbacks,
                    base_event_handlers + impl_event_handlers,
                )

            # Clean up stale requests
            self._check_requests()

            # Heartbeat
            watchdog_lease.heartbeat()

        _LOGGER.info('Shuting down %r service', self.name)
        # Remove the service heartbeat
        watchdog_lease.remove()

Example 10

Project: treadmill
Source File: _base_service.py
View license
    def run(self, watchdogs_dir, *impl_args, **impl_kwargs):
        """Run the service."""
        # Load the implementation
        if self._service_class is None:
            self._service_class = self._load_impl()
        impl = self._service_class(*impl_args, **impl_kwargs)

        # Setup the watchdog
        watchdogs = watchdog.Watchdog(os.path.realpath(watchdogs_dir))
        watchdog_lease = watchdogs.create(
            name='svc-{svc_name}'.format(svc_name=self.name),
            timeout='{hb:d}s'.format(hb=impl.WATCHDOG_HEARTBEAT_SEC),
            content='Service %r failed' % self.name
        )

        # Create the status socket
        ss = self._create_status_socket()

        # Run initialization
        impl.initialize(self._dir)

        watcher = idirwatch.DirWatcher(self._rsrc_dir)
        # Call all the callbacks with the implementation instance
        watcher.on_created = functools.partial(self._on_created, impl)
        watcher.on_deleted = functools.partial(self._on_deleted, impl)
        # NOTE: A modified request is treated as a brand new request
        watcher.on_modified = functools.partial(self._on_created, impl)
        self._io_eventfd = eventfd.eventfd(0, eventfd.EFD_CLOEXEC)

        # Before starting, check the request directory
        svcs = self._check_requests()
        # and "fake" a created event on all the existing requests
        for existing_svcs in svcs:
            self._on_created(impl, existing_svcs)

        # Before starting, make sure backend state and service state are
        # synchronized.
        impl.synchronize()

        # Report service status
        status_info = {}
        status_info.update(impl.report_status())

        # Setup the poll object
        loop_poll = select.poll()
        loop_callbacks = {}

        base_event_handlers = [
            (
                self._io_eventfd,
                select.POLLIN,
                functools.partial(
                    self._handle_queued_io_events,
                    watcher=watcher,
                    impl=impl,
                )
            ),
            (
                watcher.inotify,
                select.POLLIN,
                functools.partial(
                    self._handle_io_events,
                    watcher=watcher,
                    impl=impl,
                )
            ),
            (
                ss,
                select.POLLIN,
                functools.partial(
                    self._publish_status,
                    status_socket=ss,
                    status_info=status_info,
                )
            ),
        ]
        # Initial collection of implementation' event handlers
        impl_event_handlers = impl.event_handlers()

        self._update_poll_registration(
            loop_poll,
            loop_callbacks,
            base_event_handlers + impl_event_handlers,
        )

        loop_timeout = impl.WATCHDOG_HEARTBEAT_SEC/2
        while not self._is_dead:

            # Check for events
            updated = self._run_events(
                loop_poll,
                loop_timeout,
                loop_callbacks,
            )

            if updated:
                # Report service status
                status_info.clear()
                status_info.update(impl.report_status())

                # Update poll registration if needed
                impl_event_handlers = impl.event_handlers()
                self._update_poll_registration(
                    loop_poll, loop_callbacks,
                    base_event_handlers + impl_event_handlers,
                )

            # Clean up stale requests
            self._check_requests()

            # Heartbeat
            watchdog_lease.heartbeat()

        _LOGGER.info('Shuting down %r service', self.name)
        # Remove the service heartbeat
        watchdog_lease.remove()

Example 11

Project: qmap
Source File: qmap.py
View license
    def initGui(self):
        """
        Create all the icons and setup the tool bars.  Called by QGIS when
        loading. This is called before setupUI.
        """
        QApplication.setWindowIcon(QIcon(":/branding/logo"))
        self.mainwindow.findChildren(QMenuBar)[0].setVisible(False)
        self.mainwindow.setContextMenuPolicy(Qt.PreventContextMenu)
        self.mainwindow.setWindowTitle("IntraMaps Roam: Mobile Data Collection")
        
        # Disable QGIS logging window popups. We do our own logging
        QgsMessageLog.instance().messageReceived.disconnect()
        
        s = """
        QToolButton { 
            padding: 6px;
            color: #4f4f4f;
         }
        
        QToolButton:hover { 
            padding: 6px;
            background-color: rgb(211, 228, 255);
         }
         
        QToolBar {
         background: white;
        }
        
        QCheckBox::indicator {
             width: 40px;
             height: 40px;
         }
        
        QLabel {
            color: #4f4f4f;
        }
        
        QDialog { background-color: rgb(255, 255, 255); }
        
        QPushButton { 
            border: 1px solid #e1e1e1;
             padding: 6px;
            color: #4f4f4f;
         }
        
        QPushButton:hover { 
            border: 1px solid #e1e1e1;
             padding: 6px;
            background-color: rgb(211, 228, 255);
         }
        
        QCheckBox {
            color: #4f4f4f;
        }
        
        QComboBox::drop-down {
            width: 30px;
        }


        """
        self.mainwindow.setStyleSheet(s)
        
        mainwidget = self.mainwindow.centralWidget()
        mainwidget.setLayout(QGridLayout())
        mainwidget.layout().setContentsMargins(0,0,0,0)
        
        newlayout = QGridLayout()
        newlayout.setContentsMargins(0,0,0,0)
        newlayout.addWidget(self.iface.mapCanvas(), 0,0,2,1)
        newlayout.addWidget(self.iface.messageBar(), 0,0,1,1)
        
        wid = QWidget()
        wid.setLayout(newlayout)
        
        self.stack = QStackedWidget()
        self.messageBar = QgsMessageBar(wid)
        self.messageBar.setSizePolicy( QSizePolicy.Minimum, QSizePolicy.Fixed )
        self.errorreport = PopDownReport(self.messageBar)
         
        mainwidget.layout().addWidget(self.stack, 0,0,2,1)
        mainwidget.layout().addWidget(self.messageBar, 0,0,1,1)
        
        self.helppage = HelpPage()
        helppath = os.path.join(os.path.dirname(__file__) , 'help',"help.html")
        self.helppage.setHelpPage(helppath)
        
        self.projectwidget = ProjectsWidget()   
        self.projectwidget.requestOpenProject.connect(self.loadProject)
        self.stack.addWidget(wid)
        self.stack.addWidget(self.projectwidget)
        self.stack.addWidget(self.helppage)
                
        sys.excepthook = self.excepthook

        def createSpacer(width=30):
            widget = QWidget()
            widget.setMinimumWidth(width)
            return widget

        self.createToolBars()
        self.createActions()

        spacewidget = createSpacer(60)
        gpsspacewidget = createSpacer()
        gpsspacewidget.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)

        self.moveaction.toggled.connect(functools.partial(self.setMapTool, self.movetool))
        showediting = (functools.partial(self.editingtoolbar.showToolbar, 
                                         self.editingmodeaction,
                                         self.editattributesaction))
        self.editingmodeaction.toggled.connect(showediting)

        self.addatgpsaction.triggered.connect(self.addAtGPS)
        self.addatgpsaction.setEnabled(self.gpsAction.isConnected)
        self.gpsAction.gpsfixed.connect(self.addatgpsaction.setEnabled)

        self.editingtoolbar.addToActionGroup(self.editattributesaction)
        self.editingtoolbar.addToActionGroup(self.moveaction)

        self.actionGroup.addAction(self.editingmodeaction)

        self.homeAction.triggered.connect(self.zoomToDefaultView)
        
        self.openProjectAction.triggered.connect(self.showOpenProjectDialog)
        self.openProjectAction.triggered.connect(functools.partial(self.stack.setCurrentIndex, 1))
        
        self.toggleRasterAction.triggered.connect(self.toggleRasterLayers)

        self.navtoolbar.insertAction(self.iface.actionZoomIn(), self.iface.actionTouch())
        self.navtoolbar.insertAction(self.iface.actionTouch(), self.homeAction)
        self.navtoolbar.insertAction(self.iface.actionTouch(), self.iface.actionZoomFullExtent())
        self.navtoolbar.insertAction(self.homeAction, self.iface.actionZoomFullExtent())

        self.navtoolbar.addAction(self.toggleRasterAction)
        self.navtoolbar.insertWidget(self.iface.actionZoomFullExtent(), spacewidget)
        self.toolbar.addAction(self.editingmodeaction)
        self.toolbar.addAction(self.syncAction)
        self.toolbar.addAction(self.gpsAction)
        self.toolbar.insertWidget(self.syncAction, gpsspacewidget)
        self.toolbar.insertSeparator(self.gpsAction)

        self.extraaddtoolbar.addAction(self.addatgpsaction)

        self.editingtoolbar.addAction(self.editattributesaction)
        self.editingtoolbar.addAction(self.moveaction)
        
        self.mapview = QAction(QIcon(":/icons/map"), "Map", self.menutoolbar)
        self.mapview.setCheckable(True)
        self.mapview.triggered.connect(functools.partial(self.stack.setCurrentIndex, 0))
        
        self.help = QAction(QIcon(":/icons/help"), "Help", self.menutoolbar)
        self.help.setCheckable(True)
        self.help.triggered.connect(functools.partial(self.stack.setCurrentIndex, 2))
        self.help.setVisible(False)
        
        self.userlabel = QLabel("Current User <br> {user}".format(user=getpass.getuser()))
        self.userlabel.setAlignment(Qt.AlignCenter)
        self.userlabel.setStyleSheet("""
            QLabel {
                    color: #8c8c8c;
                    font: 10px "Calibri" ;
                    }""")
        
        self.quit = QAction(QIcon(":/icons/quit"), "Quit", self.menutoolbar)
        self.quit.triggered.connect(self.iface.actionExit().trigger)

        self.menuGroup.addAction(self.mapview)
        self.menuGroup.addAction(self.openProjectAction)
        self.menuGroup.addAction(self.help)
        
        self.menutoolbar.addAction(self.mapview)
        self.menutoolbar.addAction(self.openProjectAction)
        self.menutoolbar.addAction(self.help)
        self.menutoolbar.addAction(self.quit)
        
        quitspacewidget = createSpacer()
        quitspacewidget.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        
        labelaction = self.menutoolbar.insertWidget(self.quit, self.userlabel)
        self.menutoolbar.insertWidget(labelaction, quitspacewidget)
        
        self.setupIcons()
        self.stack.currentChanged.connect(self.updateUIState)

Example 12

Project: qmap
Source File: qmap.py
View license
    def initGui(self):
        """
        Create all the icons and setup the tool bars.  Called by QGIS when
        loading. This is called before setupUI.
        """
        QApplication.setWindowIcon(QIcon(":/branding/logo"))
        self.mainwindow.findChildren(QMenuBar)[0].setVisible(False)
        self.mainwindow.setContextMenuPolicy(Qt.PreventContextMenu)
        self.mainwindow.setWindowTitle("IntraMaps Roam: Mobile Data Collection")
        
        # Disable QGIS logging window popups. We do our own logging
        QgsMessageLog.instance().messageReceived.disconnect()
        
        s = """
        QToolButton { 
            padding: 6px;
            color: #4f4f4f;
         }
        
        QToolButton:hover { 
            padding: 6px;
            background-color: rgb(211, 228, 255);
         }
         
        QToolBar {
         background: white;
        }
        
        QCheckBox::indicator {
             width: 40px;
             height: 40px;
         }
        
        QLabel {
            color: #4f4f4f;
        }
        
        QDialog { background-color: rgb(255, 255, 255); }
        
        QPushButton { 
            border: 1px solid #e1e1e1;
             padding: 6px;
            color: #4f4f4f;
         }
        
        QPushButton:hover { 
            border: 1px solid #e1e1e1;
             padding: 6px;
            background-color: rgb(211, 228, 255);
         }
        
        QCheckBox {
            color: #4f4f4f;
        }
        
        QComboBox::drop-down {
            width: 30px;
        }


        """
        self.mainwindow.setStyleSheet(s)
        
        mainwidget = self.mainwindow.centralWidget()
        mainwidget.setLayout(QGridLayout())
        mainwidget.layout().setContentsMargins(0,0,0,0)
        
        newlayout = QGridLayout()
        newlayout.setContentsMargins(0,0,0,0)
        newlayout.addWidget(self.iface.mapCanvas(), 0,0,2,1)
        newlayout.addWidget(self.iface.messageBar(), 0,0,1,1)
        
        wid = QWidget()
        wid.setLayout(newlayout)
        
        self.stack = QStackedWidget()
        self.messageBar = QgsMessageBar(wid)
        self.messageBar.setSizePolicy( QSizePolicy.Minimum, QSizePolicy.Fixed )
        self.errorreport = PopDownReport(self.messageBar)
         
        mainwidget.layout().addWidget(self.stack, 0,0,2,1)
        mainwidget.layout().addWidget(self.messageBar, 0,0,1,1)
        
        self.helppage = HelpPage()
        helppath = os.path.join(os.path.dirname(__file__) , 'help',"help.html")
        self.helppage.setHelpPage(helppath)
        
        self.projectwidget = ProjectsWidget()   
        self.projectwidget.requestOpenProject.connect(self.loadProject)
        self.stack.addWidget(wid)
        self.stack.addWidget(self.projectwidget)
        self.stack.addWidget(self.helppage)
                
        sys.excepthook = self.excepthook

        def createSpacer(width=30):
            widget = QWidget()
            widget.setMinimumWidth(width)
            return widget

        self.createToolBars()
        self.createActions()

        spacewidget = createSpacer(60)
        gpsspacewidget = createSpacer()
        gpsspacewidget.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)

        self.moveaction.toggled.connect(functools.partial(self.setMapTool, self.movetool))
        showediting = (functools.partial(self.editingtoolbar.showToolbar, 
                                         self.editingmodeaction,
                                         self.editattributesaction))
        self.editingmodeaction.toggled.connect(showediting)

        self.addatgpsaction.triggered.connect(self.addAtGPS)
        self.addatgpsaction.setEnabled(self.gpsAction.isConnected)
        self.gpsAction.gpsfixed.connect(self.addatgpsaction.setEnabled)

        self.editingtoolbar.addToActionGroup(self.editattributesaction)
        self.editingtoolbar.addToActionGroup(self.moveaction)

        self.actionGroup.addAction(self.editingmodeaction)

        self.homeAction.triggered.connect(self.zoomToDefaultView)
        
        self.openProjectAction.triggered.connect(self.showOpenProjectDialog)
        self.openProjectAction.triggered.connect(functools.partial(self.stack.setCurrentIndex, 1))
        
        self.toggleRasterAction.triggered.connect(self.toggleRasterLayers)

        self.navtoolbar.insertAction(self.iface.actionZoomIn(), self.iface.actionTouch())
        self.navtoolbar.insertAction(self.iface.actionTouch(), self.homeAction)
        self.navtoolbar.insertAction(self.iface.actionTouch(), self.iface.actionZoomFullExtent())
        self.navtoolbar.insertAction(self.homeAction, self.iface.actionZoomFullExtent())

        self.navtoolbar.addAction(self.toggleRasterAction)
        self.navtoolbar.insertWidget(self.iface.actionZoomFullExtent(), spacewidget)
        self.toolbar.addAction(self.editingmodeaction)
        self.toolbar.addAction(self.syncAction)
        self.toolbar.addAction(self.gpsAction)
        self.toolbar.insertWidget(self.syncAction, gpsspacewidget)
        self.toolbar.insertSeparator(self.gpsAction)

        self.extraaddtoolbar.addAction(self.addatgpsaction)

        self.editingtoolbar.addAction(self.editattributesaction)
        self.editingtoolbar.addAction(self.moveaction)
        
        self.mapview = QAction(QIcon(":/icons/map"), "Map", self.menutoolbar)
        self.mapview.setCheckable(True)
        self.mapview.triggered.connect(functools.partial(self.stack.setCurrentIndex, 0))
        
        self.help = QAction(QIcon(":/icons/help"), "Help", self.menutoolbar)
        self.help.setCheckable(True)
        self.help.triggered.connect(functools.partial(self.stack.setCurrentIndex, 2))
        self.help.setVisible(False)
        
        self.userlabel = QLabel("Current User <br> {user}".format(user=getpass.getuser()))
        self.userlabel.setAlignment(Qt.AlignCenter)
        self.userlabel.setStyleSheet("""
            QLabel {
                    color: #8c8c8c;
                    font: 10px "Calibri" ;
                    }""")
        
        self.quit = QAction(QIcon(":/icons/quit"), "Quit", self.menutoolbar)
        self.quit.triggered.connect(self.iface.actionExit().trigger)

        self.menuGroup.addAction(self.mapview)
        self.menuGroup.addAction(self.openProjectAction)
        self.menuGroup.addAction(self.help)
        
        self.menutoolbar.addAction(self.mapview)
        self.menutoolbar.addAction(self.openProjectAction)
        self.menutoolbar.addAction(self.help)
        self.menutoolbar.addAction(self.quit)
        
        quitspacewidget = createSpacer()
        quitspacewidget.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        
        labelaction = self.menutoolbar.insertWidget(self.quit, self.userlabel)
        self.menutoolbar.insertWidget(labelaction, quitspacewidget)
        
        self.setupIcons()
        self.stack.currentChanged.connect(self.updateUIState)

Example 13

Project: openshot-qt
Source File: export.py
View license
    def __init__(self):

        # Create dialog class
        QDialog.__init__(self)

        # Load UI from designer
        ui_util.load_ui(self, self.ui_path)

        # Init UI
        ui_util.init_ui(self)

        # get translations
        app = get_app()
        _ = app._tr

        # Get settings
        self.s = settings.get_settings()

        # Track metrics
        track_metric_screen("export-screen")

        # Dynamically load tabs from settings data
        self.settings_data = settings.get_settings().get_all_settings()

        # Add buttons to interface
        self.export_button = QPushButton(_('Export Video'))
        self.buttonBox.addButton(self.export_button, QDialogButtonBox.AcceptRole)
        self.buttonBox.addButton(QPushButton(_('Cancel')), QDialogButtonBox.RejectRole)
        self.exporting = False

        # Get the original timeline settings
        self.original_width = get_app().window.timeline_sync.timeline.info.width
        self.original_height = get_app().window.timeline_sync.timeline.info.height
        fps = get_app().window.timeline_sync.timeline.info.fps
        self.original_fps = { "num" : fps.num, "den" : fps.den }
        self.original_sample_rate = get_app().window.timeline_sync.timeline.info.sample_rate
        self.original_channels = get_app().window.timeline_sync.timeline.info.channels
        self.original_channel_layout = get_app().window.timeline_sync.timeline.info.channel_layout

        # Default export path
        recommended_path = recommended_path = os.path.join(info.HOME_PATH)
        if app.project.current_filepath:
            recommended_path = os.path.dirname(app.project.current_filepath)

        export_path = get_app().project.get(["export_path"])
        if os.path.exists(export_path):
            # Use last selected export path
            self.txtExportFolder.setText(export_path)
        else:
            # Default to home dir
            self.txtExportFolder.setText(recommended_path)

        # Is this a saved project?
        if not get_app().project.current_filepath:
            # Not saved yet
            self.txtFileName.setText(_("Untitled Project"))
        else:
            # Yes, project is saved
            # Get just the filename
            parent_path, filename = os.path.split(get_app().project.current_filepath)
            filename, ext = os.path.splitext(filename)
            self.txtFileName.setText(filename.replace("_", " ").replace("-", " ").capitalize())

        # Default image type
        self.txtImageFormat.setText("%05.png")

        # Loop through Export To options
        export_options = [_("Video & Audio"), _("Image Sequence")]
        for option in export_options:
            # append profile to list
            self.cboExportTo.addItem(option)

        # Add channel layouts
        self.channel_layout_choices = []
        for layout in [(openshot.LAYOUT_MONO, _("Mono (1 Channel)")),
                       (openshot.LAYOUT_STEREO, _("Stereo (2 Channel)")),
                       (openshot.LAYOUT_SURROUND, _("Surround (3 Channel)")),
                       (openshot.LAYOUT_5POINT1, _("Surround (5.1 Channel)")),
                       (openshot.LAYOUT_7POINT1, _("Surround (7.1 Channel)"))]:
            log.info(layout)
            self.channel_layout_choices.append(layout[0])
            self.cboChannelLayout.addItem(layout[1], layout[0])

        # Connect signals
        self.btnBrowse.clicked.connect(functools.partial(self.btnBrowse_clicked))
        self.cboSimpleProjectType.currentIndexChanged.connect(
            functools.partial(self.cboSimpleProjectType_index_changed, self.cboSimpleProjectType))
        self.cboProfile.currentIndexChanged.connect(functools.partial(self.cboProfile_index_changed, self.cboProfile))
        self.cboSimpleTarget.currentIndexChanged.connect(
            functools.partial(self.cboSimpleTarget_index_changed, self.cboSimpleTarget))
        self.cboSimpleVideoProfile.currentIndexChanged.connect(
            functools.partial(self.cboSimpleVideoProfile_index_changed, self.cboSimpleVideoProfile))
        self.cboSimpleQuality.currentIndexChanged.connect(
            functools.partial(self.cboSimpleQuality_index_changed, self.cboSimpleQuality))
        self.cboChannelLayout.currentIndexChanged.connect(self.updateChannels)


        # ********* Advaned Profile List **********
        # Loop through profiles
        self.profile_names = []
        self.profile_paths = {}
        for profile_folder in [info.USER_PROFILES_PATH, info.PROFILES_PATH]:
            for file in os.listdir(profile_folder):
                # Load Profile
                profile_path = os.path.join(profile_folder, file)
                profile = openshot.Profile(profile_path)

                # Add description of Profile to list
                self.profile_names.append(profile.info.description)
                self.profile_paths[profile.info.description] = profile_path

        # Sort list
        self.profile_names.sort()

        # Loop through sorted profiles
        box_index = 0
        self.selected_profile_index = 0
        for profile_name in self.profile_names:

            # Add to dropdown
            self.cboProfile.addItem(profile_name, self.profile_paths[profile_name])

            # Set default (if it matches the project)
            if app.project.get(['profile']) == profile_name:
                self.selected_profile_index = box_index

            # increment item counter
            box_index += 1


        # ********* Simple Project Type **********
        # load the simple project type dropdown
        presets = []
        for file in os.listdir(info.EXPORT_PRESETS_DIR):
            xmldoc = xml.parse(os.path.join(info.EXPORT_PRESETS_DIR, file))
            type = xmldoc.getElementsByTagName("type")
            presets.append(_(type[0].childNodes[0].data))

        # Exclude duplicates
        type_index = 0
        selected_type = 0
        presets = list(set(presets))
        for item in sorted(presets):
            self.cboSimpleProjectType.addItem(item, item)
            if item == _("All Formats"):
                selected_type = type_index
            type_index += 1

        # Always select 'All Formats' option
        self.cboSimpleProjectType.setCurrentIndex(selected_type)


        # Populate all profiles
        self.populateAllProfiles(app.project.get(['profile']))

        # Connect framerate signals
        self.txtFrameRateNum.valueChanged.connect(self.updateFrameRate)
        self.txtFrameRateDen.valueChanged.connect(self.updateFrameRate)
        self.txtWidth.valueChanged.connect(self.updateFrameRate)
        self.txtHeight.valueChanged.connect(self.updateFrameRate)
        self.txtSampleRate.valueChanged.connect(self.updateFrameRate)
        self.txtChannels.valueChanged.connect(self.updateFrameRate)
        self.cboChannelLayout.currentIndexChanged.connect(self.updateFrameRate)

        # Determine the length of the timeline (in frames)
        self.updateFrameRate()

Example 14

Project: openshot-qt
Source File: preferences.py
View license
    def __init__(self):

        # Create dialog class
        QDialog.__init__(self)

        # Load UI from designer
        ui_util.load_ui(self, self.ui_path)

        # Init UI
        ui_util.init_ui(self)

        # get translations
        app = get_app()
        _ = app._tr

        # Get settings
        self.s = settings.get_settings()

        # Dynamically load tabs from settings data
        self.settings_data = settings.get_settings().get_all_settings()

        # Track metrics
        track_metric_screen("preferences-screen")

        # Load all user values
        self.params = {}
        for item in self.settings_data:
            if "setting" in item and "value" in item:
                self.params[item["setting"]] = item

        self.requires_restart = False
        self.category_names = {}
        self.category_tabs = {}
        self.category_sort = {}

        # Loop through settings and find all unique categories
        for item in self.settings_data:
            category = item.get("category")
            setting_type = item.get("type")
            sort_category = item.get("sort")

            # Indicate sorted category
            if sort_category:
                self.category_sort[category] = sort_category

            if not setting_type == "hidden":
                # Load setting
                if not category in self.category_names:
                    self.category_names[category] = []

                    # Create scrollarea
                    scroll_area = QScrollArea(self)
                    scroll_area.setWidgetResizable(True)
                    scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
                    scroll_area.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)

                    # Create tab widget and layout
                    layout = QVBoxLayout()
                    tabWidget = QWidget(self)
                    tabWidget.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred)
                    tabWidget.setLayout(layout)
                    scroll_area.setWidget(tabWidget)

                    # Add tab
                    self.tabCategories.addTab(scroll_area, _(category))
                    self.category_tabs[category] = tabWidget

                # Append translated title
                item["title_tr"] = _(item.get("title"))

                # Append settings into correct category
                self.category_names[category].append(item)

        # Loop through each category setting, and add them to the tabs
        for category in self.category_tabs.keys():
            tabWidget = self.category_tabs[category]

            # Get list of items in category
            params = self.category_names[category]
            if self.category_sort.get(category):
                # Sort this category by translated title
                params.sort(key=operator.itemgetter("title_tr"))

            # Loop through settings for each category
            for param in params:

                # Create Label
                widget = None
                label = QLabel()
                label.setText(_(param["title"]))
                label.setToolTip(_(param["title"]))

                if param["type"] == "spinner":
                    # create QDoubleSpinBox
                    widget = QDoubleSpinBox()
                    widget.setMinimum(float(param["min"]))
                    widget.setMaximum(float(param["max"]))
                    widget.setValue(float(param["value"]))
                    widget.setSingleStep(1.0)
                    widget.setToolTip(param["title"])
                    widget.valueChanged.connect(functools.partial(self.spinner_value_changed, param))

                if param["type"] == "spinner-int":
                    # create QDoubleSpinBox
                    widget = QSpinBox()
                    widget.setMinimum(int(param["min"]))
                    widget.setMaximum(int(param["max"]))
                    widget.setValue(int(param["value"]))
                    widget.setSingleStep(1.0)
                    widget.setToolTip(param["title"])
                    widget.valueChanged.connect(functools.partial(self.spinner_value_changed, param))

                elif param["type"] == "text":
                    # create QLineEdit
                    widget = QLineEdit()
                    widget.setText(_(param["value"]))
                    widget.textChanged.connect(functools.partial(self.text_value_changed, widget, param))

                elif param["type"] == "bool":
                    # create spinner
                    widget = QCheckBox()
                    if param["value"] == True:
                        widget.setCheckState(Qt.Checked)
                    else:
                        widget.setCheckState(Qt.Unchecked)
                    widget.stateChanged.connect(functools.partial(self.bool_value_changed, widget, param))

                elif param["type"] == "dropdown":

                    # create spinner
                    widget = QComboBox()

                    # Get values
                    value_list = param["values"]
                    # Overwrite value list (for profile dropdown)
                    if param["setting"] == "default-profile":
                        value_list = []
                        # Loop through profiles
                        for profile_folder in [info.USER_PROFILES_PATH, info.PROFILES_PATH]:
                            for file in os.listdir(profile_folder):
                                # Load Profile and append description
                                profile_path = os.path.join(profile_folder, file)
                                profile = openshot.Profile(profile_path)
                                value_list.append({"name":profile.info.description, "value":profile.info.description})
                        # Sort profile list
                        value_list.sort(key=operator.itemgetter("name"))

                    # Overwrite value list (for language dropdown)
                    if param["setting"] == "default-language":
                        value_list = []
                        # Loop through languages
                        for locale, language, country in get_all_languages():
                            # Load Profile and append description
                            if language:
                                lang_name = "%s (%s)" % (language, locale)
                                value_list.append({"name":lang_name, "value":locale})
                        # Sort profile list
                        value_list.sort(key=operator.itemgetter("name"))
                        # Add Default to top of list
                        value_list.insert(0, {"name":_("Default"), "value":"Default"})


                    # Add normal values
                    box_index = 0
                    for value_item in value_list:
                        k = value_item["name"]
                        v = value_item["value"]
                        # add dropdown item
                        widget.addItem(_(k), v)

                        # select dropdown (if default)
                        if v == param["value"]:
                            widget.setCurrentIndex(box_index)
                        box_index = box_index + 1

                    widget.currentIndexChanged.connect(functools.partial(self.dropdown_index_changed, widget, param))


                # Add Label and Widget to the form
                if (widget and label):
                    # Add minimum size
                    label.setMinimumWidth(180);
                    label.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred)
                    widget.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred)

                    # Create HBox layout
                    layout_hbox = QHBoxLayout()
                    layout_hbox.addWidget(label)
                    layout_hbox.addWidget(widget)

                    # Add widget to layout
                    tabWidget.layout().addLayout(layout_hbox)
                elif (label):
                    # Add widget to layout
                    tabWidget.layout().addWidget(label)

            # Add stretch to bottom of layout
            tabWidget.layout().addStretch()

Example 15

Project: openshot-qt
Source File: blender_treeview.py
View license
    def currentChanged(self, selected, deselected):
        # Get selected item
        self.selected = selected
        self.deselected = deselected

        # Get translation object
        _ = self.app._tr

        # Clear existing settings
        self.win.clear_effect_controls()

        # Get animation details
        animation = self.get_animation_details()
        self.selected_template = animation.get("service")

        # In newer versions of Qt, setting the model invokes the currentChanged signal,
        # but the selection is -1. So, just do nothing here.
        if not self.selected_template:
            return

        # Assign a new unique id for each template selected
        self.generateUniqueFolder()

        # Loop through params
        for param in animation.get("params",[]):
            log.info(param["title"])

            # Is Hidden Param?
            if param["name"] == "start_frame" or param["name"] == "end_frame":
                # add value to dictionary
                self.params[param["name"]] = int(param["default"])

                # skip to next param without rendering the controls
                continue

            # Create Label
            widget = None
            label = QLabel()
            label.setText(_(param["title"]))
            label.setToolTip(_(param["title"]))

            if param["type"] == "spinner":
                # add value to dictionary
                self.params[param["name"]] = float(param["default"])

                # create spinner
                widget = QDoubleSpinBox()
                widget.setMinimum(float(param["min"]))
                widget.setMaximum(float(param["max"]))
                widget.setValue(float(param["default"]))
                widget.setSingleStep(0.01)
                widget.setToolTip(param["title"])
                widget.valueChanged.connect(functools.partial(self.spinner_value_changed, param))

            elif param["type"] == "text":
                # add value to dictionary
                self.params[param["name"]] = _(param["default"])

                # create spinner
                widget = QLineEdit()
                widget.setText(_(param["default"]))
                widget.textChanged.connect(functools.partial(self.text_value_changed, widget, param))

            elif param["type"] == "multiline":
                # add value to dictionary
                self.params[param["name"]] = _(param["default"])

                # create spinner
                widget = QTextEdit()
                widget.setText(_(param["default"]).replace("\\n", "\n"))
                widget.textChanged.connect(functools.partial(self.text_value_changed, widget, param))

            elif param["type"] == "dropdown":
                # add value to dictionary
                self.params[param["name"]] = param["default"]

                # create spinner
                widget = QComboBox()
                widget.currentIndexChanged.connect(functools.partial(self.dropdown_index_changed, widget, param))

                # Add values to dropdown
                if "project_files" in param["name"]:
                    # override files dropdown
                    param["values"] = {}
                    for file in File.filter():
                        if file.data["media_type"] in ("image", "video"):
                            (dirName, fileName) = os.path.split(file.data["path"])
                            (fileBaseName, fileExtension) = os.path.splitext(fileName)

                            if fileExtension.lower() not in (".svg"):
                                param["values"][fileName] = "|".join((file.data["path"], str(file.data["height"]),
                                                                      str(file.data["width"]), file.data["media_type"],
                                                                      str(file.data["fps"]["num"] / file.data["fps"][
                                                                          "den"])))

                # Add normal values
                box_index = 0
                for k, v in sorted(param["values"].items()):
                    # add dropdown item
                    widget.addItem(_(k), v)

                    # select dropdown (if default)
                    if v == param["default"]:
                        widget.setCurrentIndex(box_index)
                    box_index = box_index + 1

                if not param["values"]:
                    widget.addItem(_("No Files Found"), "")
                    widget.setEnabled(False)

            elif param["type"] == "color":
                # add value to dictionary
                color = QColor(param["default"])
                self.params[param["name"]] = [color.redF(), color.greenF(), color.blueF()]

                widget = QPushButton()
                widget.setText("")
                widget.setStyleSheet("background-color: {}".format(param["default"]))
                widget.clicked.connect(functools.partial(self.color_button_clicked, widget, param))

            # Add Label and Widget to the form
            if (widget and label):
                self.win.settingsContainer.layout().addRow(label, widget)
            elif (label):
                self.win.settingsContainer.layout().addRow(label)

        # Enable interface
        self.enable_interface()

        # Init slider values
        self.init_slider_values()

Example 16

Project: taskflow
Source File: server.py
View license
    def _process_request(self, request, message):
        """Process request message and reply back."""
        try:
            # NOTE(skudriashev): parse broker message first to get
            # the `reply_to` and the `task_uuid` parameters to have
            # possibility to reply back (if we can't parse, we can't respond
            # in the first place...).
            reply_to, task_uuid = self._parse_message(message)
        except ValueError:
            LOG.warn("Failed to parse request attributes from message '%s'",
                     ku.DelayedPretty(message), exc_info=True)
            return
        else:
            # prepare reply callback
            reply_callback = functools.partial(self._reply, True, reply_to,
                                               task_uuid)

        # Parse the request to get the activity/work to perform.
        try:
            work = pr.Request.from_dict(request, task_uuid=task_uuid)
        except ValueError:
            with misc.capture_failure() as failure:
                LOG.warn("Failed to parse request contents from message '%s'",
                         ku.DelayedPretty(message), exc_info=True)
                reply_callback(result=pr.failure_to_dict(failure))
                return

        # Now fetch the task endpoint (and action handler on it).
        try:
            endpoint = self._endpoints[work.task_cls]
        except KeyError:
            with misc.capture_failure() as failure:
                LOG.warn("The '%s' task endpoint does not exist, unable"
                         " to continue processing request message '%s'",
                         work.task_cls, ku.DelayedPretty(message),
                         exc_info=True)
                reply_callback(result=pr.failure_to_dict(failure))
                return
        else:
            try:
                handler = getattr(endpoint, work.action)
            except AttributeError:
                with misc.capture_failure() as failure:
                    LOG.warn("The '%s' handler does not exist on task endpoint"
                             " '%s', unable to continue processing request"
                             " message '%s'", work.action, endpoint,
                             ku.DelayedPretty(message), exc_info=True)
                    reply_callback(result=pr.failure_to_dict(failure))
                    return
            else:
                try:
                    task = endpoint.generate(name=work.task_name)
                except Exception:
                    with misc.capture_failure() as failure:
                        LOG.warn("The '%s' task '%s' generation for request"
                                 " message '%s' failed", endpoint, work.action,
                                 ku.DelayedPretty(message), exc_info=True)
                        reply_callback(result=pr.failure_to_dict(failure))
                        return
                else:
                    if not reply_callback(state=pr.RUNNING):
                        return

        # Associate *any* events this task emits with a proxy that will
        # emit them back to the engine... for handling at the engine side
        # of things...
        if task.notifier.can_be_registered(nt.Notifier.ANY):
            task.notifier.register(nt.Notifier.ANY,
                                   functools.partial(self._on_event,
                                                     reply_to, task_uuid))
        elif isinstance(task.notifier, nt.RestrictedNotifier):
            # Only proxy the allowable events then...
            for event_type in task.notifier.events_iter():
                task.notifier.register(event_type,
                                       functools.partial(self._on_event,
                                                         reply_to, task_uuid))

        # Perform the task action.
        try:
            result = handler(task, **work.arguments)
        except Exception:
            with misc.capture_failure() as failure:
                LOG.warn("The '%s' endpoint '%s' execution for request"
                         " message '%s' failed", endpoint, work.action,
                         ku.DelayedPretty(message), exc_info=True)
                reply_callback(result=pr.failure_to_dict(failure))
        else:
            # And be done with it!
            if isinstance(result, ft.Failure):
                reply_callback(result=result.to_dict())
            else:
                reply_callback(state=pr.SUCCESS, result=result)

Example 17

Project: glean
Source File: test_glean.py
View license
    @mock.patch('platform.dist', new_callable=mock.Mock)
    @mock.patch('subprocess.call', return_value=0, new_callable=mock.Mock)
    @mock.patch('subprocess.check_output', return_value=0,
                new_callable=mock.Mock)
    @mock.patch('os.unlink', return_value=0, new_callable=mock.Mock)
    @mock.patch('os.symlink', return_value=0, new_callable=mock.Mock)
    @mock.patch('os.path.exists', new_callable=mock.Mock)
    @mock.patch('os.listdir', new_callable=mock.Mock)
    @mock.patch('os.system', return_value=0, new_callable=mock.Mock)
    @mock.patch('glean.cmd.open', new_callable=mock.Mock)
    @mock.patch.object(sys, 'argv', ['./glean', '--hostname'])
    def _assert_distro_provider(self, distro, provider, interface,
                                mock_open,
                                mock_os_system,
                                mock_os_listdir,
                                mock_os_path_exists,
                                mock_os_symlink,
                                mock_os_unlink,
                                mock_check_output,
                                mock_call,
                                mock_platform_dist):
        """Main test function

        :param distro: distro to return from "platform.dist"
        :param provider: we will look in fixtures/provider for mocked
                         out files
        :param interface: --interface argument; None for no argument
        """

        mock_platform_dist.return_value = (distro, '', '')

        # These functions are watching the path and faking results
        # based on various things
        # XXX : There are several virtual file-systems available, we
        # might like to look into them and just point ourselves at
        # testing file-systems in the future if this becomes more
        # complex.
        mock_os_path_exists.side_effect = functools.partial(
            self.os_path_exists_side_effect, provider)
        mock_os_listdir.side_effect = functools.partial(
            self.os_listdir_side_effect, provider)
        mock_open.side_effect = functools.partial(
            self.open_side_effect, provider)

        if interface:
            sys.argv.append('--interface=%s' % interface)

        cmd.main()

        output_filename = '%s.%s.network.out' % (provider, distro.lower())
        output_path = os.path.join(sample_data_path, 'test', output_filename)

        # Generate a list of (dest, content) into write_blocks to assert
        write_blocks = []
        lines = open(output_path).readlines()
        write_dest = None
        write_content = None
        for line in lines:
            if line.startswith('### Write '):
                if write_dest is not None:
                    write_blocks.append((write_dest, write_content))
                write_dest = line[len('### Write '):-1]
                write_content = ''
            else:
                write_content += line
        if write_dest is not None:
            write_blocks.append((write_dest, write_content))

        for dest, content in write_blocks:
            if interface and interface not in dest:
                continue
            self.assertNotIn("eth2", dest)
            self.assertIn(dest, self.file_handle_mocks)
            write_handle = self.file_handle_mocks[dest].write
            write_handle.assert_called_once_with(content)

        if self._resolv_unlinked:
            mock_os_unlink.assert_called_once_with('/etc/resolv.conf')

        # Check hostname
        meta_data_path = 'mnt/config/openstack/latest/meta_data.json'
        hostname = None
        with open(os.path.join(sample_data_path, provider,
                               meta_data_path)) as fh:
            meta_data = json.load(fh)
            hostname = meta_data['name']

        mock_call.assert_has_calls([mock.call(['hostname', hostname])])
        if distro.lower() is 'gentoo':
            (self.file_handle_mocks['/etc/conf.d/hostname'].write.
                assert_has_calls([mock.call(hostname)]))
        else:
            self.file_handle_mocks['/etc/hostname'].write.assert_has_calls(
                [mock.call(hostname), mock.call('\n')])

        # Check hosts entry
        hostname_ip = ips[provider]
        calls = [mock.call('%s %s\n' % (hostname_ip, hostname)), ]
        short_hostname = hostname.split('.')[0]
        if hostname != short_hostname:
            calls.append(mock.call('%s %s\n' % (hostname_ip, short_hostname)))

        self.file_handle_mocks['/etc/hosts'].write.assert_has_calls(
            calls, any_order=True)

Example 18

Project: persepolis
Source File: addlink.py
View license
    def __init__(self , callback,flashgot_add_link_dictionary = {}):
        super().__init__()
        self.callback = callback
        self.flashgot_add_link_dictionary = flashgot_add_link_dictionary

#add change_name field
        self.change_name_horizontalLayout = QHBoxLayout() 
        self.change_name_checkBox = QCheckBox(self.link_frame)
        self.change_name_checkBox.setText('Change File Name : ')
        self.change_name_horizontalLayout.addWidget(self.change_name_checkBox)



        self.change_name_lineEdit = QLineEdit(self.link_frame)
        self.change_name_horizontalLayout.addWidget(self.change_name_lineEdit)

        self.link_verticalLayout.addLayout(self.change_name_horizontalLayout)

#entry initialization            

        f = Open(setting_file)
        setting_file_lines = f.readlines()
        f.close()
        setting_dict_str = str(setting_file_lines[0].strip())
        setting_dict = ast.literal_eval(setting_dict_str) 
        global connections
        connections = int(setting_dict['connections'])
        global download_path
        download_path = str(setting_dict['download_path'])


        global init_file
        init_file = str(home_address) + "/.config/persepolis_download_manager/addlink_init_file"
        osCommands.touch(init_file)
        f = Open(init_file)
        init_file_lines = f.readlines()
        f.close()

#initialization
        self.connections_spinBox.setValue(connections)
        self.download_folder_lineEdit.setText(download_path)
        self.download_folder_lineEdit.setEnabled(False)

        self.ok_pushButton.setEnabled(False)
        self.download_later_pushButton.setEnabled(False)
        self.link_lineEdit.textChanged.connect(self.linkLineChanged)
#AddLink - checking clipboard for link!   
        if 'link' in self.flashgot_add_link_dictionary : 
            self.link_lineEdit.setText(str(self.flashgot_add_link_dictionary['link']))
            del self.flashgot_add_link_dictionary['link']
        else:    
            clipboard = QApplication.clipboard()
            text = clipboard.text()
            try:
                if ("tp:/" in text[2:6]) or ("tps:/" in text[2:7]) :
                    self.link_lineEdit.setText(str(text))
            except:
                pass
#ip_lineEdit initialization 
        try:
            self.ip_lineEdit.setText(init_file_lines[0].strip())
        except :
            pass

#proxy user lineEdit initialization 
        try:
            self.proxy_user_lineEdit.setText(init_file_lines[2].strip())
        except :
            pass

#port_spinBox initialization 
        try:
            self.port_spinBox.setValue(int(init_file_lines[1].strip()))
        except :
            pass


#download UserName initialization
        try:
            self.download_user_lineEdit.setText(init_file_lines[3].strip())
        except :
            pass
#connect folder_pushButton
        self.folder_pushButton.clicked.connect(self.changeFolder)

#connect OK and canel download_later button

        self.cancel_pushButton.clicked.connect(self.cancelButtonPressed)
        self.ok_pushButton.clicked.connect(functools.partial(self.okButtonPressed ,download_later = 'no'))
        self.download_later_pushButton.clicked.connect(functools.partial(self.okButtonPressed , download_later = 'yes'))
#frames and checkBoxes
        self.proxy_frame.setEnabled(False)
        self.proxy_checkBox.toggled.connect(self.proxyFrame)

        self.download_frame.setEnabled(False)
        self.download_checkBox.toggled.connect(self.downloadFrame)

        self.limit_frame.setEnabled(False)
        self.limit_checkBox.toggled.connect(self.limitFrame)
    
        self.start_frame.setEnabled(False)
        self.start_checkBox.toggled.connect(self.startFrame)

        self.end_frame.setEnabled(False)
        self.end_checkBox.toggled.connect(self.endFrame)


        self.change_name_lineEdit.setEnabled(False)
        self.change_name_checkBox.toggled.connect(self.changeName)

#check name of flashgot link 
        if 'out' in self.flashgot_add_link_dictionary:
            self.change_name_lineEdit.setText(str(self.flashgot_add_link_dictionary['out']))
            self.change_name_checkBox.setChecked(True)
            del self.flashgot_add_link_dictionary['out']

Example 19

Project: mincss
Source File: app.py
View license
@app.route('/<path:path>')
def proxy(path):
    if path == 'favicon.ico':
        abort(404)
    url = path
    if not path.count('://'):
        url = 'http://' + url

    query = urlparse(request.url).query
    if query:
        url += '?%s' % query
    logging.info('Downloading %s' % url)
    t0 = time.time()
    html = download(url)
    t1 = time.time()
    print('%.4f seconds to download' % (t1 - t0))

    p = Processor(debug=False, optimize_lookup=True)
    # since we've already download the HTML
    t0 = time.time()
    p.process_html(html, url)
    t1 = time.time()
    p.process()
    t2 = time.time()
    print('%.4f seconds to parse and process' % (t2 - t1))

    collect_stats = request.args.get('MINCSS_STATS', False)
    stats = []
    css_url_regex = re.compile('url\(([^\)]+)\)')

    def css_url_replacer(match, href=None):
        filename = match.groups()[0]
        bail = match.group()

        if (
            (filename.startswith('"') and filename.endswith('"')) or
            (filename.startswith("'") and filename.endswith("'"))
        ):
            filename = filename[1:-1]
        if 'data:image' in filename or '://' in filename:
            return bail
        if filename == '.':
            # this is a known IE hack in CSS
            return bail

        new_filename = urljoin(url, filename)
        return 'url("%s")' % new_filename

    for i, each in enumerate(p.inlines):
        # this should be using CSSSelector instead
        new_inline = each.after
        new_inline = css_url_regex.sub(
            functools.partial(css_url_replacer, href=url),
            new_inline
        )
        stats.append(
            ('inline %s' % (i + 1), each.before, each.after)
        )
        html = html.replace(each.before, new_inline)

    parser = etree.HTMLParser()
    stripped = html.strip()
    tree = etree.fromstring(stripped, parser).getroottree()
    page = tree.getroot()

    # lxml inserts a doctype if none exists, so only include it in
    # the root if it was in the original html.
    was_doctype = tree.docinfo.doctype

    links = dict((x.href, x) for x in p.links)

    for link in CSSSelector('link')(page):
        if (
            link.attrib.get('rel', '') == 'stylesheet' or
            link.attrib['href'].lower().split('?')[0].endswith('.css')
        ):
            hash_ = hashlib.md5(url + link.attrib['href']).hexdigest()[:7]
            now = datetime.date.today()
            destination_dir = os.path.join(
                CACHE_DIR,
                str(now.year),
                str(now.month),
                str(now.day),
            )
            mkdir(destination_dir)
            new_css = links[link.attrib['href']].after
            stats.append((
                link.attrib['href'],
                links[link.attrib['href']].before,
                links[link.attrib['href']].after
            ))
            new_css = css_url_regex.sub(
                functools.partial(
                    css_url_replacer,
                    href=link.attrib['href']
                ),
                new_css
            )
            destination = os.path.join(destination_dir, hash_ + '.css')

            with codecs.open(destination, 'w', 'utf-8') as f:
                f.write(new_css)

            link.attrib['href'] = (
                '/cache%s' % destination.replace(CACHE_DIR, '')
            )

    for img in CSSSelector('img, script')(page):
        if 'src' in img.attrib:
            orig_src = urljoin(url, img.attrib['src'])
            img.attrib['src'] = orig_src

    for a in CSSSelector('a')(page):
        if 'href' not in a.attrib:
            continue
        href = a.attrib['href']

        if (
            '://' in href or
            href.startswith('#') or
            href.startswith('javascript:')
        ):
            continue

        if href.startswith('/'):
            a.attrib['href'] = (
                '/' +
                urljoin(url, a.attrib['href'])
                .replace('http://', '')
            )
        if collect_stats:
            a.attrib['href'] = add_collect_stats_qs(
                a.attrib['href'],
                collect_stats
            )

    html = etree.tostring(page, method='html')
    if collect_stats:
        html = re.sub(
            '<body[^>]*>',
            lambda m: m.group() + summorize_stats_html(stats),
            html,
            flags=re.I | re.M,
            count=1
        )

    return (was_doctype and was_doctype or '') + '\n' + html

Example 20

Project: pypot
Source File: __init__.py
View license
def from_vrep(config, vrep_host='127.0.0.1', vrep_port=19997, scene=None,
              tracked_objects=[], tracked_collisions=[]):
    """ Create a robot from a V-REP instance.

    :param config: robot configuration (either the path to the json or directly the dictionary)
    :type config: str or dict
    :param str vrep_host: host of the V-REP server
    :param int vrep_port: port of the V-REP server
    :param str scene: path to the V-REP scene to load and start
    :param list tracked_objects: list of V-REP dummy object to track
    :param list tracked_collisions: list of V-REP collision to track

    This function tries to connect to a V-REP instance and expects to find motors with names corresponding as the ones found in the config.

    .. note:: The :class:`~pypot.robot.robot.Robot` returned will also provide a convenience reset_simulation method which resets the simulation and the robot position to its intial stance.

    .. note:: Using the same configuration, you should be able to switch from a real to a simulated robot just by switching from :func:`~pypot.robot.config.from_config` to :func:`~pypot.vrep.from_vrep`.
        For instance::

            import json

            with open('my_config.json') as f:
                config = json.load(f)

            from pypot.robot import from_config
            from pypot.vrep import from_vrep

            real_robot = from_config(config)
            simulated_robot = from_vrep(config, '127.0.0.1', 19997, 'poppy.ttt')

    """
    vrep_io = VrepIO(vrep_host, vrep_port)

    vreptime = vrep_time(vrep_io)
    pypot_time.time = vreptime.get_time
    pypot_time.sleep = vreptime.sleep

    if isinstance(config, basestring):
        with open(config) as f:
            config = json.load(f, object_pairs_hook=OrderedDict)

    motors = [motor_from_confignode(config, name)
              for name in config['motors'].keys()]

    vc = VrepController(vrep_io, scene, motors)
    vc._init_vrep_streaming()

    sensor_controllers = []

    if tracked_objects:
        sensors = [ObjectTracker(name) for name in tracked_objects]
        vot = VrepObjectTracker(vrep_io, sensors)
        sensor_controllers.append(vot)

    if tracked_collisions:
        sensors = [VrepCollisionDetector(name) for name in tracked_collisions]
        vct = VrepCollisionTracker(vrep_io, sensors)
        sensor_controllers.append(vct)

    robot = Robot(motor_controllers=[vc],
                  sensor_controllers=sensor_controllers)

    for m in robot.motors:
        m.goto_behavior = 'minjerk'

    init_pos = {m: m.goal_position for m in robot.motors}

    make_alias(config, robot)

    def start_simu():
        vrep_io.start_simulation()

        for m, p in init_pos.iteritems():
            m.goal_position = p

        vc.start()

        if tracked_objects:
            vot.start()

        if tracked_collisions:
            vct.start()

        while vrep_io.get_simulation_current_time() < 1.:
            sys_time.sleep(0.1)

    def stop_simu():
        if tracked_objects:
            vot.stop()

        if tracked_collisions:
            vct.stop()

        vc.stop()
        vrep_io.stop_simulation()

    def reset_simu():
        stop_simu()
        sys_time.sleep(0.5)
        start_simu()

    robot.start_simulation = start_simu
    robot.stop_simulation = stop_simu
    robot.reset_simulation = reset_simu

    def current_simulation_time(robot):
        return robot._controllers[0].io.get_simulation_current_time()
    Robot.current_simulation_time = property(lambda robot: current_simulation_time(robot))

    def get_object_position(robot, object, relative_to_object=None):
        return vrep_io.get_object_position(object, relative_to_object)
    Robot.get_object_position = partial(get_object_position, robot)

    def get_object_orientation(robot, object, relative_to_object=None):
        return vrep_io.get_object_orientation(object, relative_to_object)
    Robot.get_object_orientation = partial(get_object_orientation, robot)

    return robot

Example 21

Project: pysmt
Source File: parser.py
View license
    def __init__(self, environment=None, interactive=False):
        self.env = get_env() if environment is None else environment
        self.interactive = interactive

        # Placeholders for fields filled by self._reset
        self.cache = None
        self.logic = None
        self._reset()

        mgr = self.env.formula_manager

        # Fixing the issue with integer/real numbers on arithmetic
        # operators.
        #
        # We try to apply the operator as it is, in case of failure,
        # we try to interpret as reals the constant operands that are
        # integers
        def fix_real(op, *args):
            try:
                return op(*args)
            except PysmtTypeError:
                get_type = self.env.stc.get_type
                get_free_variables = self.env.fvo.get_free_variables
                new_args = []
                for x in args:
                    if get_type(x).is_int_type() and\
                       len(get_free_variables(x)) == 0:
                        new_args.append(mgr.ToReal(x))
                    else:
                        new_args.append(x)
                if args == new_args:
                    raise
                return op(*new_args)

        self.LT = functools.partial(fix_real, mgr.LT)
        self.GT = functools.partial(fix_real, mgr.GT)
        self.LE = functools.partial(fix_real, mgr.LE)
        self.GE = functools.partial(fix_real, mgr.GE)
        self.Equals = functools.partial(fix_real, mgr.Equals)
        self.EqualsOrIff = functools.partial(fix_real, mgr.EqualsOrIff)
        self.Plus = functools.partial(fix_real, mgr.Plus)
        self.Minus = functools.partial(fix_real, mgr.Minus)
        self.Times = functools.partial(fix_real, mgr.Times)
        self.Div = functools.partial(fix_real, mgr.Div)
        self.Ite = functools.partial(fix_real, mgr.Ite)
        self.AllDifferent = functools.partial(fix_real, mgr.AllDifferent)

        # Tokens representing interpreted functions appearing in expressions
        # Each token is handled by a dedicated function that takes the
        # recursion stack, the token stream and the parsed token
        # Common tokens are handled in the _reset function
        self.interpreted = {"let" : self._enter_let,
                            "!" : self._enter_annotation,
                            "exists" : self._enter_quantifier,
                            "forall" : self._enter_quantifier,
                            '+':self._operator_adapter(self.Plus),
                            '-':self._operator_adapter(self._minus_or_uminus),
                            '*':self._operator_adapter(self.Times),
                            '/':self._operator_adapter(self._division),
                            'pow':self._operator_adapter(mgr.Pow),
                            '>':self._operator_adapter(self.GT),
                            '<':self._operator_adapter(self.LT),
                            '>=':self._operator_adapter(self.GE),
                            '<=':self._operator_adapter(self.LE),
                            '=':self._operator_adapter(self._equals_or_iff),
                            'not':self._operator_adapter(mgr.Not),
                            'and':self._operator_adapter(mgr.And),
                            'or':self._operator_adapter(mgr.Or),
                            'xor':self._operator_adapter(mgr.Xor),
                            '=>':self._operator_adapter(mgr.Implies),
                            '<->':self._operator_adapter(mgr.Iff),
                            'ite':self._operator_adapter(self.Ite),
                            'distinct':self._operator_adapter(self.AllDifferent),
                            'to_real':self._operator_adapter(mgr.ToReal),
                            'concat':self._operator_adapter(mgr.BVConcat),
                            'bvnot':self._operator_adapter(mgr.BVNot),
                            'bvand':self._operator_adapter(mgr.BVAnd),
                            'bvor':self._operator_adapter(mgr.BVOr),
                            'bvneg':self._operator_adapter(mgr.BVNeg),
                            'bvadd':self._operator_adapter(mgr.BVAdd),
                            'bvmul':self._operator_adapter(mgr.BVMul),
                            'bvudiv':self._operator_adapter(mgr.BVUDiv),
                            'bvurem':self._operator_adapter(mgr.BVURem),
                            'bvshl':self._operator_adapter(mgr.BVLShl),
                            'bvlshr':self._operator_adapter(mgr.BVLShr),
                            'bvsub':self._operator_adapter(mgr.BVSub),
                            'bvult':self._operator_adapter(mgr.BVULT),
                            'bvxor':self._operator_adapter(mgr.BVXor),
                            '_':self._smtlib_underscore,
                            # Extended Functions
                            'bvnand':self._operator_adapter(mgr.BVNand),
                            'bvnor':self._operator_adapter(mgr.BVNor),
                            'bvxnor':self._operator_adapter(mgr.BVXnor),
                            'bvcomp':self._operator_adapter(mgr.BVComp),
                            'bvsdiv':self._operator_adapter(mgr.BVSDiv),
                            'bvsrem':self._operator_adapter(mgr.BVSRem),
                            'bvsmod':self._operator_adapter(mgr.BVSMod),
                            'bvashr':self._operator_adapter(mgr.BVAShr),
                            'bvule':self._operator_adapter(mgr.BVULE),
                            'bvugt':self._operator_adapter(mgr.BVUGT),
                            'bvuge':self._operator_adapter(mgr.BVUGE),
                            'bvslt':self._operator_adapter(mgr.BVSLT),
                            'bvsle':self._operator_adapter(mgr.BVSLE),
                            'bvsgt':self._operator_adapter(mgr.BVSGT),
                            'bvsge':self._operator_adapter(mgr.BVSGE),
                            # arrays
                            'select':self._operator_adapter(mgr.Select),
                            'store':self._operator_adapter(mgr.Store),
                            'as':self._enter_smtlib_as,
                            }

        # Command tokens
        self.commands = {smtcmd.ASSERT : self._cmd_assert,
                         smtcmd.CHECK_SAT : self._cmd_check_sat,
                         smtcmd.CHECK_SAT_ASSUMING : self._cmd_check_sat_assuming,
                         smtcmd.DECLARE_CONST : self._cmd_declare_const,
                         smtcmd.DECLARE_FUN : self._cmd_declare_fun,
                         smtcmd.DECLARE_SORT: self._cmd_declare_sort,
                         smtcmd.DEFINE_FUN : self._cmd_define_fun,
                         smtcmd.DEFINE_FUNS_REC : self._cmd_define_funs_rec,
                         smtcmd.DEFINE_FUN_REC : self._cmd_define_fun_rec,
                         smtcmd.DEFINE_SORT: self._cmd_define_sort,
                         smtcmd.ECHO : self._cmd_echo,
                         smtcmd.EXIT : self._cmd_exit,
                         smtcmd.GET_ASSERTIONS: self._cmd_get_assertions,
                         smtcmd.GET_ASSIGNMENT : self._cmd_get_assignment,
                         smtcmd.GET_INFO: self._cmd_get_info,
                         smtcmd.GET_MODEL: self._cmd_get_model,
                         smtcmd.GET_OPTION: self._cmd_get_option,
                         smtcmd.GET_PROOF: self._cmd_get_proof,
                         smtcmd.GET_UNSAT_ASSUMPTIONS : self._cmd_get_unsat_assumptions,
                         smtcmd.GET_UNSAT_CORE: self._cmd_get_unsat_core,
                         smtcmd.GET_VALUE : self._cmd_get_value,
                         smtcmd.POP : self._cmd_pop,
                         smtcmd.PUSH : self._cmd_push,
                         smtcmd.RESET : self._cmd_reset,
                         smtcmd.RESET_ASSERTIONS : self._cmd_reset_assertions,
                         smtcmd.SET_LOGIC : self._cmd_set_logic,
                         smtcmd.SET_OPTION : self._cmd_set_option,
                         smtcmd.SET_INFO : self._cmd_set_info,
                     }

Example 22

Project: qutip
Source File: propagator.py
View license
def propagator(H, t, c_op_list=[], args={}, options=None,
               parallel=False, progress_bar=None, **kwargs):
    """
    Calculate the propagator U(t) for the density matrix or wave function such
    that :math:`\psi(t) = U(t)\psi(0)` or
    :math:`\\rho_{\mathrm vec}(t) = U(t) \\rho_{\mathrm vec}(0)`
    where :math:`\\rho_{\mathrm vec}` is the vector representation of the
    density matrix.

    Parameters
    ----------
    H : qobj or list
        Hamiltonian as a Qobj instance of a nested list of Qobjs and
        coefficients in the list-string or list-function format for
        time-dependent Hamiltonians (see description in :func:`qutip.mesolve`).

    t : float or array-like
        Time or list of times for which to evaluate the propagator.

    c_op_list : list
        List of qobj collapse operators.

    args : list/array/dictionary
        Parameters to callback functions for time-dependent Hamiltonians and
        collapse operators.

    options : :class:`qutip.Options`
        with options for the ODE solver.

    parallel : bool {False, True}
        Run the propagator in parallel mode.
    
    progress_bar: BaseProgressBar
        Optional instance of BaseProgressBar, or a subclass thereof, for
        showing the progress of the simulation. By default no progress bar
        is used, and if set to True a TextProgressBar will be used.

    Returns
    -------
     a : qobj
        Instance representing the propagator :math:`U(t)`.

    """
    kw = _default_kwargs()
    if 'num_cpus' in kwargs:
        num_cpus = kwargs['num_cpus']
    else:
        num_cpus = kw['num_cpus']
    
    if progress_bar is None:
        progress_bar = BaseProgressBar()
    elif progress_bar is True:
        progress_bar = TextProgressBar()

    if options is None:
        options = Options()
        options.rhs_reuse = True
        rhs_clear()

    if isinstance(t, (int, float, np.integer, np.floating)):
        tlist = [0, t]
    else:
        tlist = t

    td_type = _td_format_check(H, c_op_list, solver='me')[2]
    if td_type > 0:
        rhs_generate(H, c_op_list, args=args, options=options)
        
    if isinstance(H, (types.FunctionType, types.BuiltinFunctionType,
                      functools.partial)):
        H0 = H(0.0, args)
    elif isinstance(H, list):
        H0 = H[0][0] if isinstance(H[0], list) else H[0]
    else:
        H0 = H

    if len(c_op_list) == 0 and H0.isoper:
        # calculate propagator for the wave function

        N = H0.shape[0]
        dims = H0.dims
        u = np.zeros([N, N, len(tlist)], dtype=complex)
        
        if parallel:
            output = parallel_map(_parallel_sesolve,range(N),
                    task_args=(N,H,tlist,args,options),
                    progress_bar=progress_bar, num_cpus=num_cpus)
            for n in range(N):
                for k, t in enumerate(tlist):
                    u[:, n, k] = output[n].states[k].full().T 
        else:
            progress_bar.start(N)
            for n in range(0, N):
                progress_bar.update(n)
                psi0 = basis(N, n)
                output = sesolve(H, psi0, tlist, [], args, 
                                options, _safe_mode=False)
                for k, t in enumerate(tlist):
                    u[:, n, k] = output.states[k].full().T
            progress_bar.finished()

        # todo: evolving a batch of wave functions:
        # psi_0_list = [basis(N, n) for n in range(N)]
        # psi_t_list = mesolve(H, psi_0_list, [0, t], [], [], args, options)
        # for n in range(0, N):
        #    u[:,n] = psi_t_list[n][1].full().T

    elif len(c_op_list) == 0 and H0.issuper:
        # calculate the propagator for the vector representation of the
        # density matrix (a superoperator propagator)

        N = H0.shape[0]
        sqrt_N = int(np.sqrt(N))
        dims = H0.dims
        
        u = np.zeros([N, N, len(tlist)], dtype=complex)

        if parallel:
            output = parallel_map(_parallel_mesolve,range(N * N),
                    task_args=(sqrt_N,H,tlist,c_op_list,args,options),
                    progress_bar=progress_bar, num_cpus=num_cpus)
            for n in range(N * N):
                for k, t in enumerate(tlist):
                    u[:, n, k] = mat2vec(output[n].states[k].full()).T
        else:
            progress_bar.start(N)
            for n in range(0, N):
                progress_bar.update(n)
                col_idx, row_idx = np.unravel_index(n,(sqrt_N,sqrt_N))
                rho0 = Qobj(sp.csr_matrix(([1],([row_idx],[col_idx])), shape=(sqrt_N,sqrt_N), dtype=complex))
                output = mesolve(H, rho0, tlist, [], [], args, options, _safe_mode=False)
                for k, t in enumerate(tlist):
                    u[:, n, k] = mat2vec(output.states[k].full()).T
            progress_bar.finished()

    else:
        # calculate the propagator for the vector representation of the
        # density matrix (a superoperator propagator)

        N = H0.shape[0]
        dims = [H0.dims, H0.dims]

        u = np.zeros([N * N, N * N, len(tlist)], dtype=complex)
        
        if parallel:
            output = parallel_map(_parallel_mesolve,range(N * N),
                    task_args=(N,H,tlist,c_op_list,args,options),
                    progress_bar=progress_bar, num_cpus=num_cpus)
            for n in range(N * N):
                for k, t in enumerate(tlist):
                    u[:, n, k] = mat2vec(output[n].states[k].full()).T
        else:
            progress_bar.start(N * N)
            for n in range(N * N):
                progress_bar.update(n)
                col_idx, row_idx = np.unravel_index(n,(N,N))
                rho0 = Qobj(sp.csr_matrix(([1],([row_idx],[col_idx])), shape=(N,N), dtype=complex))
                output = mesolve(H, rho0, tlist, c_op_list, [], args, options, _safe_mode=False)
                for k, t in enumerate(tlist):
                    u[:, n, k] = mat2vec(output.states[k].full()).T
            progress_bar.finished()

    if len(tlist) == 2:
        return Qobj(u[:, :, 1], dims=dims)
    else:
        return np.array([Qobj(u[:, :, k], dims=dims) for k in range(len(tlist))], dtype=object)

Example 23

Project: rabix
Source File: main.py
View license
def main():
    disable_warnings()
    logging.basicConfig(level=logging.WARN)
    if len(sys.argv) == 1:
        print(USAGE)
        return

    usage = USAGE.format(resources=make_resources_usage_string(),
                         inputs='<inputs>')
    app_usage = usage

    if len(sys.argv) == 2 and \
            (sys.argv[1] == '--help' or sys.argv[1] == '-h'):
        print(USAGE)
        return

    dry_run_args = dry_run_parse()
    if not dry_run_args:
        print(USAGE)
        return

    if not (dry_run_args['<tool>']):
        print('You have to specify a tool, with --tool option')
        print(usage)
        return

    tool = get_tool(dry_run_args)
    if not tool:
        fail("Couldn't find tool.")

    if isinstance(tool, list):
        tool = loader.index.get('#main')

    if 'class' not in tool:
        fail("Document must have a 'class' field")

    if 'id' not in tool:
        tool['id'] = dry_run_args['<tool>']

    context = init_context(tool)

    app = process_builder(context, tool)
    job = None

    if isinstance(app, Job):
        job = app
        app = job.app

    rabix.expressions.update_engines(app)

    if dry_run_args['--install']:
        app.install()
        print("Install successful.")
        return

    if dry_run_args['--conformance-test']:
        job_dict = from_url(dry_run_args['<job>'])
        conformance_test(context, app, job_dict, dry_run_args.get('--basedir'))
        return

    try:
        args = docopt.docopt(usage, version=version, help=False)
        job_dict = copy.deepcopy(TEMPLATE_JOB)
        logging.root.setLevel(log_level(dry_run_args['--verbose']))

        input_file_path = args.get('<inp>') or args.get('--inp-file')
        if input_file_path:
            basedir = os.path.dirname(os.path.abspath(input_file_path))
            input_file = from_url(input_file_path)
            inputs = get_inputs(input_file, app.inputs, basedir)
            job_dict['inputs'].update(inputs)

        input_usage = job_dict['inputs']

        if job:
            basedir = os.path.dirname(args.get('<tool>'))
            job.inputs = get_inputs(job.inputs, app.inputs, basedir)
            input_usage.update(job.inputs)

        app_inputs_usage = make_app_usage_string(
            app, template=TOOL_TEMPLATE, inp=input_usage)

        app_usage = make_app_usage_string(app, USAGE, job_dict['inputs'])

        try:
            app_inputs = docopt.docopt(app_inputs_usage, args['<inputs>'])
        except docopt.DocoptExit:
            if not job:
                raise
            for inp in job.app.inputs:
                if inp.required and inp.id not in job.inputs:
                    raise
            app_inputs = {}

        if args['--help']:
            print(app_usage)
            return
        # trim leading --, and ignore empty arays
        app_inputs = {
            k[2:]: v
            for k, v in six.iteritems(app_inputs)
            if v != []
        }

        inp = get_inputs(app_inputs, app.inputs)
        if not job:
            job_dict['id'] = args.get('--outdir') or args.get('--dir')
            job_dict['app'] = app
            job = Job.from_dict(context, job_dict)

        job.inputs.update(inp)

        if args['--print-cli']:
            if not isinstance(app, CommandLineTool):
                fail(dry_run_args['<tool>'] + " is not a command line app")

            print(CLIJob(job).cmd_line())
            return

        if args['--pretty-print']:
            fmt = partial(result_str, job.id)
        else:
            fmt = lambda result: json.dumps(context.to_primitive(result))

        if not job.inputs and not args['--'] and not args['--quiet']:
            print(app_usage)
            return

        try:
            context.executor.execute(job, lambda _, result: print(fmt(result)))
        except RabixError as err:
            fail(err.message)

    except docopt.DocoptExit:
        fail(app_usage)

Example 24

Project: discord.py
Source File: voice_client.py
View license
    @asyncio.coroutine
    def create_ytdl_player(self, url, *, ytdl_options=None, **kwargs):
        """|coro|

        Creates a stream player for youtube or other services that launches
        in a separate thread to play the audio.

        The player uses the ``youtube_dl`` python library to get the information
        required to get audio from the URL. Since this uses an external library,
        you must install it yourself. You can do so by calling
        ``pip install youtube_dl``.

        You must have the ffmpeg or avconv executable in your path environment
        variable in order for this to work.

        The operations that can be done on the player are the same as those in
        :meth:`create_stream_player`. The player has been augmented and enhanced
        to have some info extracted from the URL. If youtube-dl fails to extract
        the information then the attribute is ``None``. The ``yt``, ``url``, and
        ``download_url`` attributes are always available.

        +---------------------+---------------------------------------------------------+
        |      Operation      |                       Description                       |
        +=====================+=========================================================+
        | player.yt           | The `YoutubeDL <ytdl>` instance.                        |
        +---------------------+---------------------------------------------------------+
        | player.url          | The URL that is currently playing.                      |
        +---------------------+---------------------------------------------------------+
        | player.download_url | The URL that is currently being downloaded to ffmpeg.   |
        +---------------------+---------------------------------------------------------+
        | player.title        | The title of the audio stream.                          |
        +---------------------+---------------------------------------------------------+
        | player.description  | The description of the audio stream.                    |
        +---------------------+---------------------------------------------------------+
        | player.uploader     | The uploader of the audio stream.                       |
        +---------------------+---------------------------------------------------------+
        | player.upload_date  | A datetime.date object of when the stream was uploaded. |
        +---------------------+---------------------------------------------------------+
        | player.duration     | The duration of the audio in seconds.                   |
        +---------------------+---------------------------------------------------------+
        | player.likes        | How many likes the audio stream has.                    |
        +---------------------+---------------------------------------------------------+
        | player.dislikes     | How many dislikes the audio stream has.                 |
        +---------------------+---------------------------------------------------------+
        | player.is_live      | Checks if the audio stream is currently livestreaming.  |
        +---------------------+---------------------------------------------------------+
        | player.views        | How many views the audio stream has.                    |
        +---------------------+---------------------------------------------------------+

        .. _ytdl: https://github.com/rg3/youtube-dl/blob/master/youtube_dl/YoutubeDL.py#L128-L278

        Examples
        ----------

        Basic usage: ::

            voice = await client.join_voice_channel(channel)
            player = await voice.create_ytdl_player('https://www.youtube.com/watch?v=d62TYemN6MQ')
            player.start()

        Parameters
        -----------
        url : str
            The URL that ``youtube_dl`` will take and download audio to pass
            to ``ffmpeg`` or ``avconv`` to convert to PCM bytes.
        ytdl_options : dict
            A dictionary of options to pass into the ``YoutubeDL`` instance.
            See `the documentation <ytdl>`_ for more details.
        \*\*kwargs
            The rest of the keyword arguments are forwarded to
            :func:`create_ffmpeg_player`.

        Raises
        -------
        ClientException
            Popen failure from either ``ffmpeg``/``avconv``.

        Returns
        --------
        StreamPlayer
            An augmented StreamPlayer that uses ffmpeg.
            See :meth:`create_stream_player` for base operations.
        """
        import youtube_dl

        use_avconv = kwargs.get('use_avconv', False)
        opts = {
            'format': 'webm[abr>0]/bestaudio/best',
            'prefer_ffmpeg': not use_avconv
        }

        if ytdl_options is not None and isinstance(ytdl_options, dict):
            opts.update(ytdl_options)

        ydl = youtube_dl.YoutubeDL(opts)
        func = functools.partial(ydl.extract_info, url, download=False)
        info = yield from self.loop.run_in_executor(None, func)
        if "entries" in info:
            info = info['entries'][0]

        log.info('playing URL {}'.format(url))
        download_url = info['url']
        player = self.create_ffmpeg_player(download_url, **kwargs)

        # set the dynamic attributes from the info extraction
        player.download_url = download_url
        player.url = url
        player.yt = ydl
        player.views = info.get('view_count')
        player.is_live = bool(info.get('is_live'))
        player.likes = info.get('like_count')
        player.dislikes = info.get('dislike_count')
        player.duration = info.get('duration')
        player.uploader = info.get('uploader')

        is_twitch = 'twitch' in url
        if is_twitch:
            # twitch has 'title' and 'description' sort of mixed up.
            player.title = info.get('description')
            player.description = None
        else:
            player.title = info.get('title')
            player.description = info.get('description')

        # upload date handling
        date = info.get('upload_date')
        if date:
            try:
                date = datetime.datetime.strptime(date, '%Y%M%d').date()
            except ValueError:
                date = None

        player.upload_date = date
        return player

Example 25

Project: reviewboard
Source File: chunk_generator.py
View license
    def generate_chunks(self, old, new):
        """Generate chunks for the difference between two strings.

        The strings will be normalized, ensuring they're of the proper
        encoding and ensuring they have consistent newlines. They're then
        syntax-highlighted (if requested).

        Once the strings are ready, chunks are built from the strings and
        yielded to the caller. Each chunk represents information on an
        equal, inserted, deleted, or replaced set of lines.

        The number of lines of each chunk type are stored in the
        :py:attr:`counts` dictionary, which can then be accessed after
        yielding all chunks.
        """
        is_lists = isinstance(old, list)
        assert is_lists == isinstance(new, list)

        if is_lists:
            if self.encoding_list:
                old = self.normalize_source_list(old)
                new = self.normalize_source_list(new)

            a = old
            b = new
        else:
            old, a = self.normalize_source_string(old)
            new, b = self.normalize_source_string(new)

        a_num_lines = len(a)
        b_num_lines = len(b)

        if is_lists:
            markup_a = a
            markup_b = b
        else:
            markup_a = None
            markup_b = None

            if self._get_enable_syntax_highlighting(old, new, a, b):
                source_file = \
                    self.normalize_path_for_display(self.orig_filename)
                dest_file = \
                    self.normalize_path_for_display(self.modified_filename)

                try:
                    # TODO: Try to figure out the right lexer for these files
                    #       once instead of twice.
                    if not source_file.endswith(self.STYLED_EXT_BLACKLIST):
                        markup_a = self._apply_pygments(old or '', source_file)

                    if not dest_file.endswith(self.STYLED_EXT_BLACKLIST):
                        markup_b = self._apply_pygments(new or '', dest_file)
                except:
                    pass

            if not markup_a:
                markup_a = self.NEWLINES_RE.split(escape(old))

            if not markup_b:
                markup_b = self.NEWLINES_RE.split(escape(new))

        siteconfig = SiteConfiguration.objects.get_current()
        ignore_space = True

        for pattern in siteconfig.get('diffviewer_include_space_patterns'):
            if fnmatch.fnmatch(self.orig_filename, pattern):
                ignore_space = False
                break

        self.differ = get_differ(a, b, ignore_space=ignore_space,
                                 compat_version=self.diff_compat)
        self.differ.add_interesting_lines_for_headers(self.orig_filename)

        context_num_lines = siteconfig.get("diffviewer_context_num_lines")
        collapse_threshold = 2 * context_num_lines + 3

        line_num = 1
        opcodes_generator = self.get_opcode_generator()

        counts = {
            'equal': 0,
            'replace': 0,
            'insert': 0,
            'delete': 0,
        }

        for tag, i1, i2, j1, j2, meta in opcodes_generator:
            old_lines = markup_a[i1:i2]
            new_lines = markup_b[j1:j2]
            num_lines = max(len(old_lines), len(new_lines))

            lines = map(functools.partial(self._diff_line, tag, meta),
                        range(line_num, line_num + num_lines),
                        range(i1 + 1, i2 + 1), range(j1 + 1, j2 + 1),
                        a[i1:i2], b[j1:j2], old_lines, new_lines)

            counts[tag] += num_lines

            if tag == 'equal' and num_lines > collapse_threshold:
                last_range_start = num_lines - context_num_lines

                if line_num == 1:
                    yield self._new_chunk(lines, 0, last_range_start, True)
                    yield self._new_chunk(lines, last_range_start, num_lines)
                else:
                    yield self._new_chunk(lines, 0, context_num_lines)

                    if i2 == a_num_lines and j2 == b_num_lines:
                        yield self._new_chunk(lines, context_num_lines,
                                              num_lines, True)
                    else:
                        yield self._new_chunk(lines, context_num_lines,
                                              last_range_start, True)
                        yield self._new_chunk(lines, last_range_start,
                                              num_lines)
            else:
                yield self._new_chunk(lines, 0, num_lines, False, tag, meta)

            line_num += num_lines

        self.counts = counts

Example 26

Project: Kerminal
Source File: mechjeb.py
View license
@invalid_if_not_connected
def smartass(args, widget_proxy, form, stream):
    """\
sa

Utilize MechJeb SmartASS to lock orientation vector. Disable with the "off"
subcommand.

Usage:
  sa (off | node | prograde | retrograde | normalplus | normalminus |
      radialplus | radialminus | targetplus | targetminus | relativeplus |
      relativeminus | parallelplus | parallelminus)
  sa surface <heading> <pitch> <roll>

Standard Commands (no target required):
  off            Disable SmartASS.
  node           Orient along next node.
  prograde       Orient along orbital velocity prograde.
  retrograde     Orient along orbital velocity retrograde.
  normalplus     Orient along orbital velocity normal plus.
  normalminus    Orient along orbital velocity normal minus.
  radialplus     Orient along the orbital velocity radial plus.
  radialminus    Orient along the orbital velocity radial minus.
  surface <heading> <pitch> <roll>    Set orientation based on surface angles.

Targeted Commands (target required):
  targetplus       Orient towards the target.
  targetminus      Orient away from the target.
  relativeplus     Orient along relative velocity.
  relativeminus    Orient against relative velocity.
  parallelplus     Orient along parallel velocity component.
  parallelminus    Orient against parallel velocity component.

Comments:
  When in a prograde orbit, "normalplus" corresponds to "up"/"north" while
  "normalminus" corresponds to "down"/"south". This is reversed in a retrograde
  orbit.

  Conversely, "radialplus" will correspond to "away from orbited body" in both
  prograde and retrograde orbits and "radialminus" will correspond to "towards
  orbited body".

  Targeted Commands will not work unless there is a target selected. Selecting
  a target with Telemachus is unfortunately impossible at this time and must be
  done manually in KSP. In the future, Kerminal may warn about the use of these
  commands while nothing is targeted.
    """

    if args['surface']:
        try:
            heading = float(args['<heading>'])
        except ValueError:
            form.error('Value for heading must be a number!')
            return
        try:
            pitch = float(args['<pitch>'])
        except ValueError:
            form.error('Value for pitch must be a number!')
            return
        try:
            roll = float(args['<roll>'])
        except ValueError:
            form.error('Value for roll must be a number!')
            return
        signal = 'mj.surface2[{},{},{}]'.format(heading, pitch, roll)

    elif args['off']:
        signal = 'mj.smartassoff'

    elif args['node']:
        signal = 'mj.node'

    elif args['prograde']:
        signal = 'mj.prograde'

    elif args['retrograde']:
        signal = 'mj.retrograde'

    elif args['normalplus']:
        signal = 'mj.normalplus'

    elif args['normalminus']:
        signal = 'mj.normalminus'

    elif args['radialplus']:
        signal = 'mj.radialplus'

    elif args['radialminus']:
        signal = 'mj.radialminus'

    elif args['targetplus']:
        signal = 'mj.targetplus'

    elif args['targetminus']:
        signal = 'mj.targetminus'

    elif args['relativeplus']:
        signal = 'mj.relativeplus'

    elif args['relativeminus']:
        signal = 'mj.relativeminus'

    elif args['parallelplus']:
        signal = 'mj.parallelplus'

    elif args['parallelminus']:
        signal = 'mj.parallelminus'

    stream.msg_queue.put({'run': [signal]})
    stream.add_callback(partial(mj_callback, form, signal))

Example 27

Project: scikit-learn
Source File: funcsigs.py
View license
def signature(obj):
    '''Get a signature object for the passed callable.'''

    if not callable(obj):
        raise TypeError('{0!r} is not a callable object'.format(obj))

    if isinstance(obj, types.MethodType):
        sig = signature(obj.__func__)
        if obj.__self__ is None:
            # Unbound method: the first parameter becomes positional-only
            if sig.parameters:
                first = sig.parameters.values()[0].replace(
                    kind=_POSITIONAL_ONLY)
                return sig.replace(
                    parameters=(first,) + tuple(sig.parameters.values())[1:])
            else:
                return sig
        else:
            # In this case we skip the first parameter of the underlying
            # function (usually `self` or `cls`).
            return sig.replace(parameters=tuple(sig.parameters.values())[1:])

    try:
        sig = obj.__signature__
    except AttributeError:
        pass
    else:
        if sig is not None:
            return sig

    try:
        # Was this function wrapped by a decorator?
        wrapped = obj.__wrapped__
    except AttributeError:
        pass
    else:
        return signature(wrapped)

    if isinstance(obj, types.FunctionType):
        return Signature.from_function(obj)

    if isinstance(obj, functools.partial):
        sig = signature(obj.func)

        new_params = OrderedDict(sig.parameters.items())

        partial_args = obj.args or ()
        partial_keywords = obj.keywords or {}
        try:
            ba = sig.bind_partial(*partial_args, **partial_keywords)
        except TypeError as ex:
            msg = 'partial object {0!r} has incorrect arguments'.format(obj)
            raise ValueError(msg)

        for arg_name, arg_value in ba.arguments.items():
            param = new_params[arg_name]
            if arg_name in partial_keywords:
                # We set a new default value, because the following code
                # is correct:
                #
                #   >>> def foo(a): print(a)
                #   >>> print(partial(partial(foo, a=10), a=20)())
                #   20
                #   >>> print(partial(partial(foo, a=10), a=20)(a=30))
                #   30
                #
                # So, with 'partial' objects, passing a keyword argument is
                # like setting a new default value for the corresponding
                # parameter
                #
                # We also mark this parameter with '_partial_kwarg'
                # flag.  Later, in '_bind', the 'default' value of this
                # parameter will be added to 'kwargs', to simulate
                # the 'functools.partial' real call.
                new_params[arg_name] = param.replace(default=arg_value,
                                                     _partial_kwarg=True)

            elif (param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL) and
                            not param._partial_kwarg):
                new_params.pop(arg_name)

        return sig.replace(parameters=new_params.values())

    sig = None
    if isinstance(obj, type):
        # obj is a class or a metaclass

        # First, let's see if it has an overloaded __call__ defined
        # in its metaclass
        call = _get_user_defined_method(type(obj), '__call__')
        if call is not None:
            sig = signature(call)
        else:
            # Now we check if the 'obj' class has a '__new__' method
            new = _get_user_defined_method(obj, '__new__')
            if new is not None:
                sig = signature(new)
            else:
                # Finally, we should have at least __init__ implemented
                init = _get_user_defined_method(obj, '__init__')
                if init is not None:
                    sig = signature(init)
    elif not isinstance(obj, _NonUserDefinedCallables):
        # An object with __call__
        # We also check that the 'obj' is not an instance of
        # _WrapperDescriptor or _MethodWrapper to avoid
        # infinite recursion (and even potential segfault)
        call = _get_user_defined_method(type(obj), '__call__', 'im_func')
        if call is not None:
            sig = signature(call)

    if sig is not None:
        # For classes and objects we skip the first parameter of their
        # __call__, __new__, or __init__ methods
        return sig.replace(parameters=tuple(sig.parameters.values())[1:])

    if isinstance(obj, types.BuiltinFunctionType):
        # Raise a nicer error message for builtins
        msg = 'no signature found for builtin function {0!r}'.format(obj)
        raise ValueError(msg)

    raise ValueError('callable {0!r} is not supported by signature'.format(obj))

Example 28

Project: GPy
Source File: likelihood_tests.py
View license
    def setUp(self):
        np.random.seed(fixed_seed)
        self.N = 15
        self.D = 3
        self.X = np.random.rand(self.N, self.D)*10

        self.real_std = 0.1
        noise = np.random.randn(*self.X[:, 0].shape)*self.real_std
        self.Y = (np.sin(self.X[:, 0]*2*np.pi) + noise)[:, None]
        self.f = np.random.rand(self.N, 1)
        self.binary_Y = np.asarray(np.random.rand(self.N) > 0.5, dtype=np.int)[:, None]
        self.binary_Y[self.binary_Y == 0.0] = -1.0
        self.positive_Y = np.exp(self.Y.copy())
        tmp = np.round(self.X[:, 0]*3-3)[:, None] + np.random.randint(0,3, self.X.shape[0])[:, None]
        self.integer_Y = np.where(tmp > 0, tmp, 0)
        self.ns = np.random.poisson(50, size=self.N)[:, None]
        p = np.abs(np.cos(2*np.pi*self.X + np.random.normal(scale=.2, size=(self.N, self.D)))).mean(1)
        self.binomial_Y = np.array([np.random.binomial(self.ns[i], p[i]) for i in range(p.shape[0])])[:, None]
        
        self.var = 0.2
        self.deg_free = 4.0

        #Make a bigger step as lower bound can be quite curved
        self.step = 1e-4

        """
        Dictionary where we nest models we would like to check
            Name: {
                "model": model_instance,
                "grad_params": {
                    "names": [names_of_params_we_want, to_grad_check],
                    "vals": [values_of_params, to_start_at],
                    "constrain": [constraint_wrappers, listed_here]
                    },
                "laplace": boolean_of_whether_model_should_work_for_laplace,
                "ep": boolean_of_whether_model_should_work_for_laplace,
                "link_f_constraints": [constraint_wrappers, listed_here]
                }
        """
        self.noise_models = {"Student_t_default": {
            "model": GPy.likelihoods.StudentT(deg_free=self.deg_free, sigma2=self.var),
            "grad_params": {
                "names": [".*t_scale2"],
                "vals": [self.var],
                "constraints": [(".*t_scale2", self.constrain_positive), (".*deg_free", self.constrain_fixed)]
            },
            "laplace": True
            },
            #"Student_t_deg_free": {
                #"model": GPy.likelihoods.StudentT(deg_free=self.deg_free, sigma2=self.var),
                #"grad_params": {
                    #"names": [".*deg_free"],
                    #"vals": [self.deg_free],
                    #"constraints": [(".*t_scale2", self.constrain_fixed), (".*deg_free", self.constrain_positive)]
                #},
                #"laplace": True
            #},
            "Student_t_1_var": {
                "model": GPy.likelihoods.StudentT(deg_free=self.deg_free, sigma2=self.var),
                "grad_params": {
                    "names": [".*t_scale2"],
                    "vals": [1.0],
                    "constraints": [(".*t_scale2", self.constrain_positive), (".*deg_free", self.constrain_fixed)]
                },
                "laplace": True
            },
            # FIXME: This is a known failure point, when the degrees of freedom
            # are very small, and the variance is relatively small, the
            # likelihood is log-concave and problems occur
            # "Student_t_small_deg_free": {
                # "model": GPy.likelihoods.StudentT(deg_free=1.5, sigma2=self.var),
                # "grad_params": {
                    # "names": [".*t_scale2"],
                    # "vals": [self.var],
                    # "constraints": [(".*t_scale2", self.constrain_positive), (".*deg_free", self.constrain_fixed)]
                # },
                # "laplace": True
            # },
            "Student_t_small_var": {
                "model": GPy.likelihoods.StudentT(deg_free=self.deg_free, sigma2=self.var),
                "grad_params": {
                    "names": [".*t_scale2"],
                    "vals": [0.001],
                    "constraints": [(".*t_scale2", self.constrain_positive), (".*deg_free", self.constrain_fixed)]
                },
                "laplace": True
            },
            "Student_t_large_var": {
                "model": GPy.likelihoods.StudentT(deg_free=self.deg_free, sigma2=self.var),
                "grad_params": {
                    "names": [".*t_scale2"],
                    "vals": [10.0],
                    "constraints": [(".*t_scale2", self.constrain_positive), (".*deg_free", self.constrain_fixed)]
                },
                "laplace": True
            },
            "Student_t_approx_gauss": {
                "model": GPy.likelihoods.StudentT(deg_free=1000, sigma2=self.var),
                "grad_params": {
                    "names": [".*t_scale2"],
                    "vals": [self.var],
                    "constraints": [(".*t_scale2", self.constrain_positive), (".*deg_free", self.constrain_fixed)]
                },
                "laplace": True
            },
            "Gaussian_default": {
                "model": GPy.likelihoods.Gaussian(variance=self.var),
                "grad_params": {
                    "names": [".*variance"],
                    "vals": [self.var],
                    "constraints": [(".*variance", self.constrain_positive)]
                },
                "laplace": True,
                "ep": False, # FIXME: Should be True when we have it working again
                "variational_expectations": True,
            },
            "Gaussian_log": {
                "model": GPy.likelihoods.Gaussian(gp_link=link_functions.Log(), variance=self.var),
                "grad_params": {
                    "names": [".*variance"],
                    "vals": [self.var],
                    "constraints": [(".*variance", self.constrain_positive)]
                },
                "laplace": True,
                "variational_expectations": True
            },
            #"Gaussian_probit": {
            #"model": GPy.likelihoods.gaussian(gp_link=link_functions.Probit(), variance=self.var, D=self.D, N=self.N),
            #"grad_params": {
            #"names": ["noise_model_variance"],
            #"vals": [self.var],
            #"constraints": [constrain_positive]
            #},
            #"laplace": True
            #},
            #"Gaussian_log_ex": {
            #"model": GPy.likelihoods.gaussian(gp_link=link_functions.Log_ex_1(), variance=self.var, D=self.D, N=self.N),
            #"grad_params": {
            #"names": ["noise_model_variance"],
            #"vals": [self.var],
            #"constraints": [constrain_positive]
            #},
            #"laplace": True
            #},
            "Bernoulli_default": {
                "model": GPy.likelihoods.Bernoulli(),
                "link_f_constraints": [partial(self.constrain_bounded, lower=0, upper=1)],
                "laplace": True,
                "Y": self.binary_Y,
                "ep": True, # FIXME: Should be True when we have it working again
                "variational_expectations": True
            },
            "Exponential_default": {
                "model": GPy.likelihoods.Exponential(),
                "link_f_constraints": [self.constrain_positive],
                "Y": self.positive_Y,
                "laplace": True,
            },
            "Poisson_default": {
                "model": GPy.likelihoods.Poisson(),
                "link_f_constraints": [self.constrain_positive],
                "Y": self.integer_Y,
                "laplace": True,
                "ep": False #Should work though...
            },
            "Binomial_default": {
                "model": GPy.likelihoods.Binomial(),
                "link_f_constraints": [partial(self.constrain_bounded, lower=0, upper=1)],
                "Y": self.binomial_Y,
                "Y_metadata": {'trials': self.ns},
                "laplace": True,
            },
            #,
            #GAMMA needs some work!"Gamma_default": {
            #"model": GPy.likelihoods.Gamma(),
            #"link_f_constraints": [constrain_positive],
            #"Y": self.positive_Y,
            #"laplace": True
            #}
        }

Example 29

Project: SickGear
Source File: curl_httpclient.py
View license
    def _curl_setup_request(self, curl, request, buffer, headers):
        curl.setopt(pycurl.URL, native_str(request.url))

        # libcurl's magic "Expect: 100-continue" behavior causes delays
        # with servers that don't support it (which include, among others,
        # Google's OpenID endpoint).  Additionally, this behavior has
        # a bug in conjunction with the curl_multi_socket_action API
        # (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
        # which increases the delays.  It's more trouble than it's worth,
        # so just turn off the feature (yes, setting Expect: to an empty
        # value is the official way to disable this)
        if "Expect" not in request.headers:
            request.headers["Expect"] = ""

        # libcurl adds Pragma: no-cache by default; disable that too
        if "Pragma" not in request.headers:
            request.headers["Pragma"] = ""

        curl.setopt(pycurl.HTTPHEADER,
                    ["%s: %s" % (native_str(k), native_str(v))
                     for k, v in request.headers.get_all()])

        curl.setopt(pycurl.HEADERFUNCTION,
                    functools.partial(self._curl_header_callback,
                                      headers, request.header_callback))
        if request.streaming_callback:
            def write_function(chunk):
                self.io_loop.add_callback(request.streaming_callback, chunk)
        else:
            write_function = buffer.write
        if bytes is str:  # py2
            curl.setopt(pycurl.WRITEFUNCTION, write_function)
        else:  # py3
            # Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
            # a fork/port.  That version has a bug in which it passes unicode
            # strings instead of bytes to the WRITEFUNCTION.  This means that
            # if you use a WRITEFUNCTION (which tornado always does), you cannot
            # download arbitrary binary data.  This needs to be fixed in the
            # ported pycurl package, but in the meantime this lambda will
            # make it work for downloading (utf8) text.
            curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
        curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
        curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
        curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
        curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
        if request.user_agent:
            curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
        else:
            curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
        if request.network_interface:
            curl.setopt(pycurl.INTERFACE, request.network_interface)
        if request.decompress_response:
            curl.setopt(pycurl.ENCODING, "gzip,deflate")
        else:
            curl.setopt(pycurl.ENCODING, "none")
        if request.proxy_host and request.proxy_port:
            curl.setopt(pycurl.PROXY, request.proxy_host)
            curl.setopt(pycurl.PROXYPORT, request.proxy_port)
            if request.proxy_username:
                credentials = '%s:%s' % (request.proxy_username,
                                         request.proxy_password)
                curl.setopt(pycurl.PROXYUSERPWD, credentials)
        else:
            curl.setopt(pycurl.PROXY, '')
            curl.unsetopt(pycurl.PROXYUSERPWD)
        if request.validate_cert:
            curl.setopt(pycurl.SSL_VERIFYPEER, 1)
            curl.setopt(pycurl.SSL_VERIFYHOST, 2)
        else:
            curl.setopt(pycurl.SSL_VERIFYPEER, 0)
            curl.setopt(pycurl.SSL_VERIFYHOST, 0)
        if request.ca_certs is not None:
            curl.setopt(pycurl.CAINFO, request.ca_certs)
        else:
            # There is no way to restore pycurl.CAINFO to its default value
            # (Using unsetopt makes it reject all certificates).
            # I don't see any way to read the default value from python so it
            # can be restored later.  We'll have to just leave CAINFO untouched
            # if no ca_certs file was specified, and require that if any
            # request uses a custom ca_certs file, they all must.
            pass

        if request.allow_ipv6 is False:
            # Curl behaves reasonably when DNS resolution gives an ipv6 address
            # that we can't reach, so allow ipv6 unless the user asks to disable.
            curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
        else:
            curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)

        # Set the request method through curl's irritating interface which makes
        # up names for almost every single method
        curl_options = {
            "GET": pycurl.HTTPGET,
            "POST": pycurl.POST,
            "PUT": pycurl.UPLOAD,
            "HEAD": pycurl.NOBODY,
        }
        custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
        for o in curl_options.values():
            curl.setopt(o, False)
        if request.method in curl_options:
            curl.unsetopt(pycurl.CUSTOMREQUEST)
            curl.setopt(curl_options[request.method], True)
        elif request.allow_nonstandard_methods or request.method in custom_methods:
            curl.setopt(pycurl.CUSTOMREQUEST, request.method)
        else:
            raise KeyError('unknown method ' + request.method)

        # Handle curl's cryptic options for every individual HTTP method
        if request.method == "GET":
            if request.body is not None:
                raise ValueError('Body must be None for GET request')
        elif request.method in ("POST", "PUT") or request.body:
            if request.body is None:
                raise ValueError(
                    'Body must not be None for "%s" request'
                    % request.method)

            request_buffer = BytesIO(utf8(request.body))

            def ioctl(cmd):
                if cmd == curl.IOCMD_RESTARTREAD:
                    request_buffer.seek(0)
            curl.setopt(pycurl.READFUNCTION, request_buffer.read)
            curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
            if request.method == "POST":
                curl.setopt(pycurl.POSTFIELDSIZE, len(request.body))
            else:
                curl.setopt(pycurl.UPLOAD, True)
                curl.setopt(pycurl.INFILESIZE, len(request.body))

        if request.auth_username is not None:
            userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')

            if request.auth_mode is None or request.auth_mode == "basic":
                curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
            elif request.auth_mode == "digest":
                curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
            else:
                raise ValueError("Unsupported auth_mode %s" % request.auth_mode)

            curl.setopt(pycurl.USERPWD, native_str(userpwd))
            curl_log.debug("%s %s (username: %r)", request.method, request.url,
                           request.auth_username)
        else:
            curl.unsetopt(pycurl.USERPWD)
            curl_log.debug("%s %s", request.method, request.url)

        if request.client_cert is not None:
            curl.setopt(pycurl.SSLCERT, request.client_cert)

        if request.client_key is not None:
            curl.setopt(pycurl.SSLKEY, request.client_key)

        if request.ssl_options is not None:
            raise ValueError("ssl_options not supported in curl_httpclient")

        if threading.activeCount() > 1:
            # libcurl/pycurl is not thread-safe by default.  When multiple threads
            # are used, signals should be disabled.  This has the side effect
            # of disabling DNS timeouts in some environments (when libcurl is
            # not linked against ares), so we don't do it when there is only one
            # thread.  Applications that use many short-lived threads may need
            # to set NOSIGNAL manually in a prepare_curl_callback since
            # there may not be any other threads running at the time we call
            # threading.activeCount.
            curl.setopt(pycurl.NOSIGNAL, 1)
        if request.prepare_curl_callback is not None:
            request.prepare_curl_callback(curl)

Example 30

Project: SickRage
Source File: curl_httpclient.py
View license
    def _curl_setup_request(self, curl, request, buffer, headers):
        curl.setopt(pycurl.URL, native_str(request.url))

        # libcurl's magic "Expect: 100-continue" behavior causes delays
        # with servers that don't support it (which include, among others,
        # Google's OpenID endpoint).  Additionally, this behavior has
        # a bug in conjunction with the curl_multi_socket_action API
        # (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
        # which increases the delays.  It's more trouble than it's worth,
        # so just turn off the feature (yes, setting Expect: to an empty
        # value is the official way to disable this)
        if "Expect" not in request.headers:
            request.headers["Expect"] = ""

        # libcurl adds Pragma: no-cache by default; disable that too
        if "Pragma" not in request.headers:
            request.headers["Pragma"] = ""

        curl.setopt(pycurl.HTTPHEADER,
                    ["%s: %s" % (native_str(k), native_str(v))
                     for k, v in request.headers.get_all()])

        curl.setopt(pycurl.HEADERFUNCTION,
                    functools.partial(self._curl_header_callback,
                                      headers, request.header_callback))
        if request.streaming_callback:
            def write_function(chunk):
                self.io_loop.add_callback(request.streaming_callback, chunk)
        else:
            write_function = buffer.write
        if bytes is str:  # py2
            curl.setopt(pycurl.WRITEFUNCTION, write_function)
        else:  # py3
            # Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
            # a fork/port.  That version has a bug in which it passes unicode
            # strings instead of bytes to the WRITEFUNCTION.  This means that
            # if you use a WRITEFUNCTION (which tornado always does), you cannot
            # download arbitrary binary data.  This needs to be fixed in the
            # ported pycurl package, but in the meantime this lambda will
            # make it work for downloading (utf8) text.
            curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
        curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
        curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
        curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
        curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
        if request.user_agent:
            curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
        else:
            curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
        if request.network_interface:
            curl.setopt(pycurl.INTERFACE, request.network_interface)
        if request.decompress_response:
            curl.setopt(pycurl.ENCODING, "gzip,deflate")
        else:
            curl.setopt(pycurl.ENCODING, "none")
        if request.proxy_host and request.proxy_port:
            curl.setopt(pycurl.PROXY, request.proxy_host)
            curl.setopt(pycurl.PROXYPORT, request.proxy_port)
            if request.proxy_username:
                credentials = '%s:%s' % (request.proxy_username,
                                         request.proxy_password)
                curl.setopt(pycurl.PROXYUSERPWD, credentials)
        else:
            curl.setopt(pycurl.PROXY, '')
            curl.unsetopt(pycurl.PROXYUSERPWD)
        if request.validate_cert:
            curl.setopt(pycurl.SSL_VERIFYPEER, 1)
            curl.setopt(pycurl.SSL_VERIFYHOST, 2)
        else:
            curl.setopt(pycurl.SSL_VERIFYPEER, 0)
            curl.setopt(pycurl.SSL_VERIFYHOST, 0)
        if request.ca_certs is not None:
            curl.setopt(pycurl.CAINFO, request.ca_certs)
        else:
            # There is no way to restore pycurl.CAINFO to its default value
            # (Using unsetopt makes it reject all certificates).
            # I don't see any way to read the default value from python so it
            # can be restored later.  We'll have to just leave CAINFO untouched
            # if no ca_certs file was specified, and require that if any
            # request uses a custom ca_certs file, they all must.
            pass

        if request.allow_ipv6 is False:
            # Curl behaves reasonably when DNS resolution gives an ipv6 address
            # that we can't reach, so allow ipv6 unless the user asks to disable.
            curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
        else:
            curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)

        # Set the request method through curl's irritating interface which makes
        # up names for almost every single method
        curl_options = {
            "GET": pycurl.HTTPGET,
            "POST": pycurl.POST,
            "PUT": pycurl.UPLOAD,
            "HEAD": pycurl.NOBODY,
        }
        custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
        for o in curl_options.values():
            curl.setopt(o, False)
        if request.method in curl_options:
            curl.unsetopt(pycurl.CUSTOMREQUEST)
            curl.setopt(curl_options[request.method], True)
        elif request.allow_nonstandard_methods or request.method in custom_methods:
            curl.setopt(pycurl.CUSTOMREQUEST, request.method)
        else:
            raise KeyError('unknown method ' + request.method)

        body_expected = request.method in ("POST", "PATCH", "PUT")
        body_present = request.body is not None
        if not request.allow_nonstandard_methods:
            # Some HTTP methods nearly always have bodies while others
            # almost never do. Fail in this case unless the user has
            # opted out of sanity checks with allow_nonstandard_methods.
            if ((body_expected and not body_present) or
                    (body_present and not body_expected)):
                raise ValueError(
                    'Body must %sbe None for method %s (unless '
                    'allow_nonstandard_methods is true)' %
                    ('not ' if body_expected else '', request.method))

        if body_expected or body_present:
            if request.method == "GET":
                # Even with `allow_nonstandard_methods` we disallow
                # GET with a body (because libcurl doesn't allow it
                # unless we use CUSTOMREQUEST). While the spec doesn't
                # forbid clients from sending a body, it arguably
                # disallows the server from doing anything with them.
                raise ValueError('Body must be None for GET request')
            request_buffer = BytesIO(utf8(request.body or ''))

            def ioctl(cmd):
                if cmd == curl.IOCMD_RESTARTREAD:
                    request_buffer.seek(0)
            curl.setopt(pycurl.READFUNCTION, request_buffer.read)
            curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
            if request.method == "POST":
                curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ''))
            else:
                curl.setopt(pycurl.UPLOAD, True)
                curl.setopt(pycurl.INFILESIZE, len(request.body or ''))

        if request.auth_username is not None:
            userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')

            if request.auth_mode is None or request.auth_mode == "basic":
                curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
            elif request.auth_mode == "digest":
                curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
            else:
                raise ValueError("Unsupported auth_mode %s" % request.auth_mode)

            curl.setopt(pycurl.USERPWD, native_str(userpwd))
            curl_log.debug("%s %s (username: %r)", request.method, request.url,
                           request.auth_username)
        else:
            curl.unsetopt(pycurl.USERPWD)
            curl_log.debug("%s %s", request.method, request.url)

        if request.client_cert is not None:
            curl.setopt(pycurl.SSLCERT, request.client_cert)

        if request.client_key is not None:
            curl.setopt(pycurl.SSLKEY, request.client_key)

        if request.ssl_options is not None:
            raise ValueError("ssl_options not supported in curl_httpclient")

        if threading.activeCount() > 1:
            # libcurl/pycurl is not thread-safe by default.  When multiple threads
            # are used, signals should be disabled.  This has the side effect
            # of disabling DNS timeouts in some environments (when libcurl is
            # not linked against ares), so we don't do it when there is only one
            # thread.  Applications that use many short-lived threads may need
            # to set NOSIGNAL manually in a prepare_curl_callback since
            # there may not be any other threads running at the time we call
            # threading.activeCount.
            curl.setopt(pycurl.NOSIGNAL, 1)
        if request.prepare_curl_callback is not None:
            request.prepare_curl_callback(curl)

Example 31

Project: txstatsd
Source File: service.py
View license
def createService(options):
    """Create a txStatsD service."""
    from carbon.routers import ConsistentHashingRouter
    from carbon.client import CarbonClientManager
    from carbon.conf import settings

    settings.MAX_QUEUE_SIZE = options["max-queue-size"]
    settings.MAX_DATAPOINTS_PER_MESSAGE = options["max-datapoints-per-message"]

    root_service = MultiService()
    root_service.setName("statsd")

    prefix = options["prefix"]
    if prefix is None:
        prefix = "statsd"

    instance_name = options["instance-name"]
    if not instance_name:
        instance_name = platform.node()

    # initialize plugins
    plugin_metrics = []
    for plugin in getPlugins(IMetricFactory):
        plugin.configure(options)
        plugin_metrics.append(plugin)

    processor = None
    if options["dump-mode"]:
        # LoggingMessageProcessor supersedes
        #  any other processor class in "dump-mode"
        assert not hasattr(log, 'info')
        log.info = log.msg  # for compatibility with LMP logger interface
        processor = functools.partial(LoggingMessageProcessor, logger=log)

    if options["statsd-compliance"]:
        processor = (processor or MessageProcessor)(plugins=plugin_metrics)
        input_router = Router(processor, options['routing'], root_service)
        connection = InternalClient(input_router)
        metrics = Metrics(connection)
    else:
        processor = (processor or ConfigurableMessageProcessor)(
            message_prefix=prefix,
            internal_metrics_prefix=prefix + "." + instance_name + ".",
            plugins=plugin_metrics)
        input_router = Router(processor, options['routing'], root_service)
        connection = InternalClient(input_router)
        metrics = ExtendedMetrics(connection)

    if not options["carbon-cache-host"]:
        options["carbon-cache-host"].append("127.0.0.1")
    if not options["carbon-cache-port"]:
        options["carbon-cache-port"].append(2004)
    if not options["carbon-cache-name"]:
        options["carbon-cache-name"].append(None)

    reporting = ReportingService(instance_name)
    reporting.setServiceParent(root_service)

    reporting.schedule(report_client_manager_stats,
                       options["flush-interval"] / 1000,
                       metrics.gauge)

    if options["report"] is not None:
        from txstatsd import process
        from twisted.internet import reactor

        reporting.schedule(
            process.report_reactor_stats(reactor), 60, metrics.gauge)
        reports = [name.strip() for name in options["report"].split(",")]
        for report_name in reports:
            if report_name == "reactor":
                inspector = ReactorInspectorService(reactor, metrics,
                                                    loop_time=0.05)
                inspector.setServiceParent(root_service)

            for reporter in getattr(process, "%s_STATS" %
                                    report_name.upper(), ()):
                reporting.schedule(reporter, 60, metrics.gauge)

    # XXX Make this configurable.
    router = ConsistentHashingRouter()
    carbon_client = CarbonClientManager(router)
    carbon_client.setServiceParent(root_service)

    for host, port, name in zip(options["carbon-cache-host"],
                                options["carbon-cache-port"],
                                options["carbon-cache-name"]):
        carbon_client.startClient((host, port, name))

    statsd_service = StatsDService(carbon_client, input_router,
                                   options["flush-interval"])
    statsd_service.setServiceParent(root_service)

    statsd_server_protocol = StatsDServerProtocol(
        input_router,
        monitor_message=options["monitor-message"],
        monitor_response=options["monitor-response"])

    listener = UDPServer(options["listen-port"], statsd_server_protocol)
    listener.setServiceParent(root_service)

    if options["listen-tcp-port"] is not None:
        statsd_tcp_server_factory = StatsDTCPServerFactory(
            input_router,
            monitor_message=options["monitor-message"],
            monitor_response=options["monitor-response"])

        listener = TCPServer(options["listen-tcp-port"],
                             statsd_tcp_server_factory)
        listener.setServiceParent(root_service)

    httpinfo_service = httpinfo.makeService(options, processor, statsd_service)
    httpinfo_service.setServiceParent(root_service)

    return root_service

Example 32

Project: twitter
Source File: archiver.py
View license
def main(args=sys.argv[1:]):
    options = {
        'oauth': False,
        'save-dir': ".",
        'api-rate': False,
        'timeline': "",
        'mentions': "",
        'dms': "",
        'favorites': False,
        'follow-redirects': False,
        'redirect-sites': None,
        'isoformat': False,
    }
    try:
        parse_args(args, options)
    except GetoptError as e:
        err("I can't do that, %s." % e)
        raise SystemExit(1)

    # exit if no user given
    # except if asking for API rate, or archive of timeline or mentions
    if not options['extra_args'] and not (options['api-rate'] or
                                          options['timeline'] or
                                          options['mentions'] or
                                          options['dms']):
        print(__doc__)
        return

    # authenticate using OAuth, asking for token if necessary
    if options['oauth']:
        oauth_filename = (os.environ.get('HOME', 
                          os.environ.get('USERPROFILE', '')) 
                          + os.sep
                          + '.twitter-archiver_oauth')
        
        if not os.path.exists(oauth_filename):
            oauth_dance("Twitter-Archiver", CONSUMER_KEY, CONSUMER_SECRET,
                        oauth_filename)
        oauth_token, oauth_token_secret = read_token_file(oauth_filename)
        auth = OAuth(oauth_token, oauth_token_secret, CONSUMER_KEY,
                     CONSUMER_SECRET)
    else:
        auth = NoAuth()

    twitter = Twitter(auth=auth, api_version='1.1', domain='api.twitter.com')

    if options['api-rate']:
        rate_limit_status(twitter)
        return

    global format_text
    if options['follow-redirects'] or options['redirect-sites'] :
        if options['redirect-sites']:
            hosts = parse_host_list(options['redirect-sites'])
        else:
            hosts = None
        format_text = functools.partial(expand_format_text, hosts)
    else:
        format_text = direct_format_text

    # save own timeline or mentions (the user used in OAuth)
    if options['timeline'] or options['mentions']:
        if isinstance(auth, NoAuth):
            err("You must be authenticated to save timeline or mentions.")
            raise SystemExit(1)

        if options['timeline']:
            filename = options['save-dir'] + os.sep + options['timeline']
            print("* Archiving own timeline in %s" % filename)
        elif options['mentions']:
            filename = options['save-dir'] + os.sep + options['mentions']
            print("* Archiving own mentions in %s" % filename)

        tweets = {}
        try:
            tweets = load_tweets(filename)
        except Exception as e:
            err("Error when loading saved tweets: %s - continuing without"
                % str(e))

        try:
            statuses(twitter, "", tweets, options['mentions'], options['favorites'], isoformat=options['isoformat'])
        except KeyboardInterrupt:
            err()
            err("Interrupted")
            raise SystemExit(1)

        save_tweets(filename, tweets)
        if options['timeline']:
            print("Total tweets in own timeline: %i" % len(tweets))
        elif options['mentions']:
            print("Total mentions: %i" % len(tweets))

    if options['dms']:
        if isinstance(auth, NoAuth):
            err("You must be authenticated to save DMs.")
            raise SystemExit(1)

        filename = options['save-dir'] + os.sep + options['dms']
        print("* Archiving own DMs in %s" % filename)

        dms = {}
        try:
            dms = load_tweets(filename)
        except Exception as e:
            err("Error when loading saved DMs: %s - continuing without"
                % str(e))

        try:
            statuses(twitter, "", dms, received_dms=True, isoformat=options['isoformat'])
            statuses(twitter, "", dms, received_dms=False, isoformat=options['isoformat'])
        except KeyboardInterrupt:
            err()
            err("Interrupted")
            raise SystemExit(1)

        save_tweets(filename, dms)
        print("Total DMs sent and received: %i" % len(dms))


    # read users from command-line or stdin
    users = options['extra_args']
    if len(users) == 1 and users[0] == "-":
        users = [line.strip() for line in sys.stdin.readlines()]

    # save tweets for every user
    total, total_new = 0, 0
    for user in users:
        filename = options['save-dir'] + os.sep + user
        if options['favorites']:
            filename = filename + "-favorites"
        print("* Archiving %s tweets in %s" % (user, filename))

        tweets = {}
        try:
            tweets = load_tweets(filename)
        except Exception as e:
            err("Error when loading saved tweets: %s - continuing without"
                % str(e))

        new = 0
        before = len(tweets)
        try:
            statuses(twitter, user, tweets, options['mentions'], options['favorites'], isoformat=options['isoformat'])
        except KeyboardInterrupt:
            err()
            err("Interrupted")
            raise SystemExit(1)

        save_tweets(filename, tweets)
        total += len(tweets)
        new = len(tweets) - before
        total_new += new
        print("Total tweets for %s: %i (%i new)" % (user, len(tweets), new))

    print("Total: %i tweets (%i new) for %i users"
          % (total, total_new, len(users)))

Example 33

Project: smartlearner
Source File: test_smartlearner.py
View license
def test_resume_experiment():
    # Loading dataset
    trainset, validset, testset = load_mnist()
    nb_classes = 10

    # Nested function to build a trainer.
    def _build_trainer(nb_epochs, optimizer_cls):
        print("Will build a trainer is going to train a Perceptron for {0} epochs.".format(nb_epochs))

        print("Building model")
        model = Perceptron(trainset.input_size, nb_classes)
        model.initialize(initer.UniformInitializer(random_seed=1234))

        print("Building optimizer")
        loss = NLL(model, trainset)
        optimizer = optimizer_cls(loss=loss)
        print("Optimizer: {}".format(type(optimizer).__name__))
        #optimizer = SGD(loss=loss)
        #optimizer.append_direction_modifier(ConstantLearningRate(0.1))

        # Use mini batches of 100 examples.
        batch_scheduler = MiniBatchScheduler(trainset, 100)

        print("Building trainer")
        trainer = Trainer(optimizer, batch_scheduler)

        # Print time for one epoch
        trainer.append_task(tasks.PrintEpochDuration())
        trainer.append_task(tasks.PrintTrainingDuration())

        # Log training error
        loss_monitor = views.MonitorVariable(loss.loss)
        avg_loss = tasks.AveragePerEpoch(loss_monitor)

        # Print NLL mean/stderror.
        nll = views.LossView(loss=NLL(model, validset), batch_scheduler=FullBatchScheduler(validset))
        logger = tasks.Logger(loss_monitor, avg_loss, nll.mean)
        trainer.append_task(logger, avg_loss)

        # Train for `nb_epochs` epochs (stopping criteria should be added at the end).
        trainer.append_task(stopping_criteria.MaxEpochStopping(nb_epochs))

        return trainer, nll, logger

    for optimizer_cls in [SGD, Adam, Adadelta,
                          partial(AdaGrad, lr=0.1), partial(RMSProp, lr=1)]:
        trainer1, nll1, logger1 = _build_trainer(nb_epochs=10, optimizer_cls=optimizer_cls)
        print("Compiling training graph")
        trainer1.build_theano_graph()

        print("Training")
        trainer1.train()

        trainer2a, nll2a, logger2a = _build_trainer(5, optimizer_cls)
        print("Compiling training graph")
        trainer2a.build_theano_graph()

        print("Training")
        trainer2a.train()

        # Save model halfway during training and resume it.
        with tempfile.TemporaryDirectory() as experiment_dir:
            print("Saving")
            # Save current state of the model (i.e. after 5 epochs).
            trainer2a.save(experiment_dir)

            print("Loading")
            # Load previous state from which training will resume.
            trainer2b, nll2b, logger2b = _build_trainer(10, optimizer_cls)
            trainer2b.load(experiment_dir)

            # Check we correctly reloaded the model.
            assert_equal(len(trainer2a._optimizer.loss.model.parameters),
                         len(trainer2b._optimizer.loss.model.parameters))
            for param1, param2 in zip(trainer2a._optimizer.loss.model.parameters,
                                      trainer2b._optimizer.loss.model.parameters):
                assert_array_equal(param1.get_value(), param2.get_value(), err_msg=param1.name)

            # Check that the `status` state after loading matches the one saved.
            assert_equal(trainer2b.status.current_epoch, trainer2a.status.current_epoch)
            assert_equal(trainer2b.status.current_update, trainer2a.status.current_update)
            assert_equal(trainer2b.status.current_update_in_epoch, trainer2a.status.current_update_in_epoch)
            assert_equal(trainer2b.status.training_time, trainer2a.status.training_time)
            assert_equal(trainer2b.status.done, trainer2a.status.done)
            assert_equal(trainer2b.status.extra, trainer2a.status.extra)

            # Check that the `batch_scheduler` state after loading matches the one saved.
            assert_equal(trainer2b._batch_scheduler.batch_size, trainer2a._batch_scheduler.batch_size)
            assert_equal(trainer2b._batch_scheduler.shared_batch_count.get_value(),
                         trainer2a._batch_scheduler.shared_batch_count.get_value())

            # Check that the `optimizer` state after loading matches the one saved.
            assert_equal(trainer2a._optimizer.getstate(), trainer2b._optimizer.getstate())

        print("Compiling training graph")
        trainer2b.build_theano_graph()

        print("Training")
        trainer2b.train()

        # Check we correctly resumed training.
        assert_equal(len(trainer1._optimizer.loss.model.parameters),
                     len(trainer2b._optimizer.loss.model.parameters))
        for param1, param2 in zip(trainer1._optimizer.loss.model.parameters,
                                  trainer2b._optimizer.loss.model.parameters):

            # Using `assert_array_*almost*_equal` because of float32. However, this is not needed in float64.
            assert_array_almost_equal(param1.get_value(), param2.get_value(), err_msg=param1.name)

        # Using `assert_array_*almost*_equal` because of float32. However, this is not needed in float64.
        assert_array_almost_equal(nll1.mean.view(trainer1.status), nll2b.mean.view(trainer2b.status))
        assert_array_almost_equal(nll1.stderror.view(trainer1.status), nll2b.stderror.view(trainer2b.status))

        # Using `assert_array_*almost*_equal` because of float32. However, this is not needed in float64.
        assert_array_almost_equal(logger1.get_variable_history(0), logger2b.get_variable_history(0))
        assert_array_almost_equal(logger1.get_variable_history(1), logger2b.get_variable_history(1))

        # Check that the _history of `logger2b` is the same as the one in `logger1`.
        for i in range(len(logger1[0])):
            assert_equal(logger2b._history[i], logger1._history[i])

Example 34

Project: Clamp
Source File: engine.py
View license
def engine_start():
    _is_running = True

    context = {}  # {'bufname' : [tu, tick]}

    parse_thread = threading.Thread(target=ParseWorker)
    parse_thread.daemon = True
    parse_thread.start()

    nvim = attach('stdio')

    # nvim = attach('socket', path=sys.argv[1])
    # nvim.command('let g:clamp_channel=%d' % nvim.channel_id)

    sys.stderr.write('channel=%d\n' % nvim.channel_id)

    cindex.Config.set_library_file(nvim.vars['clamp_libclang_path'])
    sys.stderr.write('libclang=%s\n' % cindex.Config.library_file)

    occurrences_pri = nvim.vars['clamp_occurrence_priority']
    syntax_pri = nvim.vars['clamp_syntax_priority']

    _parse.idx = cindex.Index.create()
    _parse.cdb = compilation_database.CompilationDatabase.from_dir(
        nvim.call('getcwd'), nvim.vars['clamp_heuristic_compile_args'])
    _parse.global_args = nvim.vars['clamp_compile_args']

    _highlight.blacklist = nvim.vars['clamp_highlight_blacklist']

    unsaved = []
    while (_is_running):
        event = nvim.next_message()
        if not event:
            continue

        sys.stderr.write('event %s\n' % event[1])
        if event[1] == 'highlight':
            bufname = event[2][0]
            begin_line = event[2][1]
            end_line = event[2][2]
            row = event[2][3]
            col = event[2][4]

            if bufname not in context:
                event[3].send(0)
                continue

            tu, tick = context[bufname]

            symbol = clamp_helper.get_semantic_symbol_from_location(
                tu, bufname, row, col)
            syntax, occurrence = _highlight(
                tu, bufname, begin_line, end_line, symbol)

            nvim.call('ClampHighlight', bufname, [
                      [syntax_pri, syntax], [occurrences_pri, occurrence]])

            event[3].send(1)


        elif event[1] == 'parse':
            bufname = event[2][0]
            changedtick = event[2][1]
            content = event[2][2]

            _update_unsaved(bufname, unsaved, changedtick, content)
            #thread.start_new_thread(_parse_or_reparse_if_need, (bufname, unsaved, context, changedtick))

            with condition:
                func = functools.partial(_parse_or_reparse_if_need, bufname, unsaved, context, changedtick)
                parse_set.add(func)
                condition.notify()

            event[3].send(0)

        elif event[1] == 'rename':
            bufname = event[2][0]
            row = event[2][1]
            col = event[2][2]

            _update_unsaved_and_parse_all(nvim, unsaved, context)

            symbol = clamp_helper.get_semantic_symbol_from_location(
                context[bufname][0], bufname, row, col)
            if not symbol:
                print 'fffff'
                event[3].send({})
                continue

            result = {'old': symbol.spelling, 'renames': {}}
            usr = symbol.get_usr()

            if usr:
                for bufname, [tu, tick] in context.iteritems():
                    locations = []
                    clamp_helper.search_referenced_tokens_by_usr(
                        tu, usr, locations, symbol.spelling)

                    if locations:
                        result['renames'][bufname] = locations

            event[3].send(result)
        elif event[1] == 'cursor_info':
            bufname = event[2][0]
            row = event[2][1]
            col = event[2][2]

            tu = context[bufname][0]
            cursor = clamp_helper.get_cursor(tu, bufname, row, col)

            if not cursor:
                event[3].send('None')

            result = {'cursor':str(cursor), 'cursor.kind': str(cursor.kind), 'cursor.type.kind': str(cursor.type.kind), 'cursor.spelling' : cursor.spelling}
            event[3].send(result)
            
        elif event[1] == 'update_unsaved_all':
            _update_unsaved_all(nvim, unsaved)
            event[3].send('ok')

        elif event[1] == 'shutdown':
            nvim.session.stop()
            _is_running = False
            event[3].send('ok')

Example 35

Project: utter-va
Source File: archiver.py
View license
def main(args=sys.argv[1:]):
    options = {
        'oauth': False,
        'save-dir': ".",
        'api-rate': False,
        'timeline': "",
        'mentions': "",
        'dms': "",
        'favorites': False,
        'follow-redirects': False,
        'redirect-sites': None,
        'isoformat': False,
    }
    try:
        parse_args(args, options)
    except GetoptError as e:
        err("I can't do that, %s." % e)
        raise SystemExit(1)

    # exit if no user given
    # except if asking for API rate, or archive of timeline or mentions
    if not options['extra_args'] and not (options['api-rate'] or
                                          options['timeline'] or
                                          options['mentions'] or
                                          options['dms']):
        print(__doc__)
        return

    # authenticate using OAuth, asking for token if necessary
    if options['oauth']:
        oauth_filename = (os.environ.get('HOME', 
                          os.environ.get('USERPROFILE', '')) 
                          + os.sep
                          + '.twitter-archiver_oauth')
        
        if not os.path.exists(oauth_filename):
            oauth_dance("Twitter-Archiver", CONSUMER_KEY, CONSUMER_SECRET,
                        oauth_filename)
        oauth_token, oauth_token_secret = read_token_file(oauth_filename)
        auth = OAuth(oauth_token, oauth_token_secret, CONSUMER_KEY,
                     CONSUMER_SECRET)
    else:
        auth = NoAuth()

    twitter = Twitter(auth=auth, api_version='1.1', domain='api.twitter.com')

    if options['api-rate']:
        rate_limit_status(twitter)
        return

    global format_text
    if options['follow-redirects'] or options['redirect-sites'] :
        if options['redirect-sites']:
            hosts = parse_host_list(options['redirect-sites'])
        else:
            hosts = None
        format_text = functools.partial(expand_format_text, hosts)
    else:
        format_text = direct_format_text

    # save own timeline or mentions (the user used in OAuth)
    if options['timeline'] or options['mentions']:
        if isinstance(auth, NoAuth):
            err("You must be authenticated to save timeline or mentions.")
            raise SystemExit(1)

        if options['timeline']:
            filename = options['save-dir'] + os.sep + options['timeline']
            print("* Archiving own timeline in %s" % filename)
        elif options['mentions']:
            filename = options['save-dir'] + os.sep + options['mentions']
            print("* Archiving own mentions in %s" % filename)

        tweets = {}
        try:
            tweets = load_tweets(filename)
        except Exception as e:
            err("Error when loading saved tweets: %s - continuing without"
                % str(e))

        try:
            statuses(twitter, "", tweets, options['mentions'], options['favorites'], isoformat=options['isoformat'])
        except KeyboardInterrupt:
            err()
            err("Interrupted")
            raise SystemExit(1)

        save_tweets(filename, tweets)
        if options['timeline']:
            print("Total tweets in own timeline: %i" % len(tweets))
        elif options['mentions']:
            print("Total mentions: %i" % len(tweets))

    if options['dms']:
        if isinstance(auth, NoAuth):
            err("You must be authenticated to save DMs.")
            raise SystemExit(1)

        filename = options['save-dir'] + os.sep + options['dms']
        print("* Archiving own DMs in %s" % filename)

        dms = {}
        try:
            dms = load_tweets(filename)
        except Exception as e:
            err("Error when loading saved DMs: %s - continuing without"
                % str(e))

        try:
            statuses(twitter, "", dms, received_dms=True, isoformat=options['isoformat'])
            statuses(twitter, "", dms, received_dms=False, isoformat=options['isoformat'])
        except KeyboardInterrupt:
            err()
            err("Interrupted")
            raise SystemExit(1)

        save_tweets(filename, dms)
        print("Total DMs sent and received: %i" % len(dms))


    # read users from command-line or stdin
    users = options['extra_args']
    if len(users) == 1 and users[0] == "-":
        users = [line.strip() for line in sys.stdin.readlines()]

    # save tweets for every user
    total, total_new = 0, 0
    for user in users:
        filename = options['save-dir'] + os.sep + user
        if options['favorites']:
            filename = filename + "-favorites"
        print("* Archiving %s tweets in %s" % (user, filename))

        tweets = {}
        try:
            tweets = load_tweets(filename)
        except Exception as e:
            err("Error when loading saved tweets: %s - continuing without"
                % str(e))

        new = 0
        before = len(tweets)
        try:
            statuses(twitter, user, tweets, options['mentions'], options['favorites'], isoformat=options['isoformat'])
        except KeyboardInterrupt:
            err()
            err("Interrupted")
            raise SystemExit(1)

        save_tweets(filename, tweets)
        total += len(tweets)
        new = len(tweets) - before
        total_new += new
        print("Total tweets for %s: %i (%i new)" % (user, len(tweets), new))

    print("Total: %i tweets (%i new) for %i users"
          % (total, total_new, len(users)))

Example 36

Project: spinn
Source File: classifier.py
View license
def build_sentence_pair_model(cls, vocab_size, seq_length, tokens, transitions,
                     num_classes, training_mode, ground_truth_transitions_visible, vs,
                     initial_embeddings=None, project_embeddings=False, ss_mask_gen=None, ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `spinn.stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      training_mode: A Theano scalar indicating whether to act as a training model
        with dropout (1.0) or to act as an eval model with rescaling (0.0).
      ground_truth_transitions_visible: A Theano scalar. If set (1.0), allow the model access
        to ground truth transitions. This can be disabled at evaluation time to force Model 1
        (or 2S) to evaluate in the Model 2 style with predicted transitions. Has no effect on Model 0.
      vs: Variable store.
    """

    # Prepare layer which performs stack element composition.
    if cls is spinn.plain_rnn.RNN:
        compose_network = partial(util.LSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        embedding_projection_network = None
    else:
        if FLAGS.lstm_composition:
            compose_network = partial(util.TreeLSTMLayer,
                                      initializer=util.HeKaimingInitializer())
        else:
            assert not FLAGS.connect_tracking_comp, "Can only connect tracking and composition unit while using TreeLSTM"
            compose_network = partial(util.ReLULayer,
                                      initializer=util.HeKaimingInitializer())

        if project_embeddings:
            embedding_projection_network = util.Linear
        else:
            assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
                "word_embedding_dim must equal model_dim unless a projection layer is used."
            embedding_projection_network = util.IdentityLayer

    model_visible_dim = FLAGS.model_dim / 2 if FLAGS.lstm_composition else FLAGS.model_dim
    spec = util.ModelSpec(FLAGS.model_dim, FLAGS.word_embedding_dim,
                          FLAGS.batch_size, vocab_size, seq_length,
                          model_visible_dim=model_visible_dim)

    # Split the two sentences
    premise_tokens = tokens[:, :, 0]
    hypothesis_tokens = tokens[:, :, 1]

    premise_transitions = transitions[:, :, 0]
    hypothesis_transitions = transitions[:, :, 1]

    # TODO: Check non-Model0 support.
    recurrence = cls(spec, vs, compose_network,
                     use_context_sensitive_shift=FLAGS.context_sensitive_shift,
                     context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
                     use_tracking_lstm=FLAGS.use_tracking_lstm,
                     tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim)

    # Build two hard stack models which scan over input sequences.
    premise_model = ThinStack(spec, recurrence, embedding_projection_network,
        training_mode, ground_truth_transitions_visible, vs,
        X=premise_tokens,
        transitions=premise_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        use_input_batch_norm=False,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        name="premise")

    hypothesis_model = ThinStack(spec, recurrence, embedding_projection_network,
        training_mode, ground_truth_transitions_visible, vs,
        X=hypothesis_tokens,
        transitions=hypothesis_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        use_input_batch_norm=False,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        name="hypothesis")

    # Create standard MLP features
    mlp_input = T.concatenate([premise_vector, hypothesis_vector], axis=1)
    mlp_input_dim = 2 * sentence_vector_dim

    if FLAGS.use_difference_feature:
        mlp_input = T.concatenate([mlp_input, premise_vector - hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    if FLAGS.use_product_feature:
        mlp_input = T.concatenate([mlp_input, premise_vector * hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    mlp_input = util.BatchNorm(mlp_input, mlp_input_dim, vs, "sentence_vectors", training_mode)
    mlp_input = util.Dropout(mlp_input, FLAGS.semantic_classifier_keep_rate, training_mode)

    # Apply a combining MLP
    prev_features = mlp_input
    prev_features_dim = mlp_input_dim
    for layer in range(FLAGS.num_sentence_pair_combination_layers):
        prev_features = util.ReLULayer(prev_features, prev_features_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
            name="combining_mlp/" + str(layer),
            initializer=util.HeKaimingInitializer())
        prev_features_dim = FLAGS.sentence_pair_combination_layer_dim

        prev_features = util.BatchNorm(prev_features, prev_features_dim, vs, "combining_mlp/" + str(layer), training_mode)
        prev_features = util.Dropout(prev_features, FLAGS.semantic_classifier_keep_rate, training_mode)

    # Feed forward through a single output layer
    logits = util.Linear(
        prev_features, prev_features_dim, num_classes, vs,
        name="semantic_classifier", use_bias=True)

    def zero_fn():
        premise_model.zero()
        hypothesis_model.zero()

    return premise_model, hypothesis_model, logits, zero_fn

Example 37

Project: spinn
Source File: fat_classifier.py
View license
def build_sentence_pair_model(cls, vocab_size, seq_length, tokens, transitions,
                     num_classes, training_mode, ground_truth_transitions_visible, vs,
                     initial_embeddings=None, project_embeddings=False, ss_mask_gen=None, ss_prob=0.0):
    """
    Construct a classifier which makes use of some hard-stack model.

    Args:
      cls: Hard stack class to use (from e.g. `spinn.fat_stack`)
      vocab_size:
      seq_length: Length of each sequence provided to the stack model
      tokens: Theano batch (integer matrix), `batch_size * seq_length`
      transitions: Theano batch (integer matrix), `batch_size * seq_length`
      num_classes: Number of output classes
      training_mode: A Theano scalar indicating whether to act as a training model
        with dropout (1.0) or to act as an eval model with rescaling (0.0).
      ground_truth_transitions_visible: A Theano scalar. If set (1.0), allow the model access
        to ground truth transitions. This can be disabled at evaluation time to force Model 1
        (or 2S) to evaluate in the Model 2 style with predicted transitions. Has no effect on Model 0.
      vs: Variable store.
    """


    # Prepare layer which performs stack element composition.
    if cls is spinn.plain_rnn.RNN:
        if FLAGS.use_gru:
            compose_network = partial(util.GRULayer,
                                          initializer=util.HeKaimingInitializer())
        else:
            compose_network = partial(util.LSTMLayer,
                                          initializer=util.HeKaimingInitializer())
        embedding_projection_network = None
    elif cls is spinn.cbow.CBOW:
        compose_network = None
        embedding_projection_network = None
    else:
        if FLAGS.lstm_composition:
            if FLAGS.use_gru:
                compose_network = partial(util.TreeGRULayer,
                                          initializer=util.HeKaimingInitializer())
            else:
                compose_network = partial(util.TreeLSTMLayer,
                                          initializer=util.HeKaimingInitializer())
        else:
            assert not FLAGS.connect_tracking_comp, "Can only connect tracking and composition unit while using TreeLSTM"
            compose_network = partial(util.ReLULayer,
                                      initializer=util.HeKaimingInitializer())

        if project_embeddings:
            embedding_projection_network = util.Linear
        else:
            assert FLAGS.word_embedding_dim == FLAGS.model_dim, \
                "word_embedding_dim must equal model_dim unless a projection layer is used."
            embedding_projection_network = util.IdentityLayer

    # Split the two sentences
    premise_tokens = tokens[:, :, 0]
    hypothesis_tokens = tokens[:, :, 1]

    premise_transitions = transitions[:, :, 0]
    hypothesis_transitions = transitions[:, :, 1]

    # Build two hard stack models which scan over input sequences.
    premise_model = cls(
        FLAGS.model_dim, FLAGS.word_embedding_dim, vocab_size, seq_length,
        compose_network, embedding_projection_network, training_mode, ground_truth_transitions_visible, vs,
        predict_use_cell=FLAGS.predict_use_cell,
        use_tracking_lstm=FLAGS.use_tracking_lstm,
        tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim,
        X=premise_tokens,
        transitions=premise_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        connect_tracking_comp=FLAGS.connect_tracking_comp,
        context_sensitive_shift=FLAGS.context_sensitive_shift,
        context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
        use_attention=FLAGS.use_attention,
        initialize_hyp_tracking_state=FLAGS.initialize_hyp_tracking_state)

    premise_stack_tops = premise_model.stack_tops if FLAGS.use_attention != "None" else None
    premise_tracking_c_state_final = premise_model.tracking_c_state_final if cls not in [spinn.plain_rnn.RNN, 
                                                                                            spinn.cbow.CBOW] else None
    hypothesis_model = cls(
        FLAGS.model_dim, FLAGS.word_embedding_dim, vocab_size, seq_length,
        compose_network, embedding_projection_network, training_mode, ground_truth_transitions_visible, vs,
        predict_use_cell=FLAGS.predict_use_cell,
        use_tracking_lstm=FLAGS.use_tracking_lstm,
        tracking_lstm_hidden_dim=FLAGS.tracking_lstm_hidden_dim,
        X=hypothesis_tokens,
        transitions=hypothesis_transitions,
        initial_embeddings=initial_embeddings,
        embedding_dropout_keep_rate=FLAGS.embedding_keep_rate,
        ss_mask_gen=ss_mask_gen,
        ss_prob=ss_prob,
        connect_tracking_comp=FLAGS.connect_tracking_comp,
        context_sensitive_shift=FLAGS.context_sensitive_shift,
        context_sensitive_use_relu=FLAGS.context_sensitive_use_relu,
        use_attention=FLAGS.use_attention,
        premise_stack_tops=premise_stack_tops,
        is_hypothesis=True,
        initialize_hyp_tracking_state=FLAGS.initialize_hyp_tracking_state,
        premise_tracking_c_state_final=premise_tracking_c_state_final)

    # Extract top element of final stack timestep.
    if FLAGS.use_attention == "None" or FLAGS.use_difference_feature or FLAGS.use_product_feature:
        premise_vector = premise_model.final_representations
        hypothesis_vector = hypothesis_model.final_representations

        if (FLAGS.lstm_composition and cls is not spinn.cbow.CBOW) or cls is spinn.plain_rnn.RNN:
            premise_vector = premise_vector[:,:FLAGS.model_dim / 2].reshape((-1, FLAGS.model_dim / 2))
            hypothesis_vector = hypothesis_vector[:,:FLAGS.model_dim / 2].reshape((-1, FLAGS.model_dim / 2))
            sentence_vector_dim = FLAGS.model_dim / 2
        else:
            premise_vector = premise_vector.reshape((-1, FLAGS.model_dim))
            hypothesis_vector = hypothesis_vector.reshape((-1, FLAGS.model_dim))
            sentence_vector_dim = FLAGS.model_dim

    if FLAGS.use_attention != "None":
        # Use the attention weighted representation
        h_dim = FLAGS.model_dim / 2
        mlp_input = hypothesis_model.final_weighed_representation.reshape((-1, h_dim))
        mlp_input_dim = h_dim
    else:
        # Create standard MLP features
        mlp_input = T.concatenate([premise_vector, hypothesis_vector], axis=1)
        mlp_input_dim = 2 * sentence_vector_dim

    if FLAGS.use_difference_feature:
        mlp_input = T.concatenate([mlp_input, premise_vector - hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    if FLAGS.use_product_feature:
        mlp_input = T.concatenate([mlp_input, premise_vector * hypothesis_vector], axis=1)
        mlp_input_dim += sentence_vector_dim

    mlp_input = util.BatchNorm(mlp_input, mlp_input_dim, vs, "sentence_vectors", training_mode)
    mlp_input = util.Dropout(mlp_input, FLAGS.semantic_classifier_keep_rate, training_mode)

    if FLAGS.classifier_type == "ResNet":
        features = util.Linear(
            mlp_input, mlp_input_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
            name="resnet/linear", use_bias=True)
        features_dim = FLAGS.sentence_pair_combination_layer_dim

        for layer in range(FLAGS.num_sentence_pair_combination_layers):
            features = util.HeKaimingResidualLayerSet(features, features_dim, vs, training_mode, name="resnet/" + str(layer), 
                dropout_keep_rate=FLAGS.semantic_classifier_keep_rate, depth=FLAGS.resnet_unit_depth, 
                initializer=util.HeKaimingInitializer())
            features = util.BatchNorm(features, features_dim, vs, "combining_mlp/" + str(layer), training_mode)
            features = util.Dropout(features, FLAGS.semantic_classifier_keep_rate, training_mode)
    elif FLAGS.classifier_type == "Highway":
        features = util.Linear(
            mlp_input, mlp_input_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
            name="resnet/linear", use_bias=True)
        features_dim = FLAGS.sentence_pair_combination_layer_dim

        for layer in range(FLAGS.num_sentence_pair_combination_layers):
            features = util.HighwayLayer(features, features_dim, vs, training_mode, name="highway/" + str(layer), 
                dropout_keep_rate=FLAGS.semantic_classifier_keep_rate,
                initializer=util.HeKaimingInitializer())
            features = util.BatchNorm(features, features_dim, vs, "combining_mlp/" + str(layer), training_mode)
            features = util.Dropout(features, FLAGS.semantic_classifier_keep_rate, training_mode)
    else:    
        # Apply a combining MLP
        features = mlp_input
        features_dim = mlp_input_dim
        for layer in range(FLAGS.num_sentence_pair_combination_layers):
            features = util.ReLULayer(features, features_dim, FLAGS.sentence_pair_combination_layer_dim, vs,
                name="combining_mlp/" + str(layer),
                initializer=util.HeKaimingInitializer())
            features_dim = FLAGS.sentence_pair_combination_layer_dim

            features = util.BatchNorm(features, features_dim, vs, "combining_mlp/" + str(layer), training_mode)
            features = util.Dropout(features, FLAGS.semantic_classifier_keep_rate, training_mode) 

    # Feed forward through a single output layer
    logits = util.Linear(
        features, features_dim, num_classes, vs,
        name="semantic_classifier", use_bias=True)

    return premise_model.transitions_pred, hypothesis_model.transitions_pred, logits

Example 38

Project: streamlink
Source File: plugin.py
View license
    def streams(self, stream_types=None, sorting_excludes=None):
        """Attempts to extract available streams.

        Returns a :class:`dict` containing the streams, where the key is
        the name of the stream, most commonly the quality and the value
        is a :class:`Stream` object.

        The result can contain the synonyms **best** and **worst** which
        points to the streams which are likely to be of highest and
        lowest quality respectively.

        If multiple streams with the same name are found, the order of
        streams specified in *stream_types* will determine which stream
        gets to keep the name while the rest will be renamed to
        "<name>_<stream type>".

        The synonyms can be fine tuned with the *sorting_excludes*
        parameter. This can be either of these types:

            - A list of filter expressions in the format
              *[operator]<value>*. For example the filter ">480p" will
              exclude streams ranked higher than "480p" from the list
              used in the synonyms ranking. Valid operators are >, >=, <
              and <=. If no operator is specified then equality will be
              tested.

            - A function that is passed to filter() with a list of
              stream names as input.


        :param stream_types: A list of stream types to return.
        :param sorting_excludes: Specify which streams to exclude from
                                 the best/worst synonyms.


        .. versionchanged:: 1.4.2
           Added *priority* parameter.

        .. versionchanged:: 1.5.0
           Renamed *priority* to *stream_types* and changed behaviour
           slightly.

        .. versionchanged:: 1.5.0
           Added *sorting_excludes* parameter.

        .. versionchanged:: 1.6.0
           *sorting_excludes* can now be a list of filter expressions
           or a function that is passed to filter().


        """

        try:
            ostreams = self._get_streams()
            if isinstance(ostreams, dict):
                ostreams = ostreams.items()

            # Flatten the iterator to a list so we can reuse it.
            if ostreams:
                ostreams = list(ostreams)
        except NoStreamsError:
            return {}
        except (IOError, OSError, ValueError) as err:
            raise PluginError(err)

        if not ostreams:
            return {}

        if stream_types is None:
            stream_types = self.default_stream_types(ostreams)

        # Add streams depending on stream type and priorities
        sorted_streams = sorted(iterate_streams(ostreams),
                                key=partial(stream_type_priority,
                                            stream_types))

        streams = {}
        for name, stream in sorted_streams:
            stream_type = type(stream).shortname()

            if stream_type not in stream_types:
                continue

            existing = streams.get(name)
            if existing:
                existing_stream_type = type(existing).shortname()
                if existing_stream_type != stream_type:
                    name = "{0}_{1}".format(name, stream_type)

                if name in streams:
                    name = "{0}_alt".format(name)
                    num_alts = len(list(filter(lambda n: n.startswith(name), streams.keys())))

                    # We shouldn't need more than 2 alt streams
                    if num_alts >= 2:
                        continue
                    elif num_alts > 0:
                        name = "{0}{1}".format(name, num_alts + 1)

            # Validate stream name and discard the stream if it's bad.
            match = re.match("([A-z0-9_+]+)", name)
            if match:
                name = match.group(1)
            else:
                self.logger.debug("The stream '{0}' has been ignored "
                                  "since it is badly named.", name)
                continue

            # Force lowercase name and replace space with underscore.
            streams[name.lower()] = stream

        # Create the best/worst synonmys
        stream_weight_only = lambda s: (self.stream_weight(s)[0] or
                                        (len(streams) == 1 and 1))
        stream_names = filter(stream_weight_only, streams.keys())
        sorted_streams = sorted(stream_names, key=stream_weight_only)

        if isinstance(sorting_excludes, list):
            for expr in sorting_excludes:
                filter_func = stream_sorting_filter(expr, self.stream_weight)
                sorted_streams = list(filter(filter_func, sorted_streams))
        elif callable(sorting_excludes):
            sorted_streams = list(filter(sorting_excludes, sorted_streams))

        if len(sorted_streams) > 0:
            best = sorted_streams[-1]
            worst = sorted_streams[0]
            streams["best"] = streams[best]
            streams["worst"] = streams[worst]

        return streams

Example 39

Project: sugar-toolkit-gtk3
Source File: activity.py
View license
    def __init__(self, handle, create_jobject=True):
        '''
        Initialise the Activity

        Args:

        handle (sugar3.activity.activityhandle.ActivityHandle)
            instance providing the activity id and access to the
            presence service which *may* provide sharing for this
            application
        create_jobject (boolean)
            DEPRECATED: define if it should create a journal object if we are
            not resuming. The parameter is ignored, and always  will
            be created a object in the Journal.

        Side effects:

            Sets the gdk screen DPI setting (resolution) to the
            Sugar screen resolution.

            Connects our "destroy" message to our _destroy_cb
            method.

            Creates a base Gtk.Window within this window.

            Creates an ActivityService (self._bus) servicing
            this application.

        Usage:
            If your Activity implements __init__(), it should call
            the base class __init()__ before doing Activity specific things.

        '''

        # Stuff that needs to be done early
        icons_path = os.path.join(get_bundle_path(), 'icons')
        Gtk.IconTheme.get_default().append_search_path(icons_path)

        sugar_theme = 'sugar-72'
        if 'SUGAR_SCALING' in os.environ:
            if os.environ['SUGAR_SCALING'] == '100':
                sugar_theme = 'sugar-100'

        # This code can be removed when we grow an xsettings daemon (the GTK+
        # init routines will then automatically figure out the font settings)
        settings = Gtk.Settings.get_default()
        settings.set_property('gtk-theme-name', sugar_theme)
        settings.set_property('gtk-icon-theme-name', 'sugar')
        settings.set_property('gtk-button-images', True)
        settings.set_property('gtk-font-name',
                              '%s %f' % (style.FONT_FACE, style.FONT_SIZE))

        Window.__init__(self)

        if 'SUGAR_ACTIVITY_ROOT' in os.environ:
            # If this activity runs inside Sugar, we want it to take all the
            # screen. Would be better if it was the shell to do this, but we
            # haven't found yet a good way to do it there. See #1263.
            self.connect('window-state-event', self.__window_state_event_cb)
            screen = Gdk.Screen.get_default()
            screen.connect('size-changed', self.__screen_size_changed_cb)
            self._adapt_window_to_screen()

        # process titles will only show 15 characters
        # but they get truncated anyway so if more characters
        # are supported in the future we will get a better view
        # of the processes
        proc_title = '%s <%s>' % (get_bundle_name(), handle.activity_id)
        util.set_proc_title(proc_title)

        self.connect('realize', self.__realize_cb)
        self.connect('delete-event', self.__delete_event_cb)

        self._active = False
        self._active_time = None
        self._spent_time = 0
        self._activity_id = handle.activity_id
        self.shared_activity = None
        self._join_id = None
        self._updating_jobject = False
        self._closing = False
        self._quit_requested = False
        self._deleting = False
        self._max_participants = None
        self._invites_queue = []
        self._jobject = None
        self._read_file_called = False

        self._session = _get_session()
        self._session.register(self)
        self._session.connect('quit-requested',
                              self.__session_quit_requested_cb)
        self._session.connect('quit', self.__session_quit_cb)

        accel_group = Gtk.AccelGroup()
        self.sugar_accel_group = accel_group
        self.add_accel_group(accel_group)

        self._bus = ActivityService(self)
        self._owns_file = False

        share_scope = SCOPE_PRIVATE

        if handle.object_id:
            self._jobject = datastore.get(handle.object_id)

            if 'share-scope' in self._jobject.metadata:
                share_scope = self._jobject.metadata['share-scope']

            if 'launch-times' in self._jobject.metadata:
                self._jobject.metadata['launch-times'] += ', %d' % \
                    int(time.time())
            else:
                self._jobject.metadata['launch-times'] = \
                    str(int(time.time()))

            if 'spent-times' in self._jobject.metadata:
                self._jobject.metadata['spent-times'] += ', 0'
            else:
                self._jobject.metadata['spent-times'] = '0'

        self.shared_activity = None
        self._join_id = None

        if handle.object_id is None:
            logging.debug('Creating a jobject.')
            self._jobject = self._initialize_journal_object()

        if handle.invited:
            wait_loop = GObject.MainLoop()
            self._client_handler = _ClientHandler(
                self.get_bundle_id(),
                partial(self.__got_channel_cb, wait_loop))
            # FIXME: The current API requires that self.shared_activity is set
            # before exiting from __init__, so we wait until we have got the
            # shared activity. http://bugs.sugarlabs.org/ticket/2168
            wait_loop.run()
        else:
            pservice = presenceservice.get_instance()
            mesh_instance = pservice.get_activity(self._activity_id,
                                                  warn_if_none=False)
            self._set_up_sharing(mesh_instance, share_scope)

        if self.shared_activity is not None:
            self._jobject.metadata['title'] = self.shared_activity.props.name
            self._jobject.metadata['icon-color'] = \
                self.shared_activity.props.color
        else:
            self._jobject.metadata.connect('updated',
                                           self.__jobject_updated_cb)
        self.set_title(self._jobject.metadata['title'])

        bundle = get_bundle_instance(get_bundle_path())
        self.set_icon_from_file(bundle.get_icon())

Example 40

Project: tornado-zh
Source File: curl_httpclient.py
View license
    def _curl_setup_request(self, curl, request, buffer, headers):
        curl.setopt(pycurl.URL, native_str(request.url))

        # libcurl's magic "Expect: 100-continue" behavior causes delays
        # with servers that don't support it (which include, among others,
        # Google's OpenID endpoint).  Additionally, this behavior has
        # a bug in conjunction with the curl_multi_socket_action API
        # (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
        # which increases the delays.  It's more trouble than it's worth,
        # so just turn off the feature (yes, setting Expect: to an empty
        # value is the official way to disable this)
        if "Expect" not in request.headers:
            request.headers["Expect"] = ""

        # libcurl adds Pragma: no-cache by default; disable that too
        if "Pragma" not in request.headers:
            request.headers["Pragma"] = ""

        curl.setopt(pycurl.HTTPHEADER,
                    ["%s: %s" % (native_str(k), native_str(v))
                     for k, v in request.headers.get_all()])

        curl.setopt(pycurl.HEADERFUNCTION,
                    functools.partial(self._curl_header_callback,
                                      headers, request.header_callback))
        if request.streaming_callback:
            def write_function(chunk):
                self.io_loop.add_callback(request.streaming_callback, chunk)
        else:
            write_function = buffer.write
        if bytes is str:  # py2
            curl.setopt(pycurl.WRITEFUNCTION, write_function)
        else:  # py3
            # Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
            # a fork/port.  That version has a bug in which it passes unicode
            # strings instead of bytes to the WRITEFUNCTION.  This means that
            # if you use a WRITEFUNCTION (which tornado always does), you cannot
            # download arbitrary binary data.  This needs to be fixed in the
            # ported pycurl package, but in the meantime this lambda will
            # make it work for downloading (utf8) text.
            curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
        curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
        curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
        curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
        curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
        if request.user_agent:
            curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
        else:
            curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
        if request.network_interface:
            curl.setopt(pycurl.INTERFACE, request.network_interface)
        if request.decompress_response:
            curl.setopt(pycurl.ENCODING, "gzip,deflate")
        else:
            curl.setopt(pycurl.ENCODING, "none")
        if request.proxy_host and request.proxy_port:
            curl.setopt(pycurl.PROXY, request.proxy_host)
            curl.setopt(pycurl.PROXYPORT, request.proxy_port)
            if request.proxy_username:
                credentials = '%s:%s' % (request.proxy_username,
                                         request.proxy_password)
                curl.setopt(pycurl.PROXYUSERPWD, credentials)
        else:
            curl.setopt(pycurl.PROXY, '')
            curl.unsetopt(pycurl.PROXYUSERPWD)
        if request.validate_cert:
            curl.setopt(pycurl.SSL_VERIFYPEER, 1)
            curl.setopt(pycurl.SSL_VERIFYHOST, 2)
        else:
            curl.setopt(pycurl.SSL_VERIFYPEER, 0)
            curl.setopt(pycurl.SSL_VERIFYHOST, 0)
        if request.ca_certs is not None:
            curl.setopt(pycurl.CAINFO, request.ca_certs)
        else:
            # There is no way to restore pycurl.CAINFO to its default value
            # (Using unsetopt makes it reject all certificates).
            # I don't see any way to read the default value from python so it
            # can be restored later.  We'll have to just leave CAINFO untouched
            # if no ca_certs file was specified, and require that if any
            # request uses a custom ca_certs file, they all must.
            pass

        if request.allow_ipv6 is False:
            # Curl behaves reasonably when DNS resolution gives an ipv6 address
            # that we can't reach, so allow ipv6 unless the user asks to disable.
            curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
        else:
            curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)

        # Set the request method through curl's irritating interface which makes
        # up names for almost every single method
        curl_options = {
            "GET": pycurl.HTTPGET,
            "POST": pycurl.POST,
            "PUT": pycurl.UPLOAD,
            "HEAD": pycurl.NOBODY,
        }
        custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
        for o in curl_options.values():
            curl.setopt(o, False)
        if request.method in curl_options:
            curl.unsetopt(pycurl.CUSTOMREQUEST)
            curl.setopt(curl_options[request.method], True)
        elif request.allow_nonstandard_methods or request.method in custom_methods:
            curl.setopt(pycurl.CUSTOMREQUEST, request.method)
        else:
            raise KeyError('unknown method ' + request.method)

        body_expected = request.method in ("POST", "PATCH", "PUT")
        body_present = request.body is not None
        if not request.allow_nonstandard_methods:
            # Some HTTP methods nearly always have bodies while others
            # almost never do. Fail in this case unless the user has
            # opted out of sanity checks with allow_nonstandard_methods.
            if ((body_expected and not body_present) or
                    (body_present and not body_expected)):
                raise ValueError(
                    'Body must %sbe None for method %s (unless '
                    'allow_nonstandard_methods is true)' %
                    ('not ' if body_expected else '', request.method))

        if body_expected or body_present:
            if request.method == "GET":
                # Even with `allow_nonstandard_methods` we disallow
                # GET with a body (because libcurl doesn't allow it
                # unless we use CUSTOMREQUEST). While the spec doesn't
                # forbid clients from sending a body, it arguably
                # disallows the server from doing anything with them.
                raise ValueError('Body must be None for GET request')
            request_buffer = BytesIO(utf8(request.body or ''))

            def ioctl(cmd):
                if cmd == curl.IOCMD_RESTARTREAD:
                    request_buffer.seek(0)
            curl.setopt(pycurl.READFUNCTION, request_buffer.read)
            curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
            if request.method == "POST":
                curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ''))
            else:
                curl.setopt(pycurl.UPLOAD, True)
                curl.setopt(pycurl.INFILESIZE, len(request.body or ''))

        if request.auth_username is not None:
            userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')

            if request.auth_mode is None or request.auth_mode == "basic":
                curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
            elif request.auth_mode == "digest":
                curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
            else:
                raise ValueError("Unsupported auth_mode %s" % request.auth_mode)

            curl.setopt(pycurl.USERPWD, native_str(userpwd))
            curl_log.debug("%s %s (username: %r)", request.method, request.url,
                           request.auth_username)
        else:
            curl.unsetopt(pycurl.USERPWD)
            curl_log.debug("%s %s", request.method, request.url)

        if request.client_cert is not None:
            curl.setopt(pycurl.SSLCERT, request.client_cert)

        if request.client_key is not None:
            curl.setopt(pycurl.SSLKEY, request.client_key)

        if request.ssl_options is not None:
            raise ValueError("ssl_options not supported in curl_httpclient")

        if threading.activeCount() > 1:
            # libcurl/pycurl is not thread-safe by default.  When multiple threads
            # are used, signals should be disabled.  This has the side effect
            # of disabling DNS timeouts in some environments (when libcurl is
            # not linked against ares), so we don't do it when there is only one
            # thread.  Applications that use many short-lived threads may need
            # to set NOSIGNAL manually in a prepare_curl_callback since
            # there may not be any other threads running at the time we call
            # threading.activeCount.
            curl.setopt(pycurl.NOSIGNAL, 1)
        if request.prepare_curl_callback is not None:
            request.prepare_curl_callback(curl)

Example 41

Project: ion
Source File: serializers.py
View license
    @transaction.atomic
    def fetch_activity_list_with_metadata(self, block):
        user = self.context.get("user", self.context["request"].user)

        favorited_activities = set(user.favorited_activity_set.values_list("id", flat=True))

        available_restricted_acts = EighthActivity.restricted_activities_available_to_user(user)

        activity_list = FallbackDict(functools.partial(self.get_activity, user, favorited_activities, available_restricted_acts))
        scheduled_activity_to_activity_map = FallbackDict(self.get_scheduled_activity)

        # Find all scheduled activities that don't correspond to
        # deleted activities
        scheduled_activities = (block.eighthscheduledactivity_set.exclude(activity__deleted=True).select_related("activity"))

        for scheduled_activity in scheduled_activities:
            # Avoid re-fetching scheduled_activity.
            activity_info = self.get_activity(user, favorited_activities, available_restricted_acts, None, scheduled_activity)
            activity = scheduled_activity.activity
            scheduled_activity_to_activity_map[scheduled_activity.id] = activity.id
            activity_list[activity.id] = activity_info

        # Find the number of students signed up for every activity
        # in this block
        activities_with_signups = (EighthSignup.objects.filter(scheduled_activity__block=block).exclude(
            scheduled_activity__activity__deleted=True).values_list("scheduled_activity__activity_id")
            .annotate(user_count=Count("scheduled_activity")))

        for activity, user_count in activities_with_signups:
            activity_list[activity]["roster"]["count"] = user_count

        sponsors_dict = (EighthSponsor.objects.values_list("id", "user_id", "first_name", "last_name", "show_full_name"))

        all_sponsors = dict((sponsor[0], {"user_id": sponsor[1],
                                          "name": sponsor[2] + " " + sponsor[3] if sponsor[4] else sponsor[3]}) for sponsor in sponsors_dict)

        activity_ids = scheduled_activities.values_list("activity__id")
        sponsorships = (EighthActivity.sponsors.through.objects.filter(eighthactivity_id__in=activity_ids).select_related("sponsors").values(
            "eighthactivity", "eighthsponsor"))

        scheduled_activity_ids = scheduled_activities.values_list("id", flat=True)
        overidden_sponsorships = (EighthScheduledActivity.sponsors.through.objects.filter(
            eighthscheduledactivity_id__in=scheduled_activity_ids).values("eighthscheduledactivity", "eighthsponsor"))

        for sponsorship in sponsorships:
            activity_id = int(sponsorship["eighthactivity"])
            sponsor_id = sponsorship["eighthsponsor"]
            sponsor = all_sponsors[sponsor_id]

            if sponsor["user_id"]:
                # We're not using User.get_user() here since we only want
                # a value from LDAP that is probably already cached.
                # This eliminates several hundred SQL queries on some
                # pages.
                dn = User.dn_from_id(sponsor["user_id"])
                if dn is not None:
                    name = User(dn=dn).last_name
                else:
                    name = None
            else:
                name = None

            if activity_id in activity_list:
                activity_list[activity_id]["sponsors"].append(sponsor["name"] or name)

        activities_sponsors_overidden = []
        for sponsorship in overidden_sponsorships:
            scheduled_activity_id = sponsorship["eighthscheduledactivity"]
            activity_id = scheduled_activity_to_activity_map[scheduled_activity_id]
            sponsor_id = sponsorship["eighthsponsor"]
            sponsor = all_sponsors[sponsor_id]

            if activity_id not in activities_sponsors_overidden:
                activities_sponsors_overidden.append(activity_id)
                del activity_list[activity_id]["sponsors"][:]

            if sponsor["user_id"]:
                # See a few lines up for why we're not using User.get_user()
                dn = User.dn_from_id(sponsor["user_id"])
                if dn is not None:
                    name = User(dn=dn).last_name
                else:
                    name = None
            else:
                name = None
            activity_list[activity_id]["sponsors"].append(sponsor["name"] or name)

        roomings = (EighthActivity.rooms.through.objects.filter(eighthactivity_id__in=activity_ids).select_related("eighthroom", "eighthactivity"))
        overidden_roomings = (EighthScheduledActivity.rooms.through.objects.filter(
            eighthscheduledactivity_id__in=scheduled_activity_ids).select_related("eighthroom", "eighthscheduledactivity"))

        for rooming in roomings:
            activity_id = rooming.eighthactivity.id
            activity_cap = rooming.eighthactivity.default_capacity
            room_name = rooming.eighthroom.name
            activity_list[activity_id]["rooms"].append(room_name)
            if activity_cap:
                # use activity default capacity instead of sum of activity rooms
                activity_list[activity_id]["roster"]["capacity"] = activity_cap
            else:
                activity_list[activity_id]["roster"]["capacity"] += rooming.eighthroom.capacity

        activities_rooms_overidden = []
        for rooming in overidden_roomings:
            scheduled_activity_id = rooming.eighthscheduledactivity.id

            activity_id = scheduled_activity_to_activity_map[scheduled_activity_id]
            if activity_id not in activities_rooms_overidden:
                activities_rooms_overidden.append(activity_id)
                del activity_list[activity_id]["rooms"][:]
                activity_list[activity_id]["roster"]["capacity"] = 0
            room_name = rooming.eighthroom.name
            activity_list[activity_id]["rooms"].append(room_name)
            activity_list[activity_id]["roster"]["capacity"] += rooming.eighthroom.capacity

        for scheduled_activity in scheduled_activities:
            if scheduled_activity.capacity is not None:
                capacity = scheduled_activity.capacity
                sched_act_id = scheduled_activity.activity.id
                activity_list[sched_act_id]["roster"]["capacity"] = capacity

        return activity_list

Example 42

Project: deep_recommend_system
Source File: bucket_ops.py
View license
def bucket(tensors,
           which_bucket,
           batch_size,
           num_buckets,
           num_threads=1,
           capacity=32,
           shapes=None,
           dynamic_pad=False,
           allow_smaller_final_batch=False,
           keep_input=None,
           shared_name=None,
           name=None):
  """Lazy bucketing of input tensors according to `which_bucket`.

  The argument `tensors` can be a list or a dictionary of tensors.
  The value returned by the function will be of the same type
  as `tensors`.

  The tensors entering this function are put into the bucket given by
  `which_bucket`.  Each bucket has its own queue.  When a bucket contains
  `batch_size` elements, this minibatch is pushed onto a top queue.  The
  tensors returned from this function are a the result of dequeueing the
  next minibatch from this top queue.

  This function is implemented using several queues. A `QueueRunner` for the
  queues is added to the current `Graph`'s `QUEUE_RUNNER` collection.

  As the returned tensors are the result of of a dequeue operation, evaluating
  them will throw a `tf.errors.OutOfRangeError` when the input queue is
  exhausted.  If these tensors are feeding another input queue, its queue runner
  will catch this exception, however, if they are used in your main thread
  you are responsible for catching this yourself.

  *N.B.:* If `dynamic_pad` is `False`, you must ensure that either
  (i) the `shapes` argument is passed, or (ii) all of the tensors in
  `tensors` must have fully-defined shapes. `ValueError` will be
  raised if neither of these conditions holds.

  If `dynamic_pad` is `True`, it is sufficient that the *rank* of the
  tensors is known, but individual dimensions may have shape `None`.
  In this case, for each enqueue the dimensions with value `None`
  may have a variable length; upon dequeue, the output tensors will be padded
  on the right to the maximum shape of the tensors in the current minibatch.
  For numbers, this padding takes value 0.  For strings, this padding is
  the empty string.  See `PaddingFIFOQueue` for more info.

  If `allow_smaller_final_batch` is `True`, a smaller batch value than
  `batch_size` is returned when the queues are closed and there are not enough
  elements to fill the batch, otherwise the pending elements are discarded.
  In addition, all output tensors' static shapes, as accessed via the
  `get_shape()` method will have a 0th `Dimension` value of `None`, and
  operations that depend on fixed batch_size would fail.

  Args:
    tensors: The list or dictionary of tensors, representing a single element,
      to bucket.  Nested lists are not supported.
    which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`.
    batch_size: The new batch size pulled from the queue
      (python int or int32 scalar).
    num_buckets: A python integer, the number of buckets.
    num_threads: An integer.  The number of threads enqueuing `tensors`.
    capacity: An integer. The maximum number of minibatches in the top queue,
      and also the maximum number of elements within each bucket.
    shapes: (Optional) The shapes for each example.  Defaults to the
      inferred shapes for `tensors`.
    dynamic_pad: Boolean.  Allow variable dimensions in input shapes.
      The given dimensions are padded upon dequeue so that tensors within a
      batch have the same shapes.
    allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
      batches to be smaller if there are insufficient items left in the queues.
    keep_input: (Optional).  A `bool` scalar Tensor.  If provided, this tensor
      controls whether the input is added to the queue or not.  If it evaluates
      `True`, then `tensors` are added to the bucket; otherwise they are
      dropped.  This tensor essentially acts as a filtering mechanism.
      The default behavior is to assume `keep_input=True`.
    shared_name: (Optional). If set, the queues will be shared under the given
      name across multiple sessions.
    name: (Optional) A name for the operations.

  Returns:
    A tuple `(bucket, outputs)` where `bucket` is
    a `int32` scalar tensor and `outputs` is a list or
    dictionary of batched outputs corresponding to elements of `tensors`.
    Every step will receive a new bucket of outputs.

  Raises:
    ValueError: If the `shapes` are not specified, and cannot be
      inferred from the elements of `tensors`.
  """
  tensor_list = _as_tensor_list(tensors)
  with ops.name_scope(name, "bucket", tensor_list) as name:
    tensor_list = _validate_bucket(tensor_list)
    (tensor_list, sparse_info) = _store_sparse_tensors(
        tensor_list, enqueue_many=False)

    # Round-trip batch_size to a tensor, and possibly back
    batch_size = ops.convert_to_tensor(
        batch_size, dtype=dtypes.int32, name="batch_size")
    static_batch_size = tensor_util.constant_value(batch_size)
    batch_size = (
        static_batch_size if static_batch_size is not None else batch_size)

    types = _dtypes([tensor_list])
    shapes = _shapes([tensor_list], shapes, enqueue_many=False)

    which_bucket = ops.convert_to_tensor(
        which_bucket, dtype=dtypes.int32, name="which_bucket")

    queue_creator = _which_queue(dynamic_pad)
    bucket_queues = []
    for i in range(num_buckets):
      shared_name_i = (
          "%s_%d" % (shared_name, i) if shared_name is not None else None)
      bucket_queues.append(
          queue_creator(capacity=capacity,
                        dtypes=types,
                        shapes=shapes,
                        shared_name=shared_name_i, name="bucket_queue_%d" % i))

    maybe_static_batch_size = (
        None if allow_smaller_final_batch else static_batch_size)

    bucket_shapes = [tensor_shape.vector(maybe_static_batch_size).concatenate(s)
                     for s in bucket_queues[0].shapes]
    # top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO
    # queues because if we use allow_smaller_final_batch, shapes will
    # contain Nones in their first entry; as a result, a regular
    # FIFOQueue would die when being passed shapes that are not fully defined.
    top_queue = data_flow_ops.PaddingFIFOQueue(
        capacity=capacity,
        dtypes=[dtypes.int32] + types,
        shapes=[tensor_shape.scalar()] + bucket_shapes,
        shared_name=shared_name, name="top_queue")

    def enqueue_which():
      def enqueue_single(i):
        return bucket_queues[i].enqueue(tensor_list)
      enqueues = [
          control_flow_ops.cond(
              math_ops.equal(which_bucket, i),
              functools.partial(enqueue_single, i),
              control_flow_ops.no_op)
          for i in range(num_buckets)]
      return control_flow_ops.group(*enqueues, name="group_enqueues")

    if keep_input is not None:
      # TODO(ebrevdo): Expand keep_input param to core training
      # methods, and pipe through to _store_sparse_tensors; so
      # that expensive serialization is guarded by keep_input.
      maybe_enqueue = control_flow_ops.cond(
          keep_input,
          enqueue_which,
          control_flow_ops.no_op)
    else:
      maybe_enqueue = enqueue_which()

    bucket_enqueue_ops = [maybe_enqueue] * num_threads

    if allow_smaller_final_batch:
      which_dequeue = lambda q: q.dequeue_up_to
    else:
      which_dequeue = lambda q: q.dequeue_many

    enqueues_to_top = [
        top_queue.enqueue(
            [constant_op.constant(i)] +
            which_dequeue(q)(batch_size, name="read_bucket_%d" % i),
            name="enqueue_from_bucket_%d" % i)
        for i, q in enumerate(bucket_queues)]

    for i, q in enumerate(bucket_queues):
      queue_runner.add_queue_runner(queue_runner.QueueRunner(
          q, [enqueues_to_top[i]],
          queue_closed_exception_types=(
              errors.OutOfRangeError, errors.CancelledError)))
    queue_runner.add_queue_runner(queue_runner.QueueRunner(
        top_queue, bucket_enqueue_ops,
        queue_closed_exception_types=(
            errors.OutOfRangeError, errors.CancelledError)))

    for q in bucket_queues:
      summary.scalar("bucket/%s/size" % q.name,
                     math_ops.cast(top_queue.size(), dtypes.float32))
    summary.scalar("bucket/%s/fraction_of_%d_full" % (top_queue.name, capacity),
                   math_ops.cast(top_queue.size(), dtypes.float32) *
                   (1. / capacity))

    dequeued = top_queue.dequeue(name="dequeue_top")
    which_bucket_dequeued = dequeued[0]
    dequeued = dequeued[1:]
    dequeued = _restore_sparse_tensors(dequeued, sparse_info)
    return (which_bucket_dequeued, _as_original_type(tensors, dequeued))

Example 43

Project: deep_recommend_system
Source File: bucket_ops.py
View license
def bucket(tensors,
           which_bucket,
           batch_size,
           num_buckets,
           num_threads=1,
           capacity=32,
           shapes=None,
           dynamic_pad=False,
           allow_smaller_final_batch=False,
           keep_input=None,
           shared_name=None,
           name=None):
  """Lazy bucketing of input tensors according to `which_bucket`.

  The argument `tensors` can be a list or a dictionary of tensors.
  The value returned by the function will be of the same type
  as `tensors`.

  The tensors entering this function are put into the bucket given by
  `which_bucket`.  Each bucket has its own queue.  When a bucket contains
  `batch_size` elements, this minibatch is pushed onto a top queue.  The
  tensors returned from this function are a the result of dequeueing the
  next minibatch from this top queue.

  This function is implemented using several queues. A `QueueRunner` for the
  queues is added to the current `Graph`'s `QUEUE_RUNNER` collection.

  As the returned tensors are the result of of a dequeue operation, evaluating
  them will throw a `tf.errors.OutOfRangeError` when the input queue is
  exhausted.  If these tensors are feeding another input queue, its queue runner
  will catch this exception, however, if they are used in your main thread
  you are responsible for catching this yourself.

  *N.B.:* If `dynamic_pad` is `False`, you must ensure that either
  (i) the `shapes` argument is passed, or (ii) all of the tensors in
  `tensors` must have fully-defined shapes. `ValueError` will be
  raised if neither of these conditions holds.

  If `dynamic_pad` is `True`, it is sufficient that the *rank* of the
  tensors is known, but individual dimensions may have shape `None`.
  In this case, for each enqueue the dimensions with value `None`
  may have a variable length; upon dequeue, the output tensors will be padded
  on the right to the maximum shape of the tensors in the current minibatch.
  For numbers, this padding takes value 0.  For strings, this padding is
  the empty string.  See `PaddingFIFOQueue` for more info.

  If `allow_smaller_final_batch` is `True`, a smaller batch value than
  `batch_size` is returned when the queues are closed and there are not enough
  elements to fill the batch, otherwise the pending elements are discarded.
  In addition, all output tensors' static shapes, as accessed via the
  `get_shape()` method will have a 0th `Dimension` value of `None`, and
  operations that depend on fixed batch_size would fail.

  Args:
    tensors: The list or dictionary of tensors, representing a single element,
      to bucket.  Nested lists are not supported.
    which_bucket: An `int32` scalar Tensor taking a value in `[0, num_buckets)`.
    batch_size: The new batch size pulled from the queue
      (python int or int32 scalar).
    num_buckets: A python integer, the number of buckets.
    num_threads: An integer.  The number of threads enqueuing `tensors`.
    capacity: An integer. The maximum number of minibatches in the top queue,
      and also the maximum number of elements within each bucket.
    shapes: (Optional) The shapes for each example.  Defaults to the
      inferred shapes for `tensors`.
    dynamic_pad: Boolean.  Allow variable dimensions in input shapes.
      The given dimensions are padded upon dequeue so that tensors within a
      batch have the same shapes.
    allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
      batches to be smaller if there are insufficient items left in the queues.
    keep_input: (Optional).  A `bool` scalar Tensor.  If provided, this tensor
      controls whether the input is added to the queue or not.  If it evaluates
      `True`, then `tensors` are added to the bucket; otherwise they are
      dropped.  This tensor essentially acts as a filtering mechanism.
      The default behavior is to assume `keep_input=True`.
    shared_name: (Optional). If set, the queues will be shared under the given
      name across multiple sessions.
    name: (Optional) A name for the operations.

  Returns:
    A tuple `(bucket, outputs)` where `bucket` is
    a `int32` scalar tensor and `outputs` is a list or
    dictionary of batched outputs corresponding to elements of `tensors`.
    Every step will receive a new bucket of outputs.

  Raises:
    ValueError: If the `shapes` are not specified, and cannot be
      inferred from the elements of `tensors`.
  """
  tensor_list = _as_tensor_list(tensors)
  with ops.name_scope(name, "bucket", tensor_list) as name:
    tensor_list = _validate_bucket(tensor_list)
    (tensor_list, sparse_info) = _store_sparse_tensors(
        tensor_list, enqueue_many=False)

    # Round-trip batch_size to a tensor, and possibly back
    batch_size = ops.convert_to_tensor(
        batch_size, dtype=dtypes.int32, name="batch_size")
    static_batch_size = tensor_util.constant_value(batch_size)
    batch_size = (
        static_batch_size if static_batch_size is not None else batch_size)

    types = _dtypes([tensor_list])
    shapes = _shapes([tensor_list], shapes, enqueue_many=False)

    which_bucket = ops.convert_to_tensor(
        which_bucket, dtype=dtypes.int32, name="which_bucket")

    queue_creator = _which_queue(dynamic_pad)
    bucket_queues = []
    for i in range(num_buckets):
      shared_name_i = (
          "%s_%d" % (shared_name, i) if shared_name is not None else None)
      bucket_queues.append(
          queue_creator(capacity=capacity,
                        dtypes=types,
                        shapes=shapes,
                        shared_name=shared_name_i, name="bucket_queue_%d" % i))

    maybe_static_batch_size = (
        None if allow_smaller_final_batch else static_batch_size)

    bucket_shapes = [tensor_shape.vector(maybe_static_batch_size).concatenate(s)
                     for s in bucket_queues[0].shapes]
    # top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO
    # queues because if we use allow_smaller_final_batch, shapes will
    # contain Nones in their first entry; as a result, a regular
    # FIFOQueue would die when being passed shapes that are not fully defined.
    top_queue = data_flow_ops.PaddingFIFOQueue(
        capacity=capacity,
        dtypes=[dtypes.int32] + types,
        shapes=[tensor_shape.scalar()] + bucket_shapes,
        shared_name=shared_name, name="top_queue")

    def enqueue_which():
      def enqueue_single(i):
        return bucket_queues[i].enqueue(tensor_list)
      enqueues = [
          control_flow_ops.cond(
              math_ops.equal(which_bucket, i),
              functools.partial(enqueue_single, i),
              control_flow_ops.no_op)
          for i in range(num_buckets)]
      return control_flow_ops.group(*enqueues, name="group_enqueues")

    if keep_input is not None:
      # TODO(ebrevdo): Expand keep_input param to core training
      # methods, and pipe through to _store_sparse_tensors; so
      # that expensive serialization is guarded by keep_input.
      maybe_enqueue = control_flow_ops.cond(
          keep_input,
          enqueue_which,
          control_flow_ops.no_op)
    else:
      maybe_enqueue = enqueue_which()

    bucket_enqueue_ops = [maybe_enqueue] * num_threads

    if allow_smaller_final_batch:
      which_dequeue = lambda q: q.dequeue_up_to
    else:
      which_dequeue = lambda q: q.dequeue_many

    enqueues_to_top = [
        top_queue.enqueue(
            [constant_op.constant(i)] +
            which_dequeue(q)(batch_size, name="read_bucket_%d" % i),
            name="enqueue_from_bucket_%d" % i)
        for i, q in enumerate(bucket_queues)]

    for i, q in enumerate(bucket_queues):
      queue_runner.add_queue_runner(queue_runner.QueueRunner(
          q, [enqueues_to_top[i]],
          queue_closed_exception_types=(
              errors.OutOfRangeError, errors.CancelledError)))
    queue_runner.add_queue_runner(queue_runner.QueueRunner(
        top_queue, bucket_enqueue_ops,
        queue_closed_exception_types=(
            errors.OutOfRangeError, errors.CancelledError)))

    for q in bucket_queues:
      summary.scalar("bucket/%s/size" % q.name,
                     math_ops.cast(top_queue.size(), dtypes.float32))
    summary.scalar("bucket/%s/fraction_of_%d_full" % (top_queue.name, capacity),
                   math_ops.cast(top_queue.size(), dtypes.float32) *
                   (1. / capacity))

    dequeued = top_queue.dequeue(name="dequeue_top")
    which_bucket_dequeued = dequeued[0]
    dequeued = dequeued[1:]
    dequeued = _restore_sparse_tensors(dequeued, sparse_info)
    return (which_bucket_dequeued, _as_original_type(tensors, dequeued))

Example 44

Project: Fiona
Source File: collect.py
View license
@click.command(short_help="Collect a sequence of features.")
@cligj.precision_opt
@cligj.indent_opt
@cligj.compact_opt
@click.option('--record-buffered/--no-record-buffered', default=False,
              help="Economical buffering of writes at record, not collection "
              "(default), level.")
@click.option('--ignore-errors/--no-ignore-errors', default=False,
              help="log errors but do not stop serialization.")
@options.src_crs_opt
@click.option('--with-ld-context/--without-ld-context', default=False,
              help="add a JSON-LD context to JSON output.")
@click.option('--add-ld-context-item', multiple=True,
              help="map a term to a URI and add it to the output's JSON LD "
                   "context.")
@click.option('--parse/--no-parse', default=True,
              help="load and dump the geojson feature (default is True)")
@click.pass_context
def collect(ctx, precision, indent, compact, record_buffered, ignore_errors,
            src_crs, with_ld_context, add_ld_context_item, parse):
    """Make a GeoJSON feature collection from a sequence of GeoJSON
    features and print it."""
    verbosity = (ctx.obj and ctx.obj['verbosity']) or 2
    logger = logging.getLogger('fio')
    stdin = click.get_text_stream('stdin')
    sink = click.get_text_stream('stdout')

    dump_kwds = {'sort_keys': True}
    if indent:
        dump_kwds['indent'] = indent
    if compact:
        dump_kwds['separators'] = (',', ':')
    item_sep = compact and ',' or ', '

    if src_crs:
        if not parse:
            raise click.UsageError("Can't specify --src-crs with --no-parse")
        transformer = partial(transform_geom, src_crs, 'EPSG:4326',
                              antimeridian_cutting=True, precision=precision)
    else:
        transformer = lambda x: x

    first_line = next(stdin)

    # If parsing geojson
    if parse:
        # If input is RS-delimited JSON sequence.
        if first_line.startswith(u'\x1e'):
            def feature_text_gen():
                buffer = first_line.strip(u'\x1e')
                for line in stdin:
                    if line.startswith(u'\x1e'):
                        if buffer:
                            feat = json.loads(buffer)
                            feat['geometry'] = transformer(feat['geometry'])
                            yield json.dumps(feat, **dump_kwds)
                        buffer = line.strip(u'\x1e')
                    else:
                        buffer += line
                else:
                    feat = json.loads(buffer)
                    feat['geometry'] = transformer(feat['geometry'])
                    yield json.dumps(feat, **dump_kwds)
        else:
            def feature_text_gen():
                feat = json.loads(first_line)
                feat['geometry'] = transformer(feat['geometry'])
                yield json.dumps(feat, **dump_kwds)

                for line in stdin:
                    feat = json.loads(line)
                    feat['geometry'] = transformer(feat['geometry'])
                    yield json.dumps(feat, **dump_kwds)

    # If *not* parsing geojson
    else:
        # If input is RS-delimited JSON sequence.
        if first_line.startswith(u'\x1e'):
            def feature_text_gen():
                buffer = first_line.strip(u'\x1e')
                for line in stdin:
                    if line.startswith(u'\x1e'):
                        if buffer:
                            yield buffer
                        buffer = line.strip(u'\x1e')
                    else:
                        buffer += line
                else:
                    yield buffer
        else:
            def feature_text_gen():
                yield first_line
                for line in stdin:
                    yield line

    try:
        source = feature_text_gen()

        if record_buffered:
            # Buffer GeoJSON data at the feature level for smaller
            # memory footprint.
            indented = bool(indent)
            rec_indent = "\n" + " " * (2 * (indent or 0))

            collection = {
                'type': 'FeatureCollection',
                'features': []}
            if with_ld_context:
                collection['@context'] = helpers.make_ld_context(
                    add_ld_context_item)

            head, tail = json.dumps(collection, **dump_kwds).split('[]')

            sink.write(head)
            sink.write("[")

            # Try the first record.
            try:
                i, first = 0, next(source)
                if with_ld_context:
                    first = helpers.id_record(first)
                if indented:
                    sink.write(rec_indent)
                sink.write(first.replace("\n", rec_indent))
            except StopIteration:
                pass
            except Exception as exc:
                # Ignoring errors is *not* the default.
                if ignore_errors:
                    logger.error(
                        "failed to serialize file record %d (%s), "
                        "continuing",
                        i, exc)
                else:
                    # Log error and close up the GeoJSON, leaving it
                    # more or less valid no matter what happens above.
                    logger.critical(
                        "failed to serialize file record %d (%s), "
                        "quiting",
                        i, exc)
                    sink.write("]")
                    sink.write(tail)
                    if indented:
                        sink.write("\n")
                    raise

            # Because trailing commas aren't valid in JSON arrays
            # we'll write the item separator before each of the
            # remaining features.
            for i, rec in enumerate(source, 1):
                try:
                    if with_ld_context:
                        rec = helpers.id_record(rec)
                    if indented:
                        sink.write(rec_indent)
                    sink.write(item_sep)
                    sink.write(rec.replace("\n", rec_indent))
                except Exception as exc:
                    if ignore_errors:
                        logger.error(
                            "failed to serialize file record %d (%s), "
                            "continuing",
                            i, exc)
                    else:
                        logger.critical(
                            "failed to serialize file record %d (%s), "
                            "quiting",
                            i, exc)
                        sink.write("]")
                        sink.write(tail)
                        if indented:
                            sink.write("\n")
                        raise

            # Close up the GeoJSON after writing all features.
            sink.write("]")
            sink.write(tail)
            if indented:
                sink.write("\n")

        else:
            # Buffer GeoJSON data at the collection level. The default.
            collection = {
                'type': 'FeatureCollection',
                'features': []}
            if with_ld_context:
                collection['@context'] = helpers.make_ld_context(
                    add_ld_context_item)

            head, tail = json.dumps(collection, **dump_kwds).split('[]')
            sink.write(head)
            sink.write("[")
            sink.write(",".join(source))
            sink.write("]")
            sink.write(tail)
            sink.write("\n")

    except Exception:
        logger.exception("Exception caught during processing")
        raise click.Abort()

Example 45

Project: Fiona
Source File: dump.py
View license
@click.command(short_help="Dump a dataset to GeoJSON.")
@click.argument('input', type=click.Path(), required=True)
@click.option('--layer', metavar="INDEX|NAME", callback=options.cb_layer,
              help="Print information about a specific layer.  The first "
                   "layer is used by default.  Layers use zero-based "
                   "numbering when accessed by index.")
@click.option('--encoding', help="Specify encoding of the input file.")
@cligj.precision_opt
@cligj.indent_opt
@cligj.compact_opt
@click.option('--record-buffered/--no-record-buffered', default=False,
              help="Economical buffering of writes at record, not collection "
                   "(default), level.")
@click.option('--ignore-errors/--no-ignore-errors', default=False,
              help="log errors but do not stop serialization.")
@click.option('--with-ld-context/--without-ld-context', default=False,
              help="add a JSON-LD context to JSON output.")
@click.option('--add-ld-context-item', multiple=True,
              help="map a term to a URI and add it to the output's JSON LD "
                   "context.")
@click.pass_context
def dump(ctx, input, encoding, precision, indent, compact, record_buffered,
         ignore_errors, with_ld_context, add_ld_context_item, layer):

    """Dump a dataset either as a GeoJSON feature collection (the default)
    or a sequence of GeoJSON features."""

    verbosity = (ctx.obj and ctx.obj['verbosity']) or 2
    logger = logging.getLogger('fio')
    sink = click.get_text_stream('stdout')

    dump_kwds = {'sort_keys': True}
    if indent:
        dump_kwds['indent'] = indent
    if compact:
        dump_kwds['separators'] = (',', ':')
    item_sep = compact and ',' or ', '

    open_kwds = {}
    if encoding:
        open_kwds['encoding'] = encoding
    if layer:
        open_kwds['layer'] = layer

    def transformer(crs, feat):
        tg = partial(transform_geom, crs, 'EPSG:4326',
                     antimeridian_cutting=True, precision=precision)
        feat['geometry'] = tg(feat['geometry'])
        return feat

    try:
        with fiona.drivers(CPL_DEBUG=verbosity > 2):
            with fiona.open(input, **open_kwds) as source:
                meta = source.meta
                meta['fields'] = dict(source.schema['properties'].items())

                if record_buffered:
                    # Buffer GeoJSON data at the feature level for smaller
                    # memory footprint.
                    indented = bool(indent)
                    rec_indent = "\n" + " " * (2 * (indent or 0))

                    collection = {
                        'type': 'FeatureCollection',
                        'fiona:schema': meta['schema'],
                        'fiona:crs': meta['crs'],
                        'features': []}
                    if with_ld_context:
                        collection['@context'] = helpers.make_ld_context(
                            add_ld_context_item)

                    head, tail = json.dumps(
                        collection, **dump_kwds).split('[]')

                    sink.write(head)
                    sink.write("[")

                    itr = iter(source)

                    # Try the first record.
                    try:
                        i, first = 0, next(itr)
                        first = transformer(first)
                        if with_ld_context:
                            first = helpers.id_record(first)
                        if indented:
                            sink.write(rec_indent)
                        sink.write(json.dumps(
                            first, **dump_kwds).replace("\n", rec_indent))
                    except StopIteration:
                        pass
                    except Exception as exc:
                        # Ignoring errors is *not* the default.
                        if ignore_errors:
                            logger.error(
                                "failed to serialize file record %d (%s), "
                                "continuing",
                                i, exc)
                        else:
                            # Log error and close up the GeoJSON, leaving it
                            # more or less valid no matter what happens above.
                            logger.critical(
                                "failed to serialize file record %d (%s), "
                                "quiting",
                                i, exc)
                            sink.write("]")
                            sink.write(tail)
                            if indented:
                                sink.write("\n")
                            raise

                    # Because trailing commas aren't valid in JSON arrays
                    # we'll write the item separator before each of the
                    # remaining features.
                    for i, rec in enumerate(itr, 1):
                        rec = transformer(rec)
                        try:
                            if with_ld_context:
                                rec = helpers.id_record(rec)
                            if indented:
                                sink.write(rec_indent)
                            sink.write(item_sep)
                            sink.write(json.dumps(
                                rec, **dump_kwds).replace("\n", rec_indent))
                        except Exception as exc:
                            if ignore_errors:
                                logger.error(
                                    "failed to serialize file record %d (%s), "
                                    "continuing",
                                    i, exc)
                            else:
                                logger.critical(
                                    "failed to serialize file record %d (%s), "
                                    "quiting",
                                    i, exc)
                                sink.write("]")
                                sink.write(tail)
                                if indented:
                                    sink.write("\n")
                                raise

                    # Close up the GeoJSON after writing all features.
                    sink.write("]")
                    sink.write(tail)
                    if indented:
                        sink.write("\n")

                else:
                    # Buffer GeoJSON data at the collection level. The default.
                    collection = {
                        'type': 'FeatureCollection',
                        'fiona:schema': meta['schema'],
                        'fiona:crs': meta['crs']}
                    if with_ld_context:
                        collection['@context'] = helpers.make_ld_context(
                            add_ld_context_item)
                        collection['features'] = [
                            helpers.id_record(transformer(rec))
                            for rec in source]
                    else:
                        collection['features'] = [
                            transformer(source.crs, rec) for rec in source]
                    json.dump(collection, sink, **dump_kwds)

    except Exception:
        logger.exception("Exception caught during processing")
        raise click.Abort()

Example 46

Project: RenderPipeline
Source File: main.py
View license
    def _get_widget_for_setting(self, setting_id, setting):
        """ Returns an appropriate widget to control the given setting """

        widget = QWidget()
        layout = QHBoxLayout()
        layout.setAlignment(Qt.AlignCenter)
        widget.setLayout(layout)


        if setting.type == "bool":
            box = QCheckBox()
            box.setChecked(Qt.Checked if setting.value else Qt.Unchecked)
            qt_connect(box, "stateChanged(int)",
                    partial(self._on_setting_bool_changed, setting_id))
            layout.addWidget(box)

        elif setting.type == "float" or setting.type == "int":

            if setting.type == "float":
                box = QDoubleSpinBox()

                if setting.maxval - setting.minval <= 2.0:
                    box.setDecimals(4)
            else:
                box = QSpinBox()
            box.setMinimum(setting.minval)
            box.setMaximum(setting.maxval)
            box.setValue(setting.value)

            box.setAlignment(Qt.AlignCenter)

            slider = QSlider()
            slider.setOrientation(Qt.Horizontal)

            if setting.type == "float":
                box.setSingleStep(abs(setting.maxval - setting.minval) / 100.0)
                slider.setMinimum(setting.minval * 100000.0)
                slider.setMaximum(setting.maxval * 100000.0)
                slider.setValue(setting.value * 100000.0)
            elif setting.type == "int":
                box.setSingleStep(max(1, (setting.maxval - setting.minval) / 32))
                slider.setMinimum(setting.minval)
                slider.setMaximum(setting.maxval)
                slider.setValue(setting.value)

            layout.addWidget(box)
            layout.addWidget(slider)

            qt_connect(slider, "valueChanged(int)",
                    partial(self._on_setting_slider_changed, setting_id, setting.type, [box]))

            value_type = "int" if setting.type == "int" else "double"

            qt_connect(box, "valueChanged(" + value_type + ")",
                    partial(self._on_setting_spinbox_changed, setting_id, setting.type, [slider]))

        elif setting.type == "enum":
            box = QComboBox()
            for value in setting.values:
                box.addItem(value)
            qt_connect(box, "currentIndexChanged(QString)",
                    partial(self._on_setting_enum_changed, setting_id))
            box.setCurrentIndex(setting.values.index(setting.value))
            box.setMinimumWidth(145)

            layout.addWidget(box)

        elif setting.type == "power_of_two":

            box = QComboBox()
            resolutions = [str(2**i) for i in range(1, 32) if 2**i >= setting.minval and 2**i <= setting.maxval]
            for value in resolutions:
                box.addItem(value)
            qt_connect(box, "currentIndexChanged(QString)",
                    partial(self._on_setting_power_of_two_changed, setting_id))
            box.setCurrentIndex(resolutions.index(str(setting.value)))
            box.setMinimumWidth(145)
            layout.addWidget(box)

        elif setting.type == "sample_sequence":

            box = QComboBox()
            for value in setting.sequences:
                box.addItem(value)
            qt_connect(box, "currentIndexChanged(QString)",
                    partial(self._on_setting_enum_changed, setting_id))
            box.setCurrentIndex(setting.sequences.index(str(setting.value)))
            box.setMinimumWidth(145)
            layout.addWidget(box)

        elif setting.type == "path":

            label = QLabel()
            display_file = setting.value.replace("\\", "/").split("/")[-1]

            desc_font = QFont()
            desc_font.setPointSize(7)
            desc_font.setFamily("Roboto")

            label.setText(display_file)
            label.setFont(desc_font)

            button = QPushButton()
            button.setText("Choose File ...")
            qt_connect(button, "clicked()", partial(
                self._choose_path, setting_id, setting, (label,)))

            layout.addWidget(label)
            layout.addWidget(button)

        else:
            print("ERROR: Unkown setting type:", setting.type)

        return widget

Example 47

Project: tornado
Source File: curl_httpclient.py
View license
    def _curl_setup_request(self, curl, request, buffer, headers):
        curl.setopt(pycurl.URL, native_str(request.url))

        # libcurl's magic "Expect: 100-continue" behavior causes delays
        # with servers that don't support it (which include, among others,
        # Google's OpenID endpoint).  Additionally, this behavior has
        # a bug in conjunction with the curl_multi_socket_action API
        # (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
        # which increases the delays.  It's more trouble than it's worth,
        # so just turn off the feature (yes, setting Expect: to an empty
        # value is the official way to disable this)
        if "Expect" not in request.headers:
            request.headers["Expect"] = ""

        # libcurl adds Pragma: no-cache by default; disable that too
        if "Pragma" not in request.headers:
            request.headers["Pragma"] = ""

        curl.setopt(pycurl.HTTPHEADER,
                    ["%s: %s" % (native_str(k), native_str(v))
                     for k, v in request.headers.get_all()])

        curl.setopt(pycurl.HEADERFUNCTION,
                    functools.partial(self._curl_header_callback,
                                      headers, request.header_callback))
        if request.streaming_callback:
            def write_function(chunk):
                self.io_loop.add_callback(request.streaming_callback, chunk)
        else:
            write_function = buffer.write
        if bytes is str:  # py2
            curl.setopt(pycurl.WRITEFUNCTION, write_function)
        else:  # py3
            # Upstream pycurl doesn't support py3, but ubuntu 12.10 includes
            # a fork/port.  That version has a bug in which it passes unicode
            # strings instead of bytes to the WRITEFUNCTION.  This means that
            # if you use a WRITEFUNCTION (which tornado always does), you cannot
            # download arbitrary binary data.  This needs to be fixed in the
            # ported pycurl package, but in the meantime this lambda will
            # make it work for downloading (utf8) text.
            curl.setopt(pycurl.WRITEFUNCTION, lambda s: write_function(utf8(s)))
        curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
        curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
        curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
        curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
        if request.user_agent:
            curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
        else:
            curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
        if request.network_interface:
            curl.setopt(pycurl.INTERFACE, request.network_interface)
        if request.decompress_response:
            curl.setopt(pycurl.ENCODING, "gzip,deflate")
        else:
            curl.setopt(pycurl.ENCODING, "none")
        if request.proxy_host and request.proxy_port:
            curl.setopt(pycurl.PROXY, request.proxy_host)
            curl.setopt(pycurl.PROXYPORT, request.proxy_port)
            if request.proxy_username:
                credentials = '%s:%s' % (request.proxy_username,
                                         request.proxy_password)
                curl.setopt(pycurl.PROXYUSERPWD, credentials)

            if (request.proxy_auth_mode is None or
                    request.proxy_auth_mode == "basic"):
                curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC)
            elif request.proxy_auth_mode == "digest":
                curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST)
            else:
                raise ValueError(
                    "Unsupported proxy_auth_mode %s" % request.proxy_auth_mode)
        else:
            curl.setopt(pycurl.PROXY, '')
            curl.unsetopt(pycurl.PROXYUSERPWD)
        if request.validate_cert:
            curl.setopt(pycurl.SSL_VERIFYPEER, 1)
            curl.setopt(pycurl.SSL_VERIFYHOST, 2)
        else:
            curl.setopt(pycurl.SSL_VERIFYPEER, 0)
            curl.setopt(pycurl.SSL_VERIFYHOST, 0)
        if request.ca_certs is not None:
            curl.setopt(pycurl.CAINFO, request.ca_certs)
        else:
            # There is no way to restore pycurl.CAINFO to its default value
            # (Using unsetopt makes it reject all certificates).
            # I don't see any way to read the default value from python so it
            # can be restored later.  We'll have to just leave CAINFO untouched
            # if no ca_certs file was specified, and require that if any
            # request uses a custom ca_certs file, they all must.
            pass

        if request.allow_ipv6 is False:
            # Curl behaves reasonably when DNS resolution gives an ipv6 address
            # that we can't reach, so allow ipv6 unless the user asks to disable.
            curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
        else:
            curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)

        # Set the request method through curl's irritating interface which makes
        # up names for almost every single method
        curl_options = {
            "GET": pycurl.HTTPGET,
            "POST": pycurl.POST,
            "PUT": pycurl.UPLOAD,
            "HEAD": pycurl.NOBODY,
        }
        custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
        for o in curl_options.values():
            curl.setopt(o, False)
        if request.method in curl_options:
            curl.unsetopt(pycurl.CUSTOMREQUEST)
            curl.setopt(curl_options[request.method], True)
        elif request.allow_nonstandard_methods or request.method in custom_methods:
            curl.setopt(pycurl.CUSTOMREQUEST, request.method)
        else:
            raise KeyError('unknown method ' + request.method)

        body_expected = request.method in ("POST", "PATCH", "PUT")
        body_present = request.body is not None
        if not request.allow_nonstandard_methods:
            # Some HTTP methods nearly always have bodies while others
            # almost never do. Fail in this case unless the user has
            # opted out of sanity checks with allow_nonstandard_methods.
            if ((body_expected and not body_present) or
                    (body_present and not body_expected)):
                raise ValueError(
                    'Body must %sbe None for method %s (unless '
                    'allow_nonstandard_methods is true)' %
                    ('not ' if body_expected else '', request.method))

        if body_expected or body_present:
            if request.method == "GET":
                # Even with `allow_nonstandard_methods` we disallow
                # GET with a body (because libcurl doesn't allow it
                # unless we use CUSTOMREQUEST). While the spec doesn't
                # forbid clients from sending a body, it arguably
                # disallows the server from doing anything with them.
                raise ValueError('Body must be None for GET request')
            request_buffer = BytesIO(utf8(request.body or ''))

            def ioctl(cmd):
                if cmd == curl.IOCMD_RESTARTREAD:
                    request_buffer.seek(0)
            curl.setopt(pycurl.READFUNCTION, request_buffer.read)
            curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
            if request.method == "POST":
                curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ''))
            else:
                curl.setopt(pycurl.UPLOAD, True)
                curl.setopt(pycurl.INFILESIZE, len(request.body or ''))

        if request.auth_username is not None:
            userpwd = "%s:%s" % (request.auth_username, request.auth_password or '')

            if request.auth_mode is None or request.auth_mode == "basic":
                curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
            elif request.auth_mode == "digest":
                curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
            else:
                raise ValueError("Unsupported auth_mode %s" % request.auth_mode)

            curl.setopt(pycurl.USERPWD, native_str(userpwd))
            curl_log.debug("%s %s (username: %r)", request.method, request.url,
                           request.auth_username)
        else:
            curl.unsetopt(pycurl.USERPWD)
            curl_log.debug("%s %s", request.method, request.url)

        if request.client_cert is not None:
            curl.setopt(pycurl.SSLCERT, request.client_cert)

        if request.client_key is not None:
            curl.setopt(pycurl.SSLKEY, request.client_key)

        if request.ssl_options is not None:
            raise ValueError("ssl_options not supported in curl_httpclient")

        if threading.activeCount() > 1:
            # libcurl/pycurl is not thread-safe by default.  When multiple threads
            # are used, signals should be disabled.  This has the side effect
            # of disabling DNS timeouts in some environments (when libcurl is
            # not linked against ares), so we don't do it when there is only one
            # thread.  Applications that use many short-lived threads may need
            # to set NOSIGNAL manually in a prepare_curl_callback since
            # there may not be any other threads running at the time we call
            # threading.activeCount.
            curl.setopt(pycurl.NOSIGNAL, 1)
        if request.prepare_curl_callback is not None:
            request.prepare_curl_callback(curl)

Example 48

Project: bakthat
Source File: __init__.py
View license
@app.cmd(help="Backup a file or a directory, backup the current directory if no arg is provided.")
@app.cmd_arg('filename', type=str, default=os.getcwd(), nargs="?")
@app.cmd_arg('-d', '--destination', type=str, help="s3|glacier|swift", default=None)
@app.cmd_arg('--prompt', type=str, help="yes|no", default="yes")
@app.cmd_arg('-t', '--tags', type=str, help="space separated tags", default="")
@app.cmd_arg('-p', '--profile', type=str, default="default", help="profile name (default by default)")
@app.cmd_arg('-c', '--config', type=str, default=CONFIG_FILE, help="path to config file")
@app.cmd_arg('-k', '--key', type=str, default=None, help="Custom key for periodic backups (works only with BakManager.io hook.)")
@app.cmd_arg('--exclude-file', type=str, default=None)
@app.cmd_arg('--s3-reduced-redundancy', action="store_true")
def backup(filename=os.getcwd(), destination=None, profile="default", config=CONFIG_FILE, prompt="yes", tags=[], key=None, exclude_file=None, s3_reduced_redundancy=False, **kwargs):
    """Perform backup.

    :type filename: str
    :param filename: File/directory to backup.

    :type destination: str
    :param destination: s3|glacier|swift

    :type prompt: str
    :param prompt: Disable password promp, disable encryption,
        only useful when using bakthat in command line mode.

    :type tags: str or list
    :param tags: Tags either in a str space separated,
        either directly a list of str (if calling from Python).

    :type password: str
    :keyword password: Password, empty string to disable encryption.

    :type conf: dict
    :keyword conf: Override/set AWS configuration.

    :type custom_filename: str
    :keyword custom_filename: Override the original filename (only in metadata)

    :rtype: dict
    :return: A dict containing the following keys: stored_filename, size, metadata, backend and filename.

    """
    storage_backend, destination, conf = _get_store_backend(config, destination, profile)
    backup_file_fmt = "{0}.{1}.tgz"

    session_id = str(uuid.uuid4())
    events.before_backup(session_id)

    # Check if compression is disabled on the configuration.
    if conf:
        compress = conf.get("compress", True)
    else:
        compress = config.get(profile).get("compress", True)

    if not compress:
        backup_file_fmt = "{0}.{1}"

    log.info("Backing up " + filename)

    if exclude_file and os.path.isfile(exclude_file):
        EXCLUDE_FILES.insert(0, exclude_file)

    _exclude = lambda filename: False
    if os.path.isdir(filename):
        join = functools.partial(os.path.join, filename)
        for efile in EXCLUDE_FILES:
            efile = join(efile)
            if os.path.isfile(efile):
                _exclude = _get_exclude(efile)
                log.info("Using {0} to exclude files.".format(efile))

    arcname = filename.strip('/').split('/')[-1]
    now = datetime.utcnow()
    date_component = now.strftime("%Y%m%d%H%M%S")
    stored_filename = backup_file_fmt.format(arcname, date_component)

    backup_date = int(now.strftime("%s"))
    backup_data = dict(filename=kwargs.get("custom_filename", arcname),
                       backup_date=backup_date,
                       last_updated=backup_date,
                       backend=destination,
                       is_deleted=False)

    # Useful only when using bakmanager.io hook
    backup_key = key

    password = kwargs.get("password", os.environ.get("BAKTHAT_PASSWORD"))
    if password is None and prompt.lower() != "no":
        password = getpass("Password (blank to disable encryption): ")
        if password:
            password2 = getpass("Password confirmation: ")
            if password != password2:
                log.error("Password confirmation doesn't match")
                return

    if not compress:
        log.info("Compression disabled")
        outname = filename
        with open(outname) as outfile:
            backup_data["size"] = os.fstat(outfile.fileno()).st_size
        bakthat_compression = False

    # Check if the file is not already compressed
    elif mimetypes.guess_type(arcname) == ('application/x-tar', 'gzip'):
        log.info("File already compressed")
        outname = filename

        # removing extension to reformat filename
        new_arcname = re.sub(r'(\.t(ar\.)?gz)', '', arcname)
        stored_filename = backup_file_fmt.format(new_arcname, date_component)

        with open(outname) as outfile:
            backup_data["size"] = os.fstat(outfile.fileno()).st_size

        bakthat_compression = False
    else:
        # If not we compress it
        log.info("Compressing...")

        with tempfile.NamedTemporaryFile(delete=False) as out:
            with closing(tarfile.open(fileobj=out, mode="w:gz")) as tar:
                tar.add(filename, arcname=arcname, exclude=_exclude)
            outname = out.name
            out.seek(0)
            backup_data["size"] = os.fstat(out.fileno()).st_size
        bakthat_compression = True

    bakthat_encryption = False
    if password:
        bakthat_encryption = True
        log.info("Encrypting...")
        encrypted_out = tempfile.NamedTemporaryFile(delete=False)
        encrypt_file(outname, encrypted_out.name, password)
        stored_filename += ".enc"

        # We only remove the file if the archive is created by bakthat
        if bakthat_compression:
            os.remove(outname)  # remove non-encrypted tmp file

        outname = encrypted_out.name

        encrypted_out.seek(0)
        backup_data["size"] = os.fstat(encrypted_out.fileno()).st_size

    # Handling tags metadata
    if isinstance(tags, list):
        tags = " ".join(tags)

    backup_data["tags"] = tags

    backup_data["metadata"] = dict(is_enc=bakthat_encryption,
                                   client=socket.gethostname())
    backup_data["stored_filename"] = stored_filename

    access_key = storage_backend.conf.get("access_key")
    container_key = storage_backend.conf.get(storage_backend.container_key)
    backup_data["backend_hash"] = hashlib.sha512(access_key + container_key).hexdigest()

    log.info("Uploading...")
    storage_backend.upload(stored_filename, outname, s3_reduced_redundancy=s3_reduced_redundancy)

    # We only remove the file if the archive is created by bakthat
    if bakthat_compression or bakthat_encryption:
        os.remove(outname)

    log.debug(backup_data)

    # Insert backup metadata in SQLite
    backup = Backups.create(**backup_data)

    BakSyncer(conf).sync_auto()

    # bakmanager.io hook, enable with -k/--key paramter
    if backup_key:
        bakmanager_hook(conf, backup_data, backup_key)

    events.on_backup(session_id, backup)

    return backup

Example 49

Project: hustle
Source File: pipeline.py
View license
    def __init__(self,
                 master,
                 wheres,
                 project=(),
                 order_by=(),
                 join=(),
                 full_join=False,
                 distinct=False,
                 desc=False,
                 limit=0,
                 partition=0,
                 nest=False,
                 wide=False,
                 pre_order_stage=(),
                 tag=None,
                 max_cores=0,
                 profile=False):
        from hustle.core.pipeworker import Worker

        super(SelectPipe, self).__init__(master=master, worker=Worker())
        if max_cores < 0:
            max_cores = 0
        if max_cores > 0:
            self.scheduler = {"max_cores": max_cores}
        self.profile = profile
        self.wheres = wheres
        self.order_by = self._resolve(order_by, project)
        partition = partition or _NPART
        binaries = [i for i, c in enumerate(project)
                    if isinstance(c, (Column, Aggregation)) and c.is_binary]
        # if nest is true, use output_schema to store the output table
        self.output_table = None

        # aggregation functions and their defaults
        efs, gees, ehches, dflts = zip(*[(c.f, c.g, c.h, c.default)
                                         if isinstance(c, Aggregation)
                                         else (dflt_f, dflt_gh, dflt_gh, dflt_default)
                                         for c in project])
        need_agg = False  # need to commit aggregatation
        all_agg = True    # whether all columns in select are aggregates
        for c in project:
            if isinstance(c, Aggregation):
                need_agg = True
            else:
                all_agg = False
        if all_agg:
            _agg_fn = _aggregate_fast
        else:
            _agg_fn = _aggregate

        # build the pipeline
        select_hash_cols = ()
        sort_range = _get_sort_range(0, project, self.order_by)

        join_stage = []
        if join or full_join:
            joinbins = [i + 2 for i in binaries]
            join_stage = [
                (GROUP_LABEL,
                 HustleStage('join',
                             sort=(1, 0),
                             binaries=joinbins,
                             process=partial(process_join,
                                             full_join=full_join,
                                             ffuncs=efs,
                                             ghfuncs=ehches,
                                             deffuncs=dflts,
                                             wide=wide,
                                             need_agg=need_agg,
                                             agg_fn=_agg_fn,
                                             label_fn=partial(_tuple_hash,
                                                              cols=sort_range,
                                                              p=partition))))]
            select_hash_cols = (1,)

        group_by_stage = []
        if need_agg:
            # If all columns in project are aggregations, use process_skip_group
            # to skip the internal groupby
            if all_agg:
                process_group_fn = process_skip_group
                group_by_range = []
            else:
                process_group_fn = process_group
                group_by_range = [i for i, c in enumerate(project)
                                  if isinstance(c, Column)]

            # build the pipeline
            group_by_stage = []
            if wide:
                group_by_stage = [
                    (GROUP_LABEL_NODE,
                     HustleStage('group-combine',
                                 sort=group_by_range,
                                 binaries=binaries,
                                 process=partial(process_group_fn,
                                                 ffuncs=efs,
                                                 ghfuncs=ehches,
                                                 deffuncs=dflts,
                                                 label_fn=partial(_tuple_hash,
                                                                  cols=group_by_range,
                                                                  p=partition))))]
            # A Hack here that overrides disco stage's default option 'combine'.
            # Hustle needs all inputs with the same label to be combined.
            group_by_stage.append((GROUP_LABEL,
                                   HustleStage('group-reduce',
                                               combine=True,
                                               input_sorted=wide,
                                               sort=group_by_range,
                                               binaries=binaries,
                                               process=partial(process_group_fn,
                                                               ffuncs=efs,
                                                               ghfuncs=gees,
                                                               deffuncs=dflts))))

        # process the order_by/distinct stage
        order_stage = []
        if self.order_by or distinct or limit:
            order_stage = [
                (GROUP_LABEL_NODE,
                 HustleStage('order-combine',
                             sort=sort_range,
                             binaries=binaries,
                             desc=desc,
                             process=partial(process_order,
                                             distinct=distinct,
                                             limit=limit or sys.maxint))),
                (GROUP_ALL,
                 HustleStage('order-reduce',
                             sort=sort_range,
                             desc=desc,
                             input_sorted=True,
                             combine_labels=True,
                             process=partial(process_order,
                                             distinct=distinct,
                                             limit=limit or sys.maxint))),
            ]

        if not select_hash_cols:
            select_hash_cols = sort_range

        key_names = self._get_key_names(project, join)

        restrict_distinct = False
        restrict_limit = 0
        # check whether need to do a distinct/limit in the restrict select stage
        if not (join or full_join) and not need_agg and not self.order_by and distinct:
            restrict_distinct = True
            restrict_limit = limit or 0
        # check whether need to do a limit in the hustle_input stream
        input_stream_limit = 0
        if not (join or full_join) and not need_agg and not self.order_by and not distinct:
            input_stream_limit = limit or 0
        pipeline = [(SPLIT,
                     HustleStage('restrict-select',
                                 # combine=True,  # cannot set combine -- see #hack in restrict-select phase
                                 process=partial(process_restrict,
                                                 ffuncs=efs,
                                                 ghfuncs=ehches,
                                                 deffuncs=dflts,
                                                 wide=wide or join or full_join,
                                                 need_agg=need_agg,
                                                 agg_fn=_agg_fn,
                                                 distinct=restrict_distinct,
                                                 limit=restrict_limit or sys.maxint,
                                                 label_fn=partial(_tuple_hash,
                                                                  cols=select_hash_cols,
                                                                  p=partition)),
                                 input_chain=[task_input_stream,
                                              partial(hustle_input_stream,
                                                      wheres=wheres,
                                                      gen_where_index=join or full_join,
                                                      key_names=key_names,
                                                      limit=input_stream_limit or sys.maxint)]))
                    ] + join_stage + group_by_stage + list(pre_order_stage) + order_stage

        # determine the style of output (ie. if it is a Hustle Table),
        # and modify the last stage accordingly
        if nest:
            pipeline[-1][1].output_chain = \
                [partial(hustle_output_stream, result_table=self.get_result_schema(project, tag))]
        self.pipeline = pipeline

Example 50

Project: labelImg
Source File: labelImg.py
View license
    def __init__(self, filename=None):
        super(MainWindow, self).__init__()
        self.setWindowTitle(__appname__)
        # Save as Pascal voc xml
        self.defaultSaveDir = None
        self.usingPascalVocFormat = True
        if self.usingPascalVocFormat:
            LabelFile.suffix = '.xml'
        # For loading all image under a directory
        self.mImgList = []
        self.dirname = None
        self.labelHist = []
        self.lastOpenDir = None

        # Whether we need to save or not.
        self.dirty = False

        # Enble auto saving if pressing next
        self.autoSaving = True
        self._noSelectionSlot = False
        self._beginner = True
        self.screencastViewer = "firefox"
        self.screencast = "https://youtu.be/p0nR2YsCY_U"

        self.loadPredefinedClasses()
        # Main widgets and related state.
        self.labelDialog = LabelDialog(parent=self, listItem=self.labelHist)
        self.labelList = QListWidget()
        self.itemsToShapes = {}
        self.shapesToItems = {}

        self.labelList.itemActivated.connect(self.labelSelectionChanged)
        self.labelList.itemSelectionChanged.connect(self.labelSelectionChanged)
        self.labelList.itemDoubleClicked.connect(self.editLabel)
        # Connect to itemChanged to detect checkbox changes.
        self.labelList.itemChanged.connect(self.labelItemChanged)

        listLayout = QVBoxLayout()
        listLayout.setContentsMargins(0, 0, 0, 0)
        listLayout.addWidget(self.labelList)
        self.editButton = QToolButton()
        self.editButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
        self.labelListContainer = QWidget()
        self.labelListContainer.setLayout(listLayout)
        listLayout.addWidget(self.editButton)#, 0, Qt.AlignCenter)
        listLayout.addWidget(self.labelList)


        self.dock = QDockWidget(u'Box Labels', self)
        self.dock.setObjectName(u'Labels')
        self.dock.setWidget(self.labelListContainer)

        # Tzutalin 20160906 : Add file list and dock to move faster
        self.fileListWidget = QListWidget()
        self.fileListWidget.itemDoubleClicked.connect(self.fileitemDoubleClicked)
        filelistLayout = QVBoxLayout()
        filelistLayout.setContentsMargins(0, 0, 0, 0)
        filelistLayout.addWidget(self.fileListWidget)
        self.fileListContainer = QWidget()
        self.fileListContainer.setLayout(filelistLayout)
        self.filedock = QDockWidget(u'File List', self)
        self.filedock.setObjectName(u'Files')
        self.filedock.setWidget(self.fileListContainer)

        self.zoomWidget = ZoomWidget()
        self.colorDialog = ColorDialog(parent=self)

        self.canvas = Canvas()
        self.canvas.zoomRequest.connect(self.zoomRequest)

        scroll = QScrollArea()
        scroll.setWidget(self.canvas)
        scroll.setWidgetResizable(True)
        self.scrollBars = {
            Qt.Vertical: scroll.verticalScrollBar(),
            Qt.Horizontal: scroll.horizontalScrollBar()
            }
        self.canvas.scrollRequest.connect(self.scrollRequest)

        self.canvas.newShape.connect(self.newShape)
        self.canvas.shapeMoved.connect(self.setDirty)
        self.canvas.selectionChanged.connect(self.shapeSelectionChanged)
        self.canvas.drawingPolygon.connect(self.toggleDrawingSensitive)

        self.setCentralWidget(scroll)
        self.addDockWidget(Qt.RightDockWidgetArea, self.dock)
        # Tzutalin 20160906 : Add file list and dock to move faster
        self.addDockWidget(Qt.RightDockWidgetArea, self.filedock)
        self.dockFeatures = QDockWidget.DockWidgetClosable\
                          | QDockWidget.DockWidgetFloatable
        self.dock.setFeatures(self.dock.features() ^ self.dockFeatures)

        # Actions
        action = partial(newAction, self)
        quit = action('&Quit', self.close,
                'Ctrl+Q', 'quit', u'Quit application')
        open = action('&Open', self.openFile,
                'Ctrl+O', 'open', u'Open image or label file')

        opendir = action('&Open Dir', self.openDir,
                'Ctrl+u', 'open', u'Open Dir')

        changeSavedir = action('&Change default saved Annotation dir', self.changeSavedir,
                'Ctrl+r', 'open', u'Change default saved Annotation dir')

        openAnnotation = action('&Open Annotation', self.openAnnotation,
                'Ctrl+q', 'openAnnotation', u'Open Annotation')

        openNextImg = action('&Next Image', self.openNextImg,
                'n', 'next', u'Open Next')

        openPrevImg = action('&Prev Image', self.openPrevImg,
                'p', 'prev', u'Open Prev')

        save = action('&Save', self.saveFile,
                'Ctrl+S', 'save', u'Save labels to file', enabled=False)
        saveAs = action('&Save As', self.saveFileAs,
                'Ctrl+Shift+S', 'save-as', u'Save labels to a different file',
                enabled=False)
        close = action('&Close', self.closeFile,
                'Ctrl+W', 'close', u'Close current file')
        color1 = action('Box &Line Color', self.chooseColor1,
                'Ctrl+L', 'color_line', u'Choose Box line color')
        color2 = action('Box &Fill Color', self.chooseColor2,
                'Ctrl+Shift+L', 'color', u'Choose Box fill color')

        createMode = action('Create\nRectBox', self.setCreateMode,
                'Ctrl+N', 'new', u'Start drawing Boxs', enabled=False)
        editMode = action('&Edit\nRectBox', self.setEditMode,
                'Ctrl+J', 'edit', u'Move and edit Boxs', enabled=False)

        create = action('Create\nRectBox', self.createShape,
                'Ctrl+N', 'new', u'Draw a new Box', enabled=False)
        delete = action('Delete\nRectBox', self.deleteSelectedShape,
                'Delete', 'delete', u'Delete', enabled=False)
        copy = action('&Duplicate\nRectBox', self.copySelectedShape,
                'Ctrl+D', 'copy', u'Create a duplicate of the selected Box',
                enabled=False)

        advancedMode = action('&Advanced Mode', self.toggleAdvancedMode,
                'Ctrl+Shift+A', 'expert', u'Switch to advanced mode',
                checkable=True)

        hideAll = action('&Hide\nRectBox', partial(self.togglePolygons, False),
                'Ctrl+H', 'hide', u'Hide all Boxs',
                enabled=False)
        showAll = action('&Show\nRectBox', partial(self.togglePolygons, True),
                'Ctrl+A', 'hide', u'Show all Boxs',
                enabled=False)

        help = action('&Tutorial', self.tutorial, 'Ctrl+T', 'help',
                u'Show demos')

        zoom = QWidgetAction(self)
        zoom.setDefaultWidget(self.zoomWidget)
        self.zoomWidget.setWhatsThis(
            u"Zoom in or out of the image. Also accessible with"\
             " %s and %s from the canvas." % (fmtShortcut("Ctrl+[-+]"),
                 fmtShortcut("Ctrl+Wheel")))
        self.zoomWidget.setEnabled(False)

        zoomIn = action('Zoom &In', partial(self.addZoom, 10),
                'Ctrl++', 'zoom-in', u'Increase zoom level', enabled=False)
        zoomOut = action('&Zoom Out', partial(self.addZoom, -10),
                'Ctrl+-', 'zoom-out', u'Decrease zoom level', enabled=False)
        zoomOrg = action('&Original size', partial(self.setZoom, 100),
                'Ctrl+=', 'zoom', u'Zoom to original size', enabled=False)
        fitWindow = action('&Fit Window', self.setFitWindow,
                'Ctrl+F', 'fit-window', u'Zoom follows window size',
                checkable=True, enabled=False)
        fitWidth = action('Fit &Width', self.setFitWidth,
                'Ctrl+Shift+F', 'fit-width', u'Zoom follows window width',
                checkable=True, enabled=False)
        # Group zoom controls into a list for easier toggling.
        zoomActions = (self.zoomWidget, zoomIn, zoomOut, zoomOrg, fitWindow, fitWidth)
        self.zoomMode = self.MANUAL_ZOOM
        self.scalers = {
            self.FIT_WINDOW: self.scaleFitWindow,
            self.FIT_WIDTH: self.scaleFitWidth,
            # Set to one to scale to 100% when loading files.
            self.MANUAL_ZOOM: lambda: 1,
        }

        edit = action('&Edit Label', self.editLabel,
                'Ctrl+E', 'edit', u'Modify the label of the selected Box',
                enabled=False)
        self.editButton.setDefaultAction(edit)

        shapeLineColor = action('Shape &Line Color', self.chshapeLineColor,
                icon='color_line', tip=u'Change the line color for this specific shape',
                enabled=False)
        shapeFillColor = action('Shape &Fill Color', self.chshapeFillColor,
                icon='color', tip=u'Change the fill color for this specific shape',
                enabled=False)

        labels = self.dock.toggleViewAction()
        labels.setText('Show/Hide Label Panel')
        labels.setShortcut('Ctrl+Shift+L')

        # Lavel list context menu.
        labelMenu = QMenu()
        addActions(labelMenu, (edit, delete))
        self.labelList.setContextMenuPolicy(Qt.CustomContextMenu)
        self.labelList.customContextMenuRequested.connect(self.popLabelListMenu)

        # Store actions for further handling.
        self.actions = struct(save=save, saveAs=saveAs, open=open, close=close,
                lineColor=color1, fillColor=color2,
                create=create, delete=delete, edit=edit, copy=copy,
                createMode=createMode, editMode=editMode, advancedMode=advancedMode,
                shapeLineColor=shapeLineColor, shapeFillColor=shapeFillColor,
                zoom=zoom, zoomIn=zoomIn, zoomOut=zoomOut, zoomOrg=zoomOrg,
                fitWindow=fitWindow, fitWidth=fitWidth,
                zoomActions=zoomActions,
                fileMenuActions=(open,opendir,save,saveAs,close,quit),
                beginner=(), advanced=(),
                editMenu=(edit, copy, delete, None, color1, color2),
                beginnerContext=(create, edit, copy, delete),
                advancedContext=(createMode, editMode, edit, copy,
                    delete, shapeLineColor, shapeFillColor),
                onLoadActive=(close, create, createMode, editMode),
                onShapesPresent=(saveAs, hideAll, showAll))

        self.menus = struct(
                file=self.menu('&File'),
                edit=self.menu('&Edit'),
                view=self.menu('&View'),
                help=self.menu('&Help'),
                recentFiles=QMenu('Open &Recent'),
                labelList=labelMenu)

        addActions(self.menus.file,
                (open, opendir,changeSavedir, openAnnotation, self.menus.recentFiles, save, saveAs, close, None, quit))
        addActions(self.menus.help, (help,))
        addActions(self.menus.view, (
            labels, advancedMode, None,
            hideAll, showAll, None,
            zoomIn, zoomOut, zoomOrg, None,
            fitWindow, fitWidth))

        self.menus.file.aboutToShow.connect(self.updateFileMenu)

        # Custom context menu for the canvas widget:
        addActions(self.canvas.menus[0], self.actions.beginnerContext)
        addActions(self.canvas.menus[1], (
            action('&Copy here', self.copyShape),
            action('&Move here', self.moveShape)))

        self.tools = self.toolbar('Tools')
        self.actions.beginner = (
            open, opendir, openNextImg, openPrevImg, save, None, create, copy, delete, None,
            zoomIn, zoom, zoomOut, fitWindow, fitWidth)

        self.actions.advanced = (
            open, save, None,
            createMode, editMode, None,
            hideAll, showAll)

        self.statusBar().showMessage('%s started.' % __appname__)
        self.statusBar().show()

        # Application state.
        self.image = QImage()
        self.filename = filename
        self.recentFiles = []
        self.maxRecent = 7
        self.lineColor = None
        self.fillColor = None
        self.zoom_level = 100
        self.fit_window = False

        # XXX: Could be completely declarative.
        # Restore application settings.
        types = {
            'filename': QString,
            'recentFiles': QStringList,
            'window/size': QSize,
            'window/position': QPoint,
            'window/geometry': QByteArray,
            # Docks and toolbars:
            'window/state': QByteArray,
            'savedir': QString,
            'lastOpenDir': QString,
        }
        self.settings = settings = Settings(types)
        self.recentFiles = list(settings['recentFiles'])
        size = settings.get('window/size', QSize(600, 500))
        position = settings.get('window/position', QPoint(0, 0))
        self.resize(size)
        self.move(position)
        saveDir = settings.get('savedir', None)
        self.lastOpenDir = settings.get('lastOpenDir', None)
        if os.path.exists(unicode(saveDir)):
            self.defaultSaveDir = unicode(saveDir)
            self.statusBar().showMessage('%s started. Annotation will be saved to %s' %(__appname__, self.defaultSaveDir))
            self.statusBar().show()

        # or simply:
        #self.restoreGeometry(settings['window/geometry']
        self.restoreState(settings['window/state'])
        self.lineColor = QColor(settings.get('line/color', Shape.line_color))
        self.fillColor = QColor(settings.get('fill/color', Shape.fill_color))
        Shape.line_color = self.lineColor
        Shape.fill_color = self.fillColor

        if settings.get('advanced', QVariant()).toBool():
            self.actions.advancedMode.setChecked(True)
            self.toggleAdvancedMode()

        # Populate the File menu dynamically.
        self.updateFileMenu()
        # Since loading the file may take some time, make sure it runs in the background.
        self.queueEvent(partial(self.loadFile, self.filename))

        # Callbacks:
        self.zoomWidget.valueChanged.connect(self.paintCanvas)

        self.populateModeActions()