copy.copy

Here are the examples of the python api copy.copy taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: pgmapcss
Source File: stat.py
View license
    def property_values(self, prop, pseudo_element=None, include_illegal_values=False, value_type=None, eval_true=True, max_prop_id=None, include_none=False, object_type=None, postprocess=True, warn_unresolvable=False, assignment_type='P'):
        """Returns set of all values used on this property in any statement.
        Returns boolean 'True' if property is result of an unresolveable eval
        expression.

        Parameters:
        pseudo_element: limit returned values to given pseudo_element (default: None which means all)
        include_illegal_values: If True all values as given in MapCSS are returned, if False the list is sanitized (see @values). (Forces include_none=True) Default: False
        include_none: If True, None is a possible value (if used). Default: False
        value_type: Only values with value_type will be returned. Default None (all)
        eval_true: Return 'True' for values which result of an unresolvable eval expression. Otherwise this value will be removed. Default: True.
        max_prop_id: evaluate only properties with an id <= max_prop_id
        object_type: return values only for given object type (e.g. 'canvas')
        postprocess: include values derived from postprocessing
        warn_unresolvable: warn, if the property might be unresolvable
        assignment_type: type of assignment, 'P' for property (default), 'T' for tag, ... (see parser/parse_properties.py for a full list)
        """
        # Don't need 'eval_true' and 'warn_unresolvable' in cache_id, will be handled specially
        cache_id = prop + '-' + repr(pseudo_element) + '-' + repr(include_illegal_values) + '-' + repr(value_type) + '-' + repr(max_prop_id) + '-' + repr(include_none) + '-' + repr(object_type) + '-' + repr(postprocess)

        # Check if values are already calculated in cache
        if cache_id in self.property_values_cache:
            values = self.property_values_cache[cache_id]

            if warn_unresolvable and True in values:
                if not 'unresolvable_properties' in self:
                    self['unresolvable_properties'] = set()
                self['unresolvable_properties'].add(prop)

            if not eval_true and True in values:
                values = copy.copy(values)
                values.remove(True)

            return values

        prop_type = pgmapcss.types.get(prop, self)

        # go over all statements and their properties and collect it's values. If
        # include_illegal_values==False sanitize list. Do not include eval
        # statements.
        values = {
            p['value'] if include_illegal_values else prop_type.stat_value(p)
            for v in self['statements']
            for p in v['properties']
            if object_type is None or v['selector']['type'] == object_type
            if pseudo_element == None or v['selector']['pseudo_element'] in ('*', pseudo_element)
            if p['assignment_type'] == assignment_type and p['key'] == prop
            if value_type == None or value_type == p['value_type']
            if p['value_type'] != 'eval'
            if max_prop_id is None or p['id'] <= max_prop_id
        }

        # resolve eval functions (as far as possible) - also sanitize list.
        if True:
            values = values.union({
                v1 if v1 == True or include_illegal_values else prop_type.stat_value({
                    'value_type': 'eval',
                    'value': v1
                })
                for v in self['statements']
                for p in v['properties']
                if pseudo_element == None or v['selector']['pseudo_element'] in ('*', pseudo_element)
                if p['assignment_type'] == assignment_type and p['key'] == prop
                if p['value_type'] == 'eval'
                if max_prop_id is None or p['id'] <= max_prop_id
                for v1 in pgmapcss.eval.possible_values(p['value'], p, self)[0]
            })

        if 'default_value' in self['defines'] and prop in self['defines']['default_value']:
            v = self['defines']['default_value'][prop]['value']
            if include_illegal_values or v is not None:
                values.add(v)

        if 'default_other' in self['defines'] and prop in self['defines']['default_other']:
            other = self['defines']['default_other'][prop]['value']
            if other:
                other = self.property_values(other, pseudo_element, include_illegal_values, value_type, eval_true, max_prop_id)
                values = values.union(other)

        if 'generated_properties' in self and prop in self['generated_properties']:
            gen = self['generated_properties'][prop]
            combinations = self.properties_combinations(gen[0], pseudo_element, include_illegal_values, value_type, eval_true, max_prop_id, include_none, warn_unresolvable, assignment_type)
            values = values.union({
                gen[1](combination, self)
                for combination in combinations
            })

        if postprocess and 'postprocess' in self['defines'] and prop in self['defines']['postprocess'] and max_prop_id is None:
            p = copy.copy(self['defines']['postprocess'][prop])
            p['id'] = self['max_prop_id'] + 1
            for pe in ([pseudo_element] if pseudo_element else self['pseudo_elements']):
                p['statement'] = { 'selector': { 'pseudo_element': pe }}
                v = pgmapcss.eval.possible_values(p['value'], p, self)[0]
                values = values.union(v)

        if postprocess:
            v = prop_type.stat_postprocess(values, pseudo_element=pseudo_element)
            if v:
                values = v

        if include_illegal_values:
            self.property_values_cache[cache_id] = values
            return values

        if True in values:
            values.remove(True)
            values = values.union(prop_type.stat_all_values())

        if not include_none:
            values = {
                v
                for v in values
                if v != None and v != ''
            }

        if warn_unresolvable and True in values:
            if not 'unresolvable_properties' in self:
                self['unresolvable_properties'] = set()
            self['unresolvable_properties'].add(prop)

        self.property_values_cache[cache_id] = values

        if not eval_true and True in values:
            values = copy.copy(values)
            values.remove(True)

        return values

Example 2

Project: Pyfa
Source File: pygauge.py
View license
    def OnPaint(self, event):
        """
        Handles the ``wx.EVT_PAINT`` event for L{PyGauge}.

        :param `event`: a `wx.PaintEvent` event to be processed.
        """

        dc = wx.BufferedPaintDC(self)
        rect = self.GetClientRect()

        dc.SetBackground(wx.Brush(self.GetBackgroundColour()))
        dc.Clear()

        colour = self.GetBackgroundColour()

        dc.SetBrush(wx.Brush(colour))
        dc.SetPen(wx.Pen(colour))

        dc.DrawRectangleRect(rect)

        value = self._percentage
        if self._timer:
            if self._timer.IsRunning():
                value = self._animValue

        if self._border_colour:
            dc.SetPen(wx.Pen(self.GetBorderColour()))
            dc.DrawRectangleRect(rect)
            pad = 1 + self.GetBorderPadding()
            rect.Deflate(pad,pad)

        if self.GetBarGradient():

            if value > 100:
                w = rect.width
            else:
                w = rect.width * (float(value) / 100)
            r = copy.copy(rect)
            r.width = w

            if r.width > 0:
                # If we draw it with zero width, GTK throws errors. This way,
                # only draw it if the gauge will actually show something.
                # We stick other calculations in this block to avoid wasting
                # time on them if not needed. See GH issue #282

                pv = value
                xv=1
                transition = 0

                if pv <= 100:
                    xv = pv/100
                    transition = 0

                elif pv <=101:
                    xv = pv -100
                    transition = 1

                elif pv <= 103:
                    xv = (pv -101)/2
                    transition = 2

                elif pv <= 105:
                    xv = (pv -103)/2
                    transition = 3

                else:
                    pv = 106
                    xv = pv -100
                    transition = -1

                if transition != -1:
                    colorS,colorE = self.transitionsColors[transition]
                    color = colorUtils.CalculateTransitionColor(colorS, colorE, xv)
                else:
                    color = wx.Colour(191,48,48)

                if self.gradientEffect > 0:
                    gcolor = colorUtils.BrightenColor(color,  float(self.gradientEffect) / 100)
                    gMid = colorUtils.BrightenColor(color,  float(self.gradientEffect/2) / 100)
                else:
                    gcolor = colorUtils.DarkenColor(color,  float(-self.gradientEffect) / 100)
                    gMid = colorUtils.DarkenColor(color,  float(-self.gradientEffect/2) / 100)

                gBmp = drawUtils.DrawGradientBar(r.width, r.height, gMid, color, gcolor)
                dc.DrawBitmap(gBmp, r.left, r.top)

        else:
            colour=self.GetBarColour()
            dc.SetBrush(wx.Brush(colour))
            dc.SetPen(wx.Pen(colour))
            if value > 100:
                w = rect.width
            else:
                w = rect.width * (float(value) / 100)
            r = copy.copy(rect)
            r.width = w
            dc.DrawRectangleRect(r)

        dc.SetFont(self.font)

        r = copy.copy(rect)
        r.left +=1
        r.top +=1
        if self._range == 0.01 and self._value > 0:
            formatStr =  u'\u221e'
            dc.SetTextForeground(wx.Colour(80,80,80))
            dc.DrawLabel(formatStr, r, wx.ALIGN_CENTER)

            dc.SetTextForeground(wx.Colour(255,255,255))
            dc.DrawLabel(formatStr, rect, wx.ALIGN_CENTER)
        else:
            if self.GetBarGradient() and self._showRemaining:
                range = self._range if self._range > 0.01 else 0
                value = range - self._value
                if value < 0:
                    label = "over"
                    value = -value
                else:
                    label = "left"
                formatStr = "{0:." + str(self._fractionDigits) + "f} " + label

            else:
                formatStr = "{0:." + str(self._fractionDigits) + "f}%"

            dc.SetTextForeground(wx.Colour(80,80,80))
            dc.DrawLabel(formatStr.format(value), r, wx.ALIGN_CENTER)

            dc.SetTextForeground(wx.Colour(255,255,255))
            dc.DrawLabel(formatStr.format(value), rect, wx.ALIGN_CENTER)

Example 3

Project: sncosmo
Source File: fitting.py
View license
def nest_lc(data, model, vparam_names, bounds, guess_amplitude_bound=False,
            minsnr=5., priors=None, ppfs=None, npoints=100, method='single',
            maxiter=None, maxcall=None, modelcov=False, rstate=None,
            verbose=False, **kwargs):
    """Run nested sampling algorithm to estimate model parameters and evidence.

    Parameters
    ----------
    data : `~astropy.table.Table` or `~numpy.ndarray` or `dict`
        Table of photometric data. Must include certain columns.
        See the "Photometric Data" section of the documentation for
        required columns.
    model : `~sncosmo.Model`
        The model to fit.
    vparam_names : list
        Model parameters to vary in the fit.
    bounds : `dict`
        Bounded range for each parameter. Bounds must be given for
        each parameter, with the exception of ``t0``: by default, the
        minimum bound is such that the latest phase of the model lines
        up with the earliest data point and the maximum bound is such
        that the earliest phase of the model lines up with the latest
        data point.
    guess_amplitude_bound : bool, optional
        If true, bounds for the model's amplitude parameter are determined
        automatically based on the data and do not need to be included in
        `bounds`. The lower limit is set to zero and the upper limit is 10
        times the amplitude "guess" (which is based on the highest-flux
        data point in any band). Default is False.
    minsnr : float, optional
        Minimum signal-to-noise ratio of data points to use when guessing
        amplitude bound. Default is 5.
    priors : `dict`, optional
        Prior probability distribution function for each parameter. The keys
        should be parameter names and the values should be callables that
        accept a float. If a parameter is not in the dictionary, the prior
        defaults to a flat distribution between the bounds.
    ppfs : `dict`, optional
        Prior percent point function (inverse of the cumulative distribution
        function) for each parameter. If a parameter is in this dictionary,
        the ppf takes precedence over a prior pdf specified in ``priors``.
    npoints : int, optional
        Number of active samples to use. Increasing this value increases
        the accuracy (due to denser sampling) and also the time
        to solution.
    method : {'classic', 'single', 'multi'}, optional
        Method used to select new points. Choices are 'classic',
        single-ellipsoidal ('single'), multi-ellipsoidal ('multi'). Default
        is 'single'.
    maxiter : int, optional
        Maximum number of iterations. Iteration may stop earlier if
        termination condition is reached. Default is no limit.
    maxcall : int, optional
        Maximum number of likelihood evaluations. Iteration may stop earlier
        if termination condition is reached. Default is no limit.
    modelcov : bool, optional
        Include model covariance when calculating chisq. Default is False.
    rstate : `~numpy.random.RandomState`, optional
        RandomState instance. If not given, the global random state of the
        ``numpy.random`` module will be used.
    verbose : bool, optional
        Print running evidence sum on a single line.

    Returns
    -------
    res : Result
        Attributes are:

        * ``niter``: total number of iterations
        * ``ncall``: total number of likelihood function calls
        * ``time``: time in seconds spent in iteration loop.
        * ``logz``: natural log of the Bayesian evidence Z.
        * ``logzerr``: estimate of uncertainty in logz (due to finite sampling)
        * ``h``: Bayesian information.
        * ``vparam_names``: list of parameter names varied.
        * ``samples``: 2-d `~numpy.ndarray`, shape is (nsamples, nparameters).
          Each row is the parameter values for a single sample. For example,
          ``samples[0, :]`` is the parameter values for the first sample.
        * ``logprior``: 1-d `~numpy.ndarray` (length=nsamples);
          log(prior volume) for each sample.
        * ``logl``: 1-d `~numpy.ndarray` (length=nsamples); log(likelihood)
          for each sample.
        * ``weights``: 1-d `~numpy.ndarray` (length=nsamples);
          Weight corresponding to each sample. The weight is proportional to
          the prior * likelihood for the sample.
        * ``parameters``: 1-d `~numpy.ndarray` of weighted-mean parameter
          values from samples (including fixed parameters). Order corresponds
          to ``model.param_names``.
        * ``covariance``: 2-d `~numpy.ndarray` of parameter covariance;
          indicies correspond to order of ``vparam_names``. Calculated from
          ``samples`` and ``weights``.
        * ``errors``: OrderedDict of varied parameter uncertainties.
          Corresponds to square root of diagonal entries in covariance matrix.
        * ``ndof``: Number of degrees of freedom (len(data) -
          len(vparam_names)).
        * ``bounds``: Dictionary of bounds on varied parameters (including
          any automatically determined bounds).

    estimated_model : `~sncosmo.Model`
        A copy of the model with parameters set to the values in
        ``res.parameters``.
    """

    try:
        import nestle
    except ImportError:
        raise ImportError("nest_lc() requires the nestle package.")

    if "nobj" in kwargs:
        warn("The nobj keyword is deprecated and will be removed in a future "
             "sncosmo release. Use `npoints` instead.")
        npoints = kwargs.pop("nobj")

    # experimental parameters
    tied = kwargs.get("tied", None)

    data = standardize_data(data)
    model = copy.copy(model)
    bounds = copy.copy(bounds)  # need to copy this b/c we modify it below

    # Order vparam_names the same way it is ordered in the model:
    vparam_names = [s for s in model.param_names if s in vparam_names]

    # Drop data that the model doesn't cover.
    data = cut_bands(data, model, z_bounds=bounds.get('z', None))

    if guess_amplitude_bound:
        if model.param_names[2] not in vparam_names:
            raise ValueError("Amplitude bounds guessing enabled but "
                             "amplitude parameter {0!r} is not varied"
                             .format(model.param_names[2]))
        if model.param_names[2] in bounds:
            raise ValueError("cannot supply bounds for parameter {0!r}"
                             " when guess_amplitude_bound=True"
                             .format(model.param_names[2]))

        # If redshift is bounded, set model redshift to midpoint of bounds
        # when doing the guess.
        if 'z' in bounds:
            model.set(z=sum(bounds['z']) / 2.)
        _, amplitude = guess_t0_and_amplitude(data, model, minsnr)
        bounds[model.param_names[2]] = (0., 10. * amplitude)

    # Find t0 bounds to use, if not explicitly given
    if 't0' in vparam_names and 't0' not in bounds:
        bounds['t0'] = t0_bounds(data, model)

    if ppfs is None:
        ppfs = {}
    if tied is None:
        tied = {}

    # Convert bounds/priors combinations into ppfs
    if bounds is not None:
        for key, val in six.iteritems(bounds):
            if key in ppfs:
                continue  # ppfs take priority over bounds/priors
            a, b = val
            if priors is not None and key in priors:
                # solve ppf at discrete points and return interpolating
                # function
                x_samples = np.linspace(0., 1., 101)
                ppf_samples = ppf(priors[key], x_samples, a, b)
                f = Interp1D(0., 1., ppf_samples)
            else:
                f = Interp1D(0., 1., np.array([a, b]))
            ppfs[key] = f

    # NOTE: It is important that iparam_names is in the same order
    # every time, otherwise results will not be reproducible, even
    # with same random seed.  This is because iparam_names[i] is
    # matched to u[i] below and u will be in a reproducible order,
    # so iparam_names must also be.
    iparam_names = [key for key in vparam_names if key in ppfs]
    ppflist = [ppfs[key] for key in iparam_names]
    npdim = len(iparam_names)  # length of u
    ndim = len(vparam_names)  # length of v

    # Check that all param_names either have a direct prior or are tied.
    for name in vparam_names:
        if name in iparam_names:
            continue
        if name in tied:
            continue
        raise ValueError("Must supply ppf or bounds or tied for parameter '{}'"
                         .format(name))

    def prior_transform(u):
        d = {}
        for i in range(npdim):
            d[iparam_names[i]] = ppflist[i](u[i])
        v = np.empty(ndim, dtype=np.float)
        for i in range(ndim):
            key = vparam_names[i]
            if key in d:
                v[i] = d[key]
            else:
                v[i] = tied[key](d)
        return v

    # Indicies of the model parameters in vparam_names
    idx = np.array([model.param_names.index(name) for name in vparam_names])

    def loglike(parameters):
        model.parameters[idx] = parameters
        return -0.5 * _chisq(data, model, modelcov=modelcov)

    t0 = time.time()
    res = nestle.sample(loglike, prior_transform, ndim, npdim=npdim,
                        npoints=npoints, method=method, maxiter=maxiter,
                        maxcall=maxcall, rstate=rstate,
                        callback=(nestle.print_progress if verbose else None))
    elapsed = time.time() - t0

    # estimate parameters and covariance from samples
    vparameters, cov = nestle.mean_and_cov(res.samples, res.weights)

    # update model parameters to estimated ones.
    model.set(**dict(zip(vparam_names, vparameters)))

    # `res` is a nestle.Result object. Collect result into a sncosmo.Result
    # object for consistency, and add more fields.
    res = Result(niter=res.niter,
                 ncall=res.ncall,
                 logz=res.logz,
                 logzerr=res.logzerr,
                 h=res.h,
                 samples=res.samples,
                 weights=res.weights,
                 logvol=res.logvol,
                 logl=res.logl,
                 vparam_names=copy.copy(vparam_names),
                 ndof=len(data) - len(vparam_names),
                 bounds=bounds,
                 time=elapsed,
                 parameters=model.parameters.copy(),
                 covariance=cov,
                 errors=odict(zip(vparam_names, np.sqrt(np.diagonal(cov)))),
                 param_dict=odict(zip(model.param_names, model.parameters)))

    # Deprecated result fields.
    depmsg = ("The `param_names` attribute is deprecated in sncosmo v1.0 "
              "and will be removed in a future release. "
              "Use `vparam_names` instead.")
    res.__dict__['deprecated']['param_names'] = (res.vparam_names, depmsg)

    depmsg = ("The `logprior` attribute is deprecated in sncosmo v1.2 "
              "and will be changed in a future release. "
              "Use `logvol` instead.")
    res.__dict__['deprecated']['logprior'] = (res.logvol, depmsg)

    return res, model

Example 4

Project: spectralDNS
Source File: generate_xdmf.py
View license
def generate_xdmf(h5filename):
    f = h5py.File(h5filename)
    comps = f["3D"].keys()
    for i in ("checkpoint", "oldcheckpoint", "mesh"):
        try:
            popped = comps.remove(i)
        except:
            pass
            
    N = f.attrs["N"]
    L = f.attrs["L"]
    if len(f["/".join(("3D", comps[0]))]) > 0:
        xf3d = copy.copy(xdmffile)
        timesteps = f["/".join(("3D", comps[0]))].keys()
        tt = ""
        for i in timesteps:
            tt += "%s " %i
        
        xf3d += timeattr.format(tt, len(timesteps))
        
        dtype = f["/".join(("3D", comps[0]))].values()[0].dtype
        prec = 4 if dtype is float32 else 8

        for tstep in timesteps:
            xf3d += """
      <Grid GridType="Uniform">"""
            
            if "mesh" in f["3D"].keys():
                xf3d += channel.format(prec, N[0], N[1], N[2], h5filename)
                xf3d += """
        <Topology Dimensions="{0} {1} {2}" Type="3DRectMesh"/>""".format(*N)
            else:
                xf3d += isotropic.format(L[0]/N[0], L[1]/N[1], L[2]/N[2])
                xf3d += """
        <Topology Dimensions="{0} {1} {2}" Type="3DCoRectMesh"/>""".format(*N)
        
            prec = 4 if dtype == float32 else 8
            for comp in comps:
                xf3d += attribute3D.format(comp, N[0], N[1], N[2], h5filename, comp, tstep, prec)
            xf3d += """  
      </Grid>
"""
        xf3d += """    
    </Grid>
  </Domain>
</Xdmf>  
"""
        #f.attrs.create("xdmf_3d", xf3d)
        xf = open(h5filename[:-2]+"xdmf", "w")
        xf.write(xf3d)
        xf.close()    
    
    # Return if no 2D data
    if (len(f["/".join(("2D", comps[0]))]) == 0 
        and len(f["/".join(("2D/yz", comps[0]))]) == 0
        and len(f["/".join(("2D/xz", comps[0]))]) == 0
        and len(f["/".join(("2D/xy", comps[0]))]) == 0):   
        return
    
    if len(f["/".join(("2D", comps[0]))]) > 0:
        xf2d = copy.copy(xdmffile)
        timesteps = f["/".join(("2D", comps[0]))].keys()
        dtype = f["/".join(("2D", comps[0]))].values()[0].dtype
        tt = ""
        for i in timesteps:
            tt += "%s " %i
        
        xf2d += timeattr.format(tt, len(timesteps))

        for tstep in timesteps:
            xf2d += """
      <Grid GridType="Uniform">
        <Geometry Type="ORIGIN_DXDY">
          <DataItem DataType="UInt" Dimensions="2" Format="XML" Precision="4">0 0</DataItem>
          <DataItem DataType="Float" Dimensions="2" Format="XML" Precision="4">{0} {1}</DataItem>
        </Geometry>""".format(L[0]/N[0], L[1]/N[1])

            xf2d += """
        <Topology Dimensions="{0} {1}" Type="2DCoRectMesh"/>""".format(N[0], N[1])
            prec = 4 if dtype is float32 else 8
            for comp in comps:
                xf2d += attribute2D.format(comp, N[0], N[1], h5filename, comp, tstep, prec)
            xf2d += """  
      </Grid>
"""
        xf2d += """  
    </Grid>
  </Domain>
</Xdmf>"""
        xf2 = open(h5filename[:-3]+"_2D."+"xdmf", "w")
        xf2.write(xf2d)
        xf2.close()

    if len(f["/".join(("2D/yz", comps[0]))]) > 0:                
        xf2d = copy.copy(xdmffile)
        timesteps = f["/".join(("2D/yz", comps[0]))].keys()
        dtype = f["/".join(("2D/yz", comps[0]))].values()[0].dtype
        tt = ""
        for i in timesteps:
            tt += "%s " %i
        
        xf2d += timeattr.format(tt, len(timesteps))
        for tstep in timesteps:    
            xf2d += """
      <Grid GridType="Uniform">
        <Geometry Type="ORIGIN_DXDY">
          <DataItem DataType="UInt" Dimensions="2" Format="XML" Precision="4">0 0</DataItem>
          <DataItem DataType="Float" Dimensions="2" Format="XML" Precision="4">{0} {1}</DataItem>
        </Geometry>""".format(L[1]/N[1], L[2]/N[2])

            xf2d += """
        <Topology Dimensions="{0} {1}" Type="2DCoRectMesh"/>""".format(N[1], N[2])
            prec = 4 if dtype is float32 else 8
            if len(f["/".join(("2D/yz", comps[0]))]) > 0:
                for comp in f["2D/yz"]:
                    xf2d += attribute2Dslice.format(comp, N[1], N[2], h5filename, comp, tstep, prec, 'yz')
            xf2d += """  
      </Grid>
"""
        xf2d += """  
    </Grid>
  </Domain>
</Xdmf>"""
        xf2 = open(h5filename[:-3]+"_yz."+"xdmf", "w")
        xf2.write(xf2d)
        xf2.close()

    if len(f["/".join(("2D/xz", comps[0]))]) > 0:
        xf2d = copy.copy(xdmffile)
        timesteps = f["/".join(("2D/xz", comps[0]))].keys()
        dtype = f["/".join(("2D/xz", comps[0]))].values()[0].dtype
        tt = ""
        for i in timesteps:
            tt += "%s " %i
        
        xf2d += timeattr.format(tt, len(timesteps))
        for tstep in timesteps:    
            xf2d += """
      <Grid GridType="Uniform">
        <Geometry Type="ORIGIN_DXDY">
          <DataItem DataType="UInt" Dimensions="2" Format="XML" Precision="4">0 0</DataItem>
          <DataItem DataType="Float" Dimensions="2" Format="XML" Precision="4">{0} {1}</DataItem>
        </Geometry>""".format(L[0]/N[0], L[2]/N[2])

            xf2d += """
        <Topology Dimensions="{0} {1}" Type="2DCoRectMesh"/>""".format(N[0], N[2])
            prec = 4 if dtype is float32 else 8
            if len(f["/".join(("2D/xz", comps[0]))]) > 0:
                for comp in f["2D/xz"]:
                    xf2d += attribute2Dslice.format(comp, N[0], N[2], h5filename, comp, tstep, prec, 'xz')
            xf2d += """  
      </Grid>
"""
        xf2d += """  
    </Grid>
  </Domain>
</Xdmf>"""
        xf2 = open(h5filename[:-3]+"_xz."+"xdmf", "w")
        xf2.write(xf2d)
        xf2.close()

    if len(f["/".join(("2D/xy", comps[0]))]) > 0:
        xf2d = copy.copy(xdmffile)
        timesteps = f["/".join(("2D/xy", comps[0]))].keys()
        dtype = f["/".join(("2D/xy", comps[0]))].values()[0].dtype
        tt = ""
        for i in timesteps:
            tt += "%s " %i
        
        xf2d += timeattr.format(tt, len(timesteps))
        for tstep in timesteps:    
            xf2d += """
      <Grid GridType="Uniform">
        <Geometry Type="ORIGIN_DXDY">
          <DataItem DataType="UInt" Dimensions="2" Format="XML" Precision="4">0 0</DataItem>
          <DataItem DataType="Float" Dimensions="2" Format="XML" Precision="4">{0} {1}</DataItem>
        </Geometry>""".format(L[0]/N[0], L[1]/N[1])

            xf2d += """
        <Topology Dimensions="{0} {1}" Type="2DCoRectMesh"/>""".format(N[0], N[1])
            prec = 4 if dtype is float32 else 8
            if len(f["/".join(("2D/xy", comps[0]))]) > 0:
                for comp in f["2D/xy"]:
                    xf2d += attribute2Dslice.format(comp, N[0], N[1], h5filename, comp, tstep, prec, 'xy')
            xf2d += """  
      </Grid>
"""
        xf2d += """  
    </Grid>
  </Domain>
</Xdmf>"""
        xf2 = open(h5filename[:-3]+"_xy."+"xdmf", "w")
        xf2.write(xf2d)
        xf2.close()

Example 5

Project: isort
Source File: isort.py
View license
    def _add_from_imports(self, from_modules, section, section_output, ignore_case):
        for module in from_modules:
            if module in self.remove_imports:
                continue

            import_start = "from {0} import ".format(module)
            from_imports = list(self.imports[section]['from'][module])
            from_imports = nsorted(from_imports, key=lambda key: self._module_key(key, self.config, True, ignore_case))
            if self.remove_imports:
                from_imports = [line for line in from_imports if not "{0}.{1}".format(module, line) in
                                self.remove_imports]

            for from_import in copy.copy(from_imports):
                submodule = module + "." + from_import
                import_as = self.as_map.get(submodule, False)
                if import_as:
                    import_definition = "{0} as {1}".format(from_import, import_as)
                    if self.config['combine_as_imports'] and not ("*" in from_imports and
                                                                    self.config['combine_star']):
                        from_imports[from_imports.index(from_import)] = import_definition
                    else:
                        import_statement = self._wrap(import_start + import_definition)
                        comments = self.comments['straight'].get(submodule)
                        import_statement = self._add_comments(comments, import_statement)
                        section_output.append(import_statement)
                        from_imports.remove(from_import)

            if from_imports:
                comments = self.comments['from'].pop(module, ())
                if "*" in from_imports and self.config['combine_star']:
                    import_statement = self._wrap(self._add_comments(comments, "{0}*".format(import_start)))
                elif self.config['force_single_line']:
                    import_statements = []
                    for from_import in from_imports:
                        single_import_line = self._add_comments(comments, import_start + from_import)
                        comment = self.comments['nested'].get(module, {}).pop(from_import, None)
                        if comment:
                            single_import_line += "{0} {1}".format(comments and ";" or "  #", comment)
                        import_statements.append(self._wrap(single_import_line))
                        comments = None
                    import_statement = "\n".join(import_statements)
                else:
                    star_import = False
                    if "*" in from_imports:
                        section_output.append(self._add_comments(comments, "{0}*".format(import_start)))
                        from_imports.remove('*')
                        star_import = True
                        comments = None

                    for from_import in copy.copy(from_imports):
                        comment = self.comments['nested'].get(module, {}).pop(from_import, None)
                        if comment:
                            single_import_line = self._add_comments(comments, import_start + from_import)
                            single_import_line += "{0} {1}".format(comments and ";" or "  #", comment)
                            above_comments = self.comments['above']['from'].pop(module, None)
                            if above_comments:
                                section_output.extend(above_comments)
                            section_output.append(self._wrap(single_import_line))
                            from_imports.remove(from_import)
                            comments = None

                    if star_import:
                        import_statement = import_start + (", ").join(from_imports)
                    else:
                        import_statement = self._add_comments(comments, import_start + (", ").join(from_imports))
                    if not from_imports:
                        import_statement = ""

                    do_multiline_reformat = False

                    if self.config.get('force_grid_wrap') and len(from_imports) > 1:
                        do_multiline_reformat = True

                    if len(import_statement) > self.config['line_length'] and len(from_imports) > 1:
                        do_multiline_reformat = True

                    # If line too long AND have imports AND we are NOT using GRID or VERTICAL wrap modes
                    if (len(import_statement) > self.config['line_length'] and len(from_imports) > 0
                        and self.config.get('multi_line_output', 0) not in (1, 0)):
                        do_multiline_reformat = True

                    if do_multiline_reformat:
                        output_mode = settings.WrapModes._fields[self.config.get('multi_line_output',
                                                                                    0)].lower()
                        formatter = getattr(self, "_output_" + output_mode, self._output_grid)
                        dynamic_indent = " " * (len(import_start) + 1)
                        indent = self.config['indent']
                        line_length = self.config['wrap_length'] or self.config['line_length']
                        import_statement = formatter(import_start, copy.copy(from_imports),
                                                    dynamic_indent, indent, line_length, comments)
                        if self.config['balanced_wrapping']:
                            lines = import_statement.split("\n")
                            line_count = len(lines)
                            if len(lines) > 1:
                                minimum_length = min([len(line) for line in lines[:-1]])
                            else:
                                minimum_length = 0
                            new_import_statement = import_statement
                            while (len(lines[-1]) < minimum_length and
                                    len(lines) == line_count and line_length > 10):
                                import_statement = new_import_statement
                                line_length -= 1
                                new_import_statement = formatter(import_start, copy.copy(from_imports),
                                                                dynamic_indent, indent, line_length, comments)
                                lines = new_import_statement.split("\n")

                    if not do_multiline_reformat and len(import_statement) > self.config['line_length']:
                        import_statement = self._wrap(import_statement)

                if import_statement:
                    above_comments = self.comments['above']['from'].pop(module, None)
                    if above_comments:
                        section_output.extend(above_comments)
                    section_output.append(import_statement)

Example 6

View license
def transpose_pitch_carrier_by_interval(pitch_carrier, interval):
    '''Transposes `pitch_carrier` by named `interval`.

    ::

        >>> chord = Chord("<c' e' g'>4")

    ::

        >>> pitchtools.transpose_pitch_carrier_by_interval(
        ...     chord, '+m2')
        Chord("<df' f' af'>4")

    Transpose `pitch_carrier` by numbered `interval`:

    ::

        >>> chord = Chord("<c' e' g'>4")

    ::

        >>> pitchtools.transpose_pitch_carrier_by_interval(chord, 1)
        Chord("<cs' f' af'>4")

    Returns non-pitch-carrying input unchaged:

    ::

        >>> rest = Rest('r4')

    ::

        >>> pitchtools.transpose_pitch_carrier_by_interval(rest, 1)
        Rest('r4')

    Return `pitch_carrier`.
    '''
    from abjad.tools import pitchtools
    from abjad.tools import scoretools

    def _transpose_pitch_by_named_interval(pitch, mdi):
        pitch_number = pitch.pitch_number + mdi.semitones
        diatonic_pitch_class_number = \
            (pitch.diatonic_pitch_class_number + mdi.staff_spaces) % 7
        diatonic_pitch_class_name = \
            pitchtools.PitchClass._diatonic_pitch_class_number_to_diatonic_pitch_class_name[
                diatonic_pitch_class_number]
        named_pitch = pitchtools.NamedPitch(
            pitch_number, diatonic_pitch_class_name)
        return type(pitch)(named_pitch)

    def _transpose_pitch_carrier_by_named_interval(
        pitch_carrier, named_interval):
        mdi = pitchtools.NamedInterval(named_interval)
        if isinstance(pitch_carrier, pitchtools.Pitch):
            return _transpose_pitch_by_named_interval(
                pitch_carrier, mdi)
        elif isinstance(pitch_carrier, scoretools.Note):
            new_note = copy.copy(pitch_carrier)
            new_pitch = _transpose_pitch_by_named_interval(
                pitch_carrier.written_pitch, mdi)
            new_note.written_pitch = new_pitch
            return new_note
        elif isinstance(pitch_carrier, scoretools.Chord):
            new_chord = copy.copy(pitch_carrier)
            for new_nh, old_nh in \
                zip(new_chord.note_heads, pitch_carrier.note_heads):
                new_pitch = _transpose_pitch_by_named_interval(
                    old_nh.written_pitch, mdi)
                new_nh.written_pitch = new_pitch
            return new_chord
        else:
            return pitch_carrier

    def _transpose_pitch_carrier_by_numbered_interval(
        pitch_carrier, numbered_interval):
        mci = pitchtools.NumberedInterval(numbered_interval)
        if isinstance(pitch_carrier, pitchtools.Pitch):
            number = pitch_carrier.pitch_number + mci.semitones
            return type(pitch_carrier)(number)
        elif isinstance(pitch_carrier, numbers.Number):
            pitch_carrier = pitchtools.NumberedPitch(pitch_carrier)
            result = _transpose_pitch_carrier_by_numbered_interval(
                pitch_carrier, mci)
            return result.pitch_number
        elif isinstance(pitch_carrier, scoretools.Note):
            new_note = copy.copy(pitch_carrier)
            number = pitchtools.NumberedPitch(
                pitch_carrier.written_pitch).pitch_number
            number += mci.number
            new_pitch = pitchtools.NamedPitch(number)
            new_note.written_pitch = new_pitch
            return new_note
        elif isinstance(pitch_carrier, scoretools.Chord):
            new_chord = copy.copy(pitch_carrier)
            pairs = zip(new_chord.note_heads, pitch_carrier.note_heads)
            for new_nh, old_nh in pairs:
                number = \
                    pitchtools.NumberedPitch(old_nh.written_pitch).pitch_number
                number += mci.number
                new_pitch = pitchtools.NamedPitch(number)
                new_nh.written_pitch = new_pitch
            return new_chord
        else:
            return pitch_carrier


    diatonic_types = (pitchtools.NamedInterval, str)
    if isinstance(interval, diatonic_types):
        interval = \
            pitchtools.NamedInterval(interval)
        return _transpose_pitch_carrier_by_named_interval(
            pitch_carrier, interval)
    else:
        interval = \
            pitchtools.NumberedInterval(interval)
        return _transpose_pitch_carrier_by_numbered_interval(
            pitch_carrier, interval)

Example 7

Project: sftf
Source File: Transaction.py
View license
	def createReply(self, code, reason, req=None, createEvent=True):
		"""Creates and returns a SIP reply with the given code and reason for
		the transaction instance.
		"""
		Log.logDebug("Transaction.createReply(): entered with code=\'" + str(code) + "\', reason=\'" + str(reason) + "\', createEvent=\'" + str(createEvent) + "\'", 5)
		if req is not None:
			if not req in self.message:
				raise SCException("Transaction", "createReply", "given request isnt in this transaction")
			if not isinstance(req, SipRequest):
				raise SCException("Transaction", "createReply", "message class != SipRequest")
		elif self.lastRequest is not None:
			req = self.lastRequest
		else:
			raise SCException("Transaction", "createReply", "transaction contains no request for reply creation")
		reply = SipReply()
		reply.code = int(code)
		reply.reason = reason
		reply.request = req
		reply.protocol = copy.copy(req.protocol)
		reply.version = copy.copy(req.version)
		# copy only the mandatory HFHs
		cp_hfh = Helper.get_rpl_hfh_dict(code)
		if cp_hfh is None:
			Log.logDebug("Transaction.createReply(): code not in reply HFH table, looking up generic code", 5)
			code_generic = (code / 100) * 100
			cp_hfh = Helper.get_rpl_hfh_dict(code_generic)
		if cp_hfh is None:
			Log.logDebug("Transaction.createReply(): unable to find generic reply in HFH table, copying all HFH", 2)
			reply.headerFields = copy.copy(req.headerFields)
			reply.parsedHeader = copy.deepcopy(req.parsedHeader)
		else:
			for i in cp_hfh:
				if req.hasHeaderField(i):
					reply.setHeaderValue(i, copy.copy(req.getHeaderValue(i)))
				else:
					Log.logDebug("Transaction.createReply(): request missing mandatory HFH: " + str(i), 2)
				i_m = Helper.getMappedHFH(i)
				if i_m is None:
					Log.logDebug("Transaction.createReply(): unable to match HFH: " + str(i), 3)
					i_m = i
				if req.hasParsedHeaderField(i_m):
					reply.setParsedHeaderValue(i_m, copy.deepcopy(req.getParsedHeaderValue(i_m)))
					if i == "Via":
						via = reply.getParsedHeaderValue("Via")
						regen = False
						if via.rport is not None:
							via.rport = req.event.srcAddress[1]
							regen = True
						if via.received is not None:
							via.received = req.event.srcAddress[0]
							regen = True
						if regen:
							reply.setHeaderValue("Via", via.create())
				else:
					Log.logDebug("Transaction.createReply(): request missing mandatory parsed HFH: "+  str(i), 2)
		if code/100 <= 2 and req.hasParsedHeaderField("Record-Route"):
			rr = copy.deepcopy(req.getParsedHeaderValue("Record-Route"))
			reply.setParsedHeaderValue("Record-Route", rr)
			reply.setHeaderValue("Record-Route", rr.create())
		if code/100 == 2 and req.method == "INVITE":
			if (self.dialog is not None):
				con = self.dialog.getLocalContact()
			else:
				con = Helper.createClassInstance("Contact")
				con.uri.protocol = "sip"
				con.uri.username = Config.SC_USER_NAME
				con.uri.host = Config.LOCAL_IP
				con.uri.port = Config.LOCAL_PORT
			reply.setParsedHeaderValue("Contact", con)
			reply.setHeaderValue("Contact", con.create())
			reply.body = Helper.createDummyBody()
		elif code/100 == 2 and req.method == "REGISTER":
			reply.body = []
			if req.hasParsedHeaderField("Contact"):
				reply.setParsedHeaderValue("Contact", copy.deepcopy(req.getParsedHeaderValue("Contact")))
				touri = req.getParsedHeaderValue("To").uri
				co = reply.getParsedHeaderValue("Contact")
				cco = co
				while cco is not None:
					if cco.expires is None:
						if req.hasParsedHeaderField("Expires"):
							cco.expires = req.getParsedHeaderValue("Expires").seconds
						else:
							cco.expires = int(Config.DEFAULT_EXPIRES)
					cco = cco.next
				Helper.usrlocAddContact(touri.create(), co)
				reply.setHeaderValue("Contact", co.create())
			else:
				if req.hasHeaderField("Contact"):
					Log.logDebug("Transaction.createReply(): missing parsed Contact using raw value", 3)
					reply.setHeaderValue("Contact", req.getHeaderValue("Contact"))
				else:
					co = Helper.usrlocGetContacts(req.getParsedHeaderValue("To").uri)
					if co is not None:
						reply.setParsedHeaderValue("Contact", co)
						reply.setHeaderValue("Contact", co.create())
		else:
			if (self.dialog is not None):
				con = self.dialog.getLocalContact()
			else:
				con = Helper.createClassInstance("Contact")
				con.uri.protocol = "sip"
				con.uri.username = Config.SC_USER_NAME
				con.uri.host = Config.LOCAL_IP
				con.uri.port = Config.LOCAL_PORT
				con.uri.params = ['transport=UDP']
			reply.setParsedHeaderValue("Contact", con)
			reply.setHeaderValue("Contact", con.create())
			reply.body = []
		cl = Helper.createClassInstance("Contentlength")
		cl.length = Helper.calculateBodyLength(reply.body)
		reply.setHeaderValue("Content-Length", cl.create())
		if (createEvent):
			reply.createEvent()
			reply.setEventAddresses(req.event.dstAddress, req.getReplyAddress())
		return reply

Example 8

Project: xhtml2pdf
Source File: parser.py
View license
def pisaLoop(node, context, path=None, **kw):

    if path is None:
        path = []

    # Initialize KW
    if not kw:
        kw = {
            "margin-top": 0,
            "margin-bottom": 0,
            "margin-left": 0,
            "margin-right": 0,
        }
    else:
        kw = copy.copy(kw)

    #indent = len(path) * "  " # only used for debug print statements

    # TEXT
    if node.nodeType == Node.TEXT_NODE:
        # print indent, "#", repr(node.data) #, context.frag
        context.addFrag(node.data)

        # context.text.append(node.value)

    # ELEMENT
    elif node.nodeType == Node.ELEMENT_NODE:

        node.tagName = node.tagName.replace(":", "").lower()

        if node.tagName in ("style", "script"):
            return

        path = copy.copy(path) + [node.tagName]

        # Prepare attributes
        attr = pisaGetAttributes(context, node.tagName, node.attributes)
        #log.debug(indent + "<%s %s>" % (node.tagName, attr) + repr(node.attributes.items())) #, path

        # Calculate styles
        context.cssAttr = CSSCollect(node, context)
        context.cssAttr = mapNonStandardAttrs(context.cssAttr, node, attr)
        context.node = node

        # Block?
        PAGE_BREAK = 1
        PAGE_BREAK_RIGHT = 2
        PAGE_BREAK_LEFT = 3

        pageBreakAfter = False
        frameBreakAfter = False
        display = lower(context.cssAttr.get("display", "inline"))
        # print indent, node.tagName, display, context.cssAttr.get("background-color", None), attr
        isBlock = (display == "block")

        if isBlock:
            context.addPara()

            # Page break by CSS
            if "-pdf-next-page" in context.cssAttr:
                context.addStory(NextPageTemplate(str(context.cssAttr["-pdf-next-page"])))
            if "-pdf-page-break" in context.cssAttr:
                if str(context.cssAttr["-pdf-page-break"]).lower() == "before":
                    context.addStory(PageBreak())
            if "-pdf-frame-break" in context.cssAttr:
                if str(context.cssAttr["-pdf-frame-break"]).lower() == "before":
                    context.addStory(FrameBreak())
                if str(context.cssAttr["-pdf-frame-break"]).lower() == "after":
                    frameBreakAfter = True
            if "page-break-before" in context.cssAttr:
                if str(context.cssAttr["page-break-before"]).lower() == "always":
                    context.addStory(PageBreak())
                if str(context.cssAttr["page-break-before"]).lower() == "right":
                    context.addStory(PageBreak())
                    context.addStory(PmlRightPageBreak())
                if str(context.cssAttr["page-break-before"]).lower() == "left":
                    context.addStory(PageBreak())
                    context.addStory(PmlLeftPageBreak())
            if "page-break-after" in context.cssAttr:
                if str(context.cssAttr["page-break-after"]).lower() == "always":
                    pageBreakAfter = PAGE_BREAK
                if str(context.cssAttr["page-break-after"]).lower() == "right":
                    pageBreakAfter = PAGE_BREAK_RIGHT
                if str(context.cssAttr["page-break-after"]).lower() == "left":
                    pageBreakAfter = PAGE_BREAK_LEFT

        if display == "none":
            # print "none!"
            return

        # Translate CSS to frags

        # Save previous frag styles
        context.pushFrag()

        # Map styles to Reportlab fragment properties
        CSS2Frag(context, kw, isBlock)

        # EXTRAS
        if "-pdf-keep-with-next" in context.cssAttr:
            context.frag.keepWithNext = getBool(context.cssAttr["-pdf-keep-with-next"])
        if "-pdf-outline" in context.cssAttr:
            context.frag.outline = getBool(context.cssAttr["-pdf-outline"])
        if "-pdf-outline-level" in context.cssAttr:
            context.frag.outlineLevel = int(context.cssAttr["-pdf-outline-level"])
        if "-pdf-outline-open" in context.cssAttr:
            context.frag.outlineOpen = getBool(context.cssAttr["-pdf-outline-open"])
        if "-pdf-word-wrap" in context.cssAttr:
            context.frag.wordWrap = context.cssAttr["-pdf-word-wrap"]

        # handle keep-in-frame
        keepInFrameMode = None
        keepInFrameMaxWidth = 0
        keepInFrameMaxHeight = 0
        if "-pdf-keep-in-frame-mode" in context.cssAttr:
            value = str(context.cssAttr["-pdf-keep-in-frame-mode"]).strip().lower()
            if value in ("shrink", "error", "overflow", "truncate"):
                keepInFrameMode = value
            else:
                keepInFrameMode = "shrink"
            # Added because we need a default value.
        if "-pdf-keep-in-frame-max-width" in context.cssAttr:
            keepInFrameMaxWidth = getSize("".join(context.cssAttr["-pdf-keep-in-frame-max-width"]))
        if "-pdf-keep-in-frame-max-height" in context.cssAttr:
            keepInFrameMaxHeight = getSize("".join(context.cssAttr["-pdf-keep-in-frame-max-height"]))

        # ignore nested keep-in-frames, tables have their own KIF handling
        keepInFrame = keepInFrameMode is not None and context.keepInFrameIndex is None
        if keepInFrame:
            # keep track of current story index, so we can wrap everythink
            # added after this point in a KeepInFrame
            context.keepInFrameIndex = len(context.story)

        # BEGIN tag
        klass = globals().get("pisaTag%s" % node.tagName.replace(":", "").upper(), None)
        obj = None

        # Static block
        elementId = attr.get("id", None)
        staticFrame = context.frameStatic.get(elementId, None)
        if staticFrame:
            context.frag.insideStaticFrame += 1
            oldStory = context.swapStory()

        # Tag specific operations
        if klass is not None:
            obj = klass(node, attr)
            obj.start(context)

        # Visit child nodes
        context.fragBlock = fragBlock = copy.copy(context.frag)
        for nnode in node.childNodes:
            pisaLoop(nnode, context, path, **kw)
        context.fragBlock = fragBlock

        # END tag
        if obj:
            obj.end(context)

        # Block?
        if isBlock:
            context.addPara()

            # XXX Buggy!

            # Page break by CSS
            if pageBreakAfter:
                context.addStory(PageBreak())
                if pageBreakAfter == PAGE_BREAK_RIGHT:
                    context.addStory(PmlRightPageBreak())
                if pageBreakAfter == PAGE_BREAK_LEFT:
                    context.addStory(PmlLeftPageBreak())
            if frameBreakAfter:
                context.addStory(FrameBreak())

        if keepInFrame:
            # get all content added after start of -pdf-keep-in-frame and wrap
            # it in a KeepInFrame
            substory = context.story[context.keepInFrameIndex:]
            context.story = context.story[:context.keepInFrameIndex]
            context.story.append(
                KeepInFrame(
                    content=substory,
                    maxWidth=keepInFrameMaxWidth,
                    maxHeight=keepInFrameMaxHeight,
                    mode=keepInFrameMode))
            # mode wasn't being used; it is necessary for tables or images at end of page.
            context.keepInFrameIndex = None

        # Static block, END
        if staticFrame:
            context.addPara()
            for frame in staticFrame:
                frame.pisaStaticStory = context.story
            context.swapStory(oldStory)
            context.frag.insideStaticFrame -= 1

        # context.debug(1, indent, "</%s>" % (node.tagName))

        # Reset frag style
        context.pullFrag()

    # Unknown or not handled
    else:
        # context.debug(1, indent, "???", node, node.nodeType, repr(node))
        # Loop over children
        for node in node.childNodes:
            pisaLoop(node, context, path, **kw)

Example 9

Project: jasy
Source File: Writer.py
View license
    def __process(self, apiData, classFilter=None, internals=False, privates=False, printErrors=True, highlightCode=True):
        
        knownClasses = set(list(apiData))


        #
        # Attaching Links to Source Code (Lines)
        # Building Documentation Summaries
        #

        
        Console.info("Adding Source Links...")

        for className in apiData:
            classApi = apiData[className]

            constructData = getattr(classApi, "construct", None)
            if constructData is not None:
                if "line" in constructData:
                    constructData["sourceLink"] = "source:%s~%s" % (className, constructData["line"])

            for section in ("properties", "events", "statics", "members"):
                sectionData = getattr(classApi, section, None)

                if sectionData is not None:
                    for name in sectionData:
                        if "line" in sectionData[name]:
                            sectionData[name]["sourceLink"] = "source:%s~%s" % (className, sectionData[name]["line"])



        #
        # Including Mixins / IncludedBy
        #

        Console.info("Resolving Mixins...")
        Console.indent()

        # Just used temporary to keep track of which classes are merged
        mergedClasses = set()

        def getApi(className):
            classApi = apiData[className]

            if className in mergedClasses:
                return classApi

            classIncludes = getattr(classApi, "includes", None)
            if classIncludes:
                for mixinName in classIncludes:
                    if not mixinName in apiData:
                        Console.error("Invalid mixin %s in class %s", className, mixinName)
                        continue
                        
                    mixinApi = apiData[mixinName]
                    if not hasattr(mixinApi, "includedBy"):
                        mixinApi.includedBy = set()

                    mixinApi.includedBy.add(className)
                    mergeMixin(className, mixinName, classApi, getApi(mixinName))

            mergedClasses.add(className)

            return classApi

        for className in apiData:
            apiData[className] = getApi(className)

        Console.outdent()



        #
        # Checking links
        #
        
        Console.info("Checking Links...")
        
        additionalTypes = ("Call", "Identifier", "Map", "Integer", "Node", "Element")
        
        def checkInternalLink(link, className):
            match = internalLinkParse.match(link)
            if not match:
                return 'Invalid link "#%s"' % link
                
            if match.group(3) is not None:
                className = match.group(3)
                
            if not className in knownClasses and not className in apiData:
                return 'Invalid class in link "#%s"' % link
                
            # Accept all section/item values for named classes,
            # as it might be pretty complicated to verify this here.
            if not className in apiData:
                return True
                
            classApi = apiData[className]
            sectionName = match.group(2)
            itemName = match.group(5)
            
            if itemName is None:
                return True
                
            if sectionName is not None:
                if not sectionName in linkMap:
                    return 'Invalid section in link "#%s"' % link
                    
                section = getattr(classApi, linkMap[sectionName], None)
                if section is None:
                    return 'Invalid section in link "#%s"' % link
                else:
                    if itemName in section:
                        return True
                        
                    return 'Invalid item in link "#%s"' % link
            
            for sectionName in ("statics", "members", "properties", "events"):
                section = getattr(classApi, sectionName, None)
                if section and itemName in section:
                    return True
                
            return 'Invalid item link "#%s"' % link


        def checkLinksInItem(item):
            
            # Process types
            if "type" in item:
                
                if item["type"] == "Function":

                    # Check param types
                    if "params" in item:
                        for paramName in item["params"]:
                            paramEntry = item["params"][paramName]
                            if "type" in paramEntry:
                                for paramTypeEntry in paramEntry["type"]:
                                    if not paramTypeEntry["name"] in knownClasses and not paramTypeEntry["name"] in additionalTypes and not ("builtin" in paramTypeEntry or "pseudo" in paramTypeEntry):
                                        item["errornous"] = True
                                        Console.error('Invalid param type "%s" in %s' % (paramTypeEntry["name"], className))

                                    if not "pseudo" in paramTypeEntry and paramTypeEntry["name"] in knownClasses:
                                        paramTypeEntry["linkable"] = True
                
                
                    # Check return types
                    if "returns" in item:
                        for returnTypeEntry in item["returns"]:
                            if not returnTypeEntry["name"] in knownClasses and not returnTypeEntry["name"] in additionalTypes and not ("builtin" in returnTypeEntry or "pseudo" in returnTypeEntry):
                                item["errornous"] = True
                                Console.error('Invalid return type "%s" in %s' % (returnTypeEntry["name"], className))
                            
                            if not "pseudo" in returnTypeEntry and returnTypeEntry["name"] in knownClasses:
                                returnTypeEntry["linkable"] = True
                            
                elif not item["type"] in builtinTypes and not item["type"] in pseudoTypes and not item["type"] in additionalTypes:
                    item["errornous"] = True
                    Console.error('Invalid type "%s" in %s' % (item["type"], className))
            
            
            # Process doc
            if "doc" in item:
                
                def processInternalLink(match):
                    linkUrl = match.group(2)

                    if linkUrl.startswith("#"):
                        linkCheck = checkInternalLink(linkUrl[1:], className)
                        if linkCheck is not True:
                            item["errornous"] = True

                            if sectionName:
                                Console.error("%s in %s:%s~%s" % (linkCheck, sectionName, className, name))
                            else:
                                Console.error("%s in %s" % (linkCheck, className))
            
                linkExtract.sub(processInternalLink, item["doc"])


        Console.indent()

        # Process APIs
        for className in apiData:
            classApi = apiData[className]
            
            sectionName = None
            constructData = getattr(classApi, "construct", None)
            if constructData is not None:
                checkLinksInItem(constructData)

            for sectionName in ("properties", "events", "statics", "members"):
                section = getattr(classApi, sectionName, None)

                if section is not None:
                    for name in section:
                         checkLinksInItem(section[name])

        Console.outdent()



        #
        # Filter Internals/Privates
        #
        
        Console.info("Filtering Items...")
        
        def isVisible(entry):
            if "visibility" in entry:
                visibility = entry["visibility"]
                if visibility == "private" and not privates:
                    return False
                if visibility == "internal" and not internals:
                    return False

            return True

        def filterInternalsPrivates(classApi, field):
            data = getattr(classApi, field, None)
            if data:
                for name in list(data):
                    if not isVisible(data[name]):
                        del data[name]

        for className in apiData:
            filterInternalsPrivates(apiData[className], "statics")
            filterInternalsPrivates(apiData[className], "members")



        #
        # Connection Interfaces / ImplementedBy
        #
        
        Console.info("Connecting Interfaces...")
        Console.indent()
        
        for className in apiData:
            classApi = getApi(className)
            
            if not hasattr(classApi, "main"):
                continue
                
            classType = classApi.main["type"]
            if classType == "core.Class":
                
                classImplements = getattr(classApi, "implements", None)
                if classImplements:
                    
                    for interfaceName in classImplements:
                        interfaceApi = apiData[interfaceName]
                        implementedBy = getattr(interfaceApi, "implementedBy", None)
                        if not implementedBy:
                            implementedBy = interfaceApi.implementedBy = []
                            
                        implementedBy.append(className)
                        connectInterface(className, interfaceName, classApi, interfaceApi)
        
        Console.outdent()
        
        
        #
        # Merging Named Classes
        #
        
        Console.info("Merging Named Classes...")
        Console.indent()
        
        for className in list(apiData):
            classApi = apiData[className]
            destName = classApi.main["name"]
            
            if destName is not None and destName != className:

                Console.debug("Extending class %s with %s", destName, className)

                if destName in apiData:
                    destApi = apiData[destName]
                    destApi.main["from"].append(className)
                
                else:
                    destApi = apiData[destName] = Data.ApiData(destName, highlight=highlightCode)
                    destApi.main = {
                        "type" : "Extend",
                        "name" : destName,
                        "from" : [className]
                    }
                    
                # If there is a "main" tag found in the class use its API description
                if "tags" in classApi.main and classApi.main["tags"] is not None and "main" in classApi.main["tags"]:
                    if "doc" in classApi.main:
                        destApi.main["doc"] = classApi.main["doc"]
                
                classApi.main["extension"] = True
                    
                # Read existing data
                construct = getattr(classApi, "construct", None)
                statics = getattr(classApi, "statics", None)
                members = getattr(classApi, "members", None)

                if construct is not None:
                    if hasattr(destApi, "construct"):
                        Console.warn("Overriding constructor in extension %s by %s", destName, className)
                        
                    destApi.construct = copy.copy(construct)

                if statics is not None:
                    if not hasattr(destApi, "statics"):
                        destApi.statics = {}

                    for staticName in statics:
                        destApi.statics[staticName] = copy.copy(statics[staticName])
                        destApi.statics[staticName]["from"] = className
                        destApi.statics[staticName]["fromLink"] = "static:%s~%s" % (className, staticName)

                if members is not None:
                    if not hasattr(destApi, "members"):
                        destApi.members = {}
                        
                    for memberName in members:
                        destApi.members[memberName] = copy.copy(members[memberName])
                        destApi.members[memberName]["from"] = className
                        destApi.members[memberName]["fromLink"] = "member:%s~%s" % (className, memberName)

        Console.outdent()
        

        #
        # Connecting Uses / UsedBy
        #

        Console.info("Collecting Use Patterns...")

        # This matches all uses with the known classes and only keeps them if matched
        allClasses = set(list(apiData))
        for className in apiData:
            uses = apiData[className].uses

            # Rebuild use list
            cleanUses = set()
            for use in uses:
                if use != className and use in allClasses:
                    cleanUses.add(use)

                    useEntry = apiData[use]
                    if not hasattr(useEntry, "usedBy"):
                        useEntry.usedBy = set()

                    useEntry.usedBy.add(className)

            apiData[className].uses = cleanUses

        
        
        #
        # Collecting errors
        #
        
        Console.info("Collecting Errors...")
        Console.indent()
        
        for className in sorted(apiData):
            classApi = apiData[className]
            errors = []

            if isErrornous(classApi.main):
                errors.append({
                    "kind": "Main",
                    "name": None,
                    "line": 1
                })
            
            if hasattr(classApi, "construct"):
                if isErrornous(classApi.construct):
                    errors.append({
                        "kind": "Constructor",
                        "name": None,
                        "line": classApi.construct["line"]
                    })
            
            for section in ("statics", "members", "properties", "events"):
                items = getattr(classApi, section, {})
                for itemName in items:
                    item = items[itemName]
                    if isErrornous(item):
                        errors.append({
                            "kind": itemMap[section],
                            "name": itemName,
                            "line": item["line"]
                        })
                        
            if errors:
                if printErrors:
                    Console.warn("Found errors in %s", className)
                    
                errorsSorted = sorted(errors, key=lambda entry: entry["line"])
                
                if printErrors:
                    Console.indent()
                    for entry in errorsSorted:
                        if entry["name"]:
                            Console.warn("%s: %s (line %s)", entry["kind"], entry["name"], entry["line"])
                        else:
                            Console.warn("%s (line %s)", entry["kind"], entry["line"])
                
                    Console.outdent()
                    
                classApi.errors = errorsSorted
                
        Console.outdent()
        
        
        
        #
        # Building Search Index
        #

        Console.info("Building Search Index...")
        search = {}

        def addSearch(classApi, field):
            data = getattr(classApi, field, None)
            if data:
                for name in data:
                    if not name in search:
                        search[name] = set()

                    search[name].add(className)

        for className in apiData:

            classApi = apiData[className]

            addSearch(classApi, "statics")
            addSearch(classApi, "members")
            addSearch(classApi, "properties")
            addSearch(classApi, "events")
        
        
        
        #
        # Post Process (dict to sorted list)
        #
        
        Console.info("Post Processing Data...")
        
        for className in sorted(apiData):
            classApi = apiData[className]
            
            convertTags(classApi.main)
            
            construct = getattr(classApi, "construct", None)
            if construct:
                convertFunction(construct)
                convertTags(construct)

            for section in ("statics", "members", "properties", "events"):
                items = getattr(classApi, section, None)
                if items:
                    sortedList = []
                    for itemName in sorted(items):
                        item = items[itemName]
                        item["name"] = itemName
                        
                        if "type" in item and item["type"] == "Function":
                            convertFunction(item)
                                
                        convertTags(item)
                        sortedList.append(item)

                    setattr(classApi, section, sortedList)
        
        
        
        #
        # Collecting Package Docs
        #

        Console.info("Collecting Package Docs...")
        Console.indent()
        
        # Inject existing package docs into api data
        for project in self.__session.getProjects():
            docs = project.getDocs()
            
            for packageName in docs:
                if self.__isIncluded(packageName, classFilter):
                    Console.debug("Creating package documentation %s", packageName)
                    apiData[packageName] = docs[packageName].getApi()
        
        
        # Fill missing package docs
        for className in sorted(apiData):
            splits = className.split(".")
            packageName = splits[0]
            for split in splits[1:]:
                if not packageName in apiData:
                    Console.warn("Missing package documentation %s", packageName)
                    apiData[packageName] = Data.ApiData(packageName, highlight=highlightCode)
                    apiData[packageName].main = {
                        "type" : "Package",
                        "name" : packageName
                    }
                        
                packageName = "%s.%s" % (packageName, split)


        # Now register all classes in their parent namespace/package
        for className in sorted(apiData):
            splits = className.split(".")
            packageName = ".".join(splits[:-1])
            if packageName:
                package = apiData[packageName]
                # debug("- Registering class %s in parent %s", className, packageName)
                
                entry = {
                    "name" : splits[-1],
                    "link" : className,
                }
                
                classMain = apiData[className].main
                if "doc" in classMain and classMain["doc"]:
                    summary = Text.extractSummary(classMain["doc"])
                    if summary:
                        entry["summary"] = summary
                        
                if "type" in classMain and classMain["type"]:
                    entry["type"] = classMain["type"]
                
                if not hasattr(package, "content"):
                    package.content = [entry]
                else:
                    package.content.append(entry)
                    
        Console.outdent()



        #
        # Writing API Index
        #
        
        Console.debug("Building Index...")
        index = {}
        
        for className in sorted(apiData):
            
            classApi = apiData[className]
            mainInfo = classApi.main
            
            # Create structure for className
            current = index
            for split in className.split("."):
                if not split in current:
                    current[split] = {}
            
                current = current[split]
            
            # Store current type
            current["$type"] = mainInfo["type"]
            
            # Keep information if
            if hasattr(classApi, "content"):
                current["$content"] = True
        
        
        
        #
        # Return
        #
        
        return apiData, index, search

Example 10

Project: merlin
Source File: dnn_synth_PROJECTION.py
View license
def main_function(cfg, in_dir, out_dir, token_xpath, index_attrib_name, synth_mode, cmp_dir, projection_end):
    ## TODO: token_xpath & index_attrib_name   should be in config
    
    # get a logger for this main function
    logger = logging.getLogger("main")
    
    # get another logger to handle plotting duties
    plotlogger = logging.getLogger("plotting")

    # later, we might do this via a handler that is created, attached and configured
    # but for now we need to do it manually
    plotlogger.set_plot_path(cfg.plot_dir)
    
    #### parameter setting########
    hidden_layers_sizes = cfg.hyper_params['hidden_layers_sizes']
    
    ####prepare environment    
    synth_utts_input = glob.glob(in_dir + '/*.utt')
    ###synth_utts_input = synth_utts_input[:10]   ### temp!!!!!

    if synth_mode == 'single_sentence_demo':
        synth_utts_input = synth_utts_input[:1]
        print 
        print 'mode: single_sentence_demo'
        print synth_utts_input
        print

    indexed_utt_dir = os.path.join(out_dir, 'utt') ## place to put test utts with tokens labelled with projection indices
    direcs = [out_dir, indexed_utt_dir]
    for direc in direcs:
        if not os.path.isdir(direc):
            os.mkdir(direc)
    

    ## was below -- see comment
    if synth_mode == 'single_sentence_demo':
        synth_utts_input = add_projection_indices_with_replicates(synth_utts_input, token_xpath, index_attrib_name, indexed_utt_dir, 100)
    else:
        add_projection_indices(synth_utts_input, token_xpath, index_attrib_name, indexed_utt_dir)




    file_id_list = []
    for fname in synth_utts_input:
        junk,name = os.path.split(fname)
        file_id_list.append(name.replace('.utt',''))


    data_dir = cfg.data_dir

    model_dir = os.path.join(cfg.work_dir, 'nnets_model')
    gen_dir   = os.path.join(out_dir, 'gen')    

    ###normalisation information
    norm_info_file = os.path.join(data_dir, 'norm_info' + cfg.combined_feature_name + '_' + str(cfg.cmp_dim) + '_' + cfg.output_feature_normalisation + '.dat')
    
    ### normalise input full context label
    if cfg.label_style == 'HTS':
        sys.exit('only ossian utts supported')        
    elif cfg.label_style == 'composed':
        suffix='composed'

    # the number can be removed
    binary_label_dir      = os.path.join(out_dir, 'lab_bin')
    nn_label_norm_dir     = os.path.join(out_dir, 'lab_bin_norm')

    binary_label_file_list   = prepare_file_path_list(file_id_list, binary_label_dir, cfg.lab_ext)
    nn_label_norm_file_list  = prepare_file_path_list(file_id_list, nn_label_norm_dir, cfg.lab_ext)

    ## need this to find normalisation info:
    if cfg.process_labels_in_work_dir:
        label_data_dir = cfg.work_dir
    else:
        label_data_dir = data_dir
    
    min_max_normaliser = None
    label_norm_file = 'label_norm_%s.dat' %(cfg.label_style)
    label_norm_file = os.path.join(label_data_dir, label_norm_file)
    
    if cfg.label_style == 'HTS':
        sys.exit('script not tested with HTS labels')


    ## always do this in synth:
    ## if cfg.NORMLAB and (cfg.label_style == 'composed'):  
    logger.info('add projection indices to tokens in test utts')

    ## add_projection_indices was here

    logger.info('preparing label data (input) using "composed" style labels')
    label_composer = LabelComposer()
    label_composer.load_label_configuration(cfg.label_config_file)

    logger.info('Loaded label configuration')

    lab_dim=label_composer.compute_label_dimension()
    logger.info('label dimension will be %d' % lab_dim)
    
    if cfg.precompile_xpaths:
        label_composer.precompile_xpaths()
    
    # there are now a set of parallel input label files (e.g, one set of HTS and another set of Ossian trees)
    # create all the lists of these, ready to pass to the label composer

    in_label_align_file_list = {}
    for label_style, label_style_required in label_composer.label_styles.iteritems():
        if label_style_required:
            logger.info('labels of style %s are required - constructing file paths for them' % label_style)
            if label_style == 'xpath':
                in_label_align_file_list['xpath'] = prepare_file_path_list(file_id_list, indexed_utt_dir, cfg.utt_ext, False)
            elif label_style == 'hts':
                logger.critical('script not tested with HTS labels')        
            else:
                logger.critical('unsupported label style %s specified in label configuration' % label_style)
                raise Exception
    
        # now iterate through the files, one at a time, constructing the labels for them 
        num_files=len(file_id_list)
        logger.info('the label styles required are %s' % label_composer.label_styles)
        
        for i in xrange(num_files):
            logger.info('making input label features for %4d of %4d' % (i+1,num_files))

            # iterate through the required label styles and open each corresponding label file

            # a dictionary of file descriptors, pointing at the required files
            required_labels={}
            
            for label_style, label_style_required in label_composer.label_styles.iteritems():
                
                # the files will be a parallel set of files for a single utterance
                # e.g., the XML tree and an HTS label file
                if label_style_required:
                    required_labels[label_style] = open(in_label_align_file_list[label_style][i] , 'r')
                    logger.debug(' opening label file %s' % in_label_align_file_list[label_style][i])

            logger.debug('label styles with open files: %s' % required_labels)
            label_composer.make_labels(required_labels,out_file_name=binary_label_file_list[i],fill_missing_values=cfg.fill_missing_values,iterate_over_frames=cfg.iterate_over_frames)
                
            # now close all opened files
            for fd in required_labels.itervalues():
                fd.close()
    
    # no silence removal for synthesis ...
    
    ## minmax norm:
    min_max_normaliser = MinMaxNormalisation(feature_dimension = lab_dim, min_value = 0.01, max_value = 0.99, exclude_columns=[cfg.index_to_project])

    (min_vector, max_vector) = retrieve_normalisation_values(label_norm_file)
    min_max_normaliser.min_vector = min_vector
    min_max_normaliser.max_vector = max_vector

    ###  apply precompuated and stored min-max to the whole dataset
    min_max_normaliser.normalise_data(binary_label_file_list, nn_label_norm_file_list)


### DEBUG
    if synth_mode == 'inferred':

        ## set up paths -- write CMP data to infer from in outdir:
        nn_cmp_dir = os.path.join(out_dir, 'nn' + cfg.combined_feature_name + '_' + str(cfg.cmp_dim))
        nn_cmp_norm_dir = os.path.join(out_dir, 'nn_norm'  + cfg.combined_feature_name + '_' + str(cfg.cmp_dim))

        in_file_list_dict = {}
        for feature_name in cfg.in_dir_dict.keys():
            in_direc = os.path.join(cmp_dir, feature_name)
            assert os.path.isdir(in_direc), in_direc
            in_file_list_dict[feature_name] = prepare_file_path_list(file_id_list, in_direc, cfg.file_extension_dict[feature_name], False)        
        
        nn_cmp_file_list         = prepare_file_path_list(file_id_list, nn_cmp_dir, cfg.cmp_ext)
        nn_cmp_norm_file_list    = prepare_file_path_list(file_id_list, nn_cmp_norm_dir, cfg.cmp_ext)



        ### make output acoustic data
        #    if cfg.MAKECMP:
        logger.info('creating acoustic (output) features')
        delta_win = [-0.5, 0.0, 0.5]
        acc_win = [1.0, -2.0, 1.0]
        
        acoustic_worker = AcousticComposition(delta_win = delta_win, acc_win = acc_win)
        acoustic_worker.prepare_nn_data(in_file_list_dict, nn_cmp_file_list, cfg.in_dimension_dict, cfg.out_dimension_dict)

        ## skip silence removal for inference -- need to match labels, which are
        ## not silence removed either


        
    ### retrieve acoustic normalisation information for normalising the features back
    var_dir   = os.path.join(data_dir, 'var')
    var_file_dict = {}
    for feature_name in cfg.out_dimension_dict.keys():
        var_file_dict[feature_name] = os.path.join(var_dir, feature_name + '_' + str(cfg.out_dimension_dict[feature_name]))
        
        
    ### normalise output acoustic data
#    if cfg.NORMCMP:


#### DEBUG
    if synth_mode == 'inferred':


        logger.info('normalising acoustic (output) features using method %s' % cfg.output_feature_normalisation)
        cmp_norm_info = None
        if cfg.output_feature_normalisation == 'MVN':
            normaliser = MeanVarianceNorm(feature_dimension=cfg.cmp_dim)

            (mean_vector,std_vector) = retrieve_normalisation_values(norm_info_file)
            normaliser.mean_vector = mean_vector
            normaliser.std_vector = std_vector

            ###  apply precompuated and stored mean and std to the whole dataset
            normaliser.feature_normalisation(nn_cmp_file_list, nn_cmp_norm_file_list)

        elif cfg.output_feature_normalisation == 'MINMAX':        
            sys.exit('not implemented')
            #            min_max_normaliser = MinMaxNormalisation(feature_dimension = cfg.cmp_dim)
            #            global_mean_vector = min_max_normaliser.compute_mean(nn_cmp_file_list[0:cfg.train_file_number])
            #            global_std_vector = min_max_normaliser.compute_std(nn_cmp_file_list[0:cfg.train_file_number], global_mean_vector)

            #            min_max_normaliser = MinMaxNormalisation(feature_dimension = cfg.cmp_dim, min_value = 0.01, max_value = 0.99)
            #            min_max_normaliser.find_min_max_values(nn_cmp_file_list[0:cfg.train_file_number])
            #            min_max_normaliser.normalise_data(nn_cmp_file_list, nn_cmp_norm_file_list)

            #            cmp_min_vector = min_max_normaliser.min_vector
            #            cmp_max_vector = min_max_normaliser.max_vector
            #            cmp_norm_info = numpy.concatenate((cmp_min_vector, cmp_max_vector), axis=0)

        else:
            logger.critical('Normalisation type %s is not supported!\n' %(cfg.output_feature_normalisation))
            raise
 

    combined_model_arch = str(len(hidden_layers_sizes))
    for hid_size in hidden_layers_sizes:
        combined_model_arch += '_' + str(hid_size)
    nnets_file_name = '%s/%s_%s_%d_%s_%d.%d.train.%d.model' \
                      %(model_dir, cfg.model_type, cfg.combined_feature_name, int(cfg.multistream_switch), 
                        combined_model_arch, lab_dim, cfg.cmp_dim, cfg.train_file_number)

    ### DNN model training
#    if cfg.TRAINDNN: always do this in synth






#### DEBUG
    inferred_weights = None ## default, for non-inferring synth methods
    if synth_mode == 'inferred':

        ## infer control values from TESTING data

        ## identical lists (our test data) for 'train' and 'valid' -- this is just to
        ##   keep the infer_projections_fn theano function happy -- operates on
        ##    validation set. 'Train' set shouldn't be used here.
        train_x_file_list = copy.copy(nn_label_norm_file_list)
        train_y_file_list = copy.copy(nn_cmp_norm_file_list)
        valid_x_file_list = copy.copy(nn_label_norm_file_list)
        valid_y_file_list = copy.copy(nn_cmp_norm_file_list)

        print 'FILELIST for inferr:'
        print train_x_file_list 
        print 

        try:
            inferred_weights = infer_projections(train_xy_file_list = (train_x_file_list, train_y_file_list), \
                        valid_xy_file_list = (valid_x_file_list, valid_y_file_list), \
                        nnets_file_name = nnets_file_name, \
                        n_ins = lab_dim, n_outs = cfg.cmp_dim, ms_outs = cfg.multistream_outs, \
                        hyper_params = cfg.hyper_params, buffer_size = cfg.buffer_size, plot = cfg.plot)
           
        except KeyboardInterrupt:
            logger.critical('train_DNN interrupted via keyboard')
            # Could 'raise' the exception further, but that causes a deep traceback to be printed
            # which we don't care about for a keyboard interrupt. So, just bail out immediately
            sys.exit(1)
        except:
            logger.critical('train_DNN threw an exception')
            raise






    ## if cfg.DNNGEN:
    logger.info('generating from DNN')

    try:
        os.makedirs(gen_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # not an error - just means directory already exists
            pass
        else:
            logger.critical('Failed to create generation directory %s' % gen_dir)
            logger.critical(' OS error was: %s' % e.strerror)
            raise



    gen_file_list = prepare_file_path_list(file_id_list, gen_dir, cfg.cmp_ext)

    #print nn_label_norm_file_list  ## <-- this WAS mangled in inferred due to copying of file list to trainlist_x etc. which is then shuffled. Now use copy.copy
    #print gen_file_list

    weights_outfile = os.path.join(out_dir, 'projection_weights_for_synth.txt')  
    dnn_generation_PROJECTION(nn_label_norm_file_list, nnets_file_name, lab_dim, cfg.cmp_dim, gen_file_list, cfg=cfg, synth_mode=synth_mode, projection_end=projection_end, projection_weights_to_use=inferred_weights, save_weights_to_file=weights_outfile )
    
    logger.debug('denormalising generated output using method %s' % cfg.output_feature_normalisation)
    ## DNNGEN

    fid = open(norm_info_file, 'rb')
    cmp_min_max = numpy.fromfile(fid, dtype=numpy.float32)
    fid.close()
    cmp_min_max = cmp_min_max.reshape((2, -1))
    cmp_min_vector = cmp_min_max[0, ] 
    cmp_max_vector = cmp_min_max[1, ]

    if cfg.output_feature_normalisation == 'MVN':
        denormaliser = MeanVarianceNorm(feature_dimension = cfg.cmp_dim)
        denormaliser.feature_denormalisation(gen_file_list, gen_file_list, cmp_min_vector, cmp_max_vector)
        
    elif cfg.output_feature_normalisation == 'MINMAX':
        denormaliser = MinMaxNormalisation(cfg.cmp_dim, min_value = 0.01, max_value = 0.99, min_vector = cmp_min_vector, max_vector = cmp_max_vector)
        denormaliser.denormalise_data(gen_file_list, gen_file_list)
    else:
        logger.critical('denormalising method %s is not supported!\n' %(cfg.output_feature_normalisation))
        raise

    ##perform MLPG to smooth parameter trajectory
    ## lf0 is included, the output features much have vuv. 
    generator = ParameterGeneration(gen_wav_features = cfg.gen_wav_features)
    generator.acoustic_decomposition(gen_file_list, cfg.cmp_dim, cfg.out_dimension_dict, cfg.file_extension_dict, var_file_dict)    

            ## osw: skip MLPG:
#            split_cmp(gen_file_list, ['mgc', 'lf0', 'bap'], cfg.cmp_dim, cfg.out_dimension_dict, cfg.file_extension_dict)    

    ## Variance scaling:
    scaled_dir = gen_dir + '_scaled'
    simple_scale_variance(gen_dir, scaled_dir, var_file_dict, cfg.out_dimension_dict, file_id_list, gv_weight=0.5)  ## gv_weight hardcoded

    ### generate wav ---- glottHMM only!!!
    #if cfg.GENWAV:
    logger.info('reconstructing waveform(s)')
    generate_wav_glottHMM(scaled_dir, file_id_list)   # generated speech

Example 11

Project: pythonVSCode
Source File: isort.py
View license
    def _add_from_imports(self, from_modules, section, section_output, ignore_case):
        for module in from_modules:
            if module in self.remove_imports:
                continue

            import_start = "from {0} import ".format(module)
            from_imports = list(self.imports[section]['from'][module])
            from_imports = nsorted(from_imports, key=lambda key: self._module_key(key, self.config, True, ignore_case))
            if self.remove_imports:
                from_imports = [line for line in from_imports if not "{0}.{1}".format(module, line) in
                                self.remove_imports]

            for from_import in copy.copy(from_imports):
                submodule = module + "." + from_import
                import_as = self.as_map.get(submodule, False)
                if import_as:
                    import_definition = "{0} as {1}".format(from_import, import_as)
                    if self.config['combine_as_imports'] and not ("*" in from_imports and
                                                                    self.config['combine_star']):
                        from_imports[from_imports.index(from_import)] = import_definition
                    else:
                        import_statement = self._wrap(import_start + import_definition)
                        comments = self.comments['straight'].get(submodule)
                        import_statement = self._add_comments(comments, import_statement)
                        section_output.append(import_statement)
                        from_imports.remove(from_import)

            if from_imports:
                comments = self.comments['from'].pop(module, ())
                if "*" in from_imports and self.config['combine_star']:
                    import_statement = self._wrap(self._add_comments(comments, "{0}*".format(import_start)))
                elif self.config['force_single_line']:
                    import_statements = []
                    for from_import in from_imports:
                        single_import_line = self._add_comments(comments, import_start + from_import)
                        comment = self.comments['nested'].get(module, {}).pop(from_import, None)
                        if comment:
                            single_import_line += "{0} {1}".format(comments and ";" or "  #", comment)
                        import_statements.append(self._wrap(single_import_line))
                        comments = None
                    import_statement = "\n".join(import_statements)
                else:
                    star_import = False
                    if "*" in from_imports:
                        section_output.append(self._add_comments(comments, "{0}*".format(import_start)))
                        from_imports.remove('*')
                        star_import = True
                        comments = None

                    for from_import in copy.copy(from_imports):
                        comment = self.comments['nested'].get(module, {}).pop(from_import, None)
                        if comment:
                            single_import_line = self._add_comments(comments, import_start + from_import)
                            single_import_line += "{0} {1}".format(comments and ";" or "  #", comment)
                            above_comments = self.comments['above']['from'].pop(module, None)
                            if above_comments:
                                section_output.extend(above_comments)
                            section_output.append(self._wrap(single_import_line))
                            from_imports.remove(from_import)
                            comments = None

                    if star_import:
                        import_statement = import_start + (", ").join(from_imports)
                    else:
                        import_statement = self._add_comments(comments, import_start + (", ").join(from_imports))
                    if not from_imports:
                        import_statement = ""

                    do_multiline_reformat = False

                    if self.config.get('force_grid_wrap') and len(from_imports) > 1:
                        do_multiline_reformat = True

                    if len(import_statement) > self.config['line_length'] and len(from_imports) > 1:
                        do_multiline_reformat = True

                    # If line too long AND have imports AND we are NOT using GRID or VERTICAL wrap modes
                    if (len(import_statement) > self.config['line_length'] and len(from_imports) > 0
                        and self.config.get('multi_line_output', 0) not in (1, 0)):
                        do_multiline_reformat = True

                    if do_multiline_reformat:
                        output_mode = settings.WrapModes._fields[self.config.get('multi_line_output',
                                                                                    0)].lower()
                        formatter = getattr(self, "_output_" + output_mode, self._output_grid)
                        dynamic_indent = " " * (len(import_start) + 1)
                        indent = self.config['indent']
                        line_length = self.config['wrap_length'] or self.config['line_length']
                        import_statement = formatter(import_start, copy.copy(from_imports),
                                                    dynamic_indent, indent, line_length, comments)
                        if self.config['balanced_wrapping']:
                            lines = import_statement.split("\n")
                            line_count = len(lines)
                            if len(lines) > 1:
                                minimum_length = min([len(line) for line in lines[:-1]])
                            else:
                                minimum_length = 0
                            new_import_statement = import_statement
                            while (len(lines[-1]) < minimum_length and
                                    len(lines) == line_count and line_length > 10):
                                import_statement = new_import_statement
                                line_length -= 1
                                new_import_statement = formatter(import_start, copy.copy(from_imports),
                                                                dynamic_indent, indent, line_length, comments)
                                lines = new_import_statement.split("\n")

                    if not do_multiline_reformat and len(import_statement) > self.config['line_length']:
                        import_statement = self._wrap(import_statement)

                if import_statement:
                    above_comments = self.comments['above']['from'].pop(module, None)
                    if above_comments:
                        section_output.extend(above_comments)
                    section_output.append(import_statement)

Example 12

Project: PyGaze
Source File: libsmi.py
View license
	def wait_for_fixation_start(self):

		"""Returns starting time and position when a fixation is started;
		function assumes a 'fixation' has started when gaze position
		remains reasonably stable (i.e. when most deviant samples are
		within self.pxfixtresh) for five samples in a row (self.pxfixtresh
		is created in self.calibration, based on self.fixtresh, a property
		defined in self.__init__)
		
		arguments
		None
		
		returns
		time, gazepos	-- time is the starting time in milliseconds (from
					   expstart), gazepos is a (x,y) gaze position
					   tuple of the position from which the fixation
					   was initiated
		"""
		
		# # # # #
		# SMI method

		if self.eventdetection == 'native':
			
			# print warning, since SMI does not have a fixation start
			# detection built into their API (only ending)
			
			print("WARNING! 'native' event detection has been selected, \
				but SMI does not offer fixation START detection (only \
				fixation ENDING; PyGaze algorithm will be used")
			
			
		# # # # #
		# PyGaze method
		
		# function assumes a 'fixation' has started when gaze position
		# remains reasonably stable for self.fixtimetresh
		
		# get starting position
		spos = self.sample()
		while not self.is_valid_sample(spos):
			spos = self.sample()
		
		# get starting time
		t0 = clock.get_time()

		# wait for reasonably stable position
		moving = True
		while moving:
			# get new sample
			npos = self.sample()
			# check if sample is valid
			if self.is_valid_sample(npos):
				# check if new sample is too far from starting position
				if (npos[0]-spos[0])**2 + (npos[1]-spos[1])**2 > self.pxfixtresh**2: # Pythagoras
					# if not, reset starting position and time
					spos = copy.copy(npos)
					t0 = clock.get_time()
				# if new sample is close to starting sample
				else:
					# get timestamp
					t1 = clock.get_time()
					# check if fixation time threshold has been surpassed
					if t1 - t0 >= self.fixtimetresh:
						# return time and starting position
						return t1, spos


	def wait_for_saccade_end(self):

		"""Returns ending time, starting and end position when a saccade is
		ended; based on Dalmaijer et al. (2013) online saccade detection
		algorithm
		
		arguments
		None
		
		returns
		endtime, startpos, endpos	-- endtime in milliseconds (from 
							   expbegintime); startpos and endpos
							   are (x,y) gaze position tuples
		"""

		# # # # #
		# SMI method

		if self.eventdetection == 'native':
			
			# print warning, since SMI does not have a blink detection
			# built into their API
			
			print("WARNING! 'native' event detection has been selected, \
				but SMI does not offer saccade detection; PyGaze \
				algorithm will be used")

		# # # # #
		# PyGaze method
		
		# get starting position (no blinks)
		t0, spos = self.wait_for_saccade_start()
		# get valid sample
		prevpos = self.sample()
		while not self.is_valid_sample(prevpos):
			prevpos = self.sample()
		# get starting time, intersample distance, and velocity
		t1 = clock.get_time()
		s = ((prevpos[0]-spos[0])**2 + (prevpos[1]-spos[1])**2)**0.5 # = intersample distance = speed in px/sample
		v0 = s / (t1-t0)

		# run until velocity and acceleration go below threshold
		saccadic = True
		while saccadic:
			# get new sample
			newpos = self.sample()
			t1 = clock.get_time()
			if self.is_valid_sample(newpos) and newpos != prevpos:
				# calculate distance
				s = ((newpos[0]-prevpos[0])**2 + (newpos[1]-prevpos[1])**2)**0.5 # = speed in pixels/sample
				# calculate velocity
				v1 = s / (t1-t0)
				# calculate acceleration
				a = (v1-v0) / (t1-t0) # acceleration in pixels/sample**2 (actually is v1-v0 / t1-t0; but t1-t0 = 1 sample)
				# check if velocity and acceleration are below threshold
				if v1 < self.pxspdtresh and (a > -1*self.pxacctresh and a < 0):
					saccadic = False
					epos = newpos[:]
					etime = clock.get_time()
				# update previous values
				t0 = copy.copy(t1)
				v0 = copy.copy(v1)
			# udate previous sample
			prevpos = newpos[:]

		return etime, spos, epos


	def wait_for_saccade_start(self):

		"""Returns starting time and starting position when a saccade is
		started; based on Dalmaijer et al. (2013) online saccade detection
		algorithm
		
		arguments
		None
		
		returns
		endtime, startpos	-- endtime in milliseconds (from expbegintime);
					   startpos is an (x,y) gaze position tuple
		"""

		# # # # #
		# SMI method

		if self.eventdetection == 'native':
			
			# print warning, since SMI does not have a blink detection
			# built into their API
			
			print("WARNING! 'native' event detection has been selected, \
				but SMI does not offer saccade detection; PyGaze \
				algorithm will be used")

		# # # # #
		# PyGaze method
		
		# get starting position (no blinks)
		newpos = self.sample()
		while not self.is_valid_sample(newpos):
			newpos = self.sample()
		# get starting time, position, intersampledistance, and velocity
		t0 = clock.get_time()
		prevpos = newpos[:]
		s = 0
		v0 = 0

		# get samples
		saccadic = False
		while not saccadic:
			# get new sample
			newpos = self.sample()
			t1 = clock.get_time()
			if self.is_valid_sample(newpos) and newpos != prevpos:
				# check if distance is larger than precision error
				sx = newpos[0]-prevpos[0]; sy = newpos[1]-prevpos[1]
				if (sx/self.pxdsttresh[0])**2 + (sy/self.pxdsttresh[1])**2 > self.weightdist: # weigthed distance: (sx/tx)**2 + (sy/ty)**2 > 1 means movement larger than RMS noise
					# calculate distance
					s = ((sx)**2 + (sy)**2)**0.5 # intersampledistance = speed in pixels/ms
					# calculate velocity
					v1 = s / (t1-t0)
					# calculate acceleration
					a = (v1-v0) / (t1-t0) # acceleration in pixels/ms**2
					# check if either velocity or acceleration are above threshold values
					if v1 > self.pxspdtresh or a > self.pxacctresh:
						saccadic = True
						spos = prevpos[:]
						stime = clock.get_time()
					# update previous values
					t0 = copy.copy(t1)
					v0 = copy.copy(v1)

				# udate previous sample
				prevpos = newpos[:]

		return stime, spos
	
	
	def is_valid_sample(self, gazepos):
		
		"""Checks if the sample provided is valid, based on SMI specific
		criteria (for internal use)
		
		arguments
		gazepos		--	a (x,y) gaze position tuple, as returned by
						self.sample()
		
		returns
		valid			--	a Boolean: True on a valid sample, False on
						an invalid sample
		"""
		
		# return False if a sample is invalid
		if gazepos == (-1,-1):
			return False
		# sometimes, on SMI devices, invalid samples can actually contain
		# numbers; these do 
		elif sum(gazepos) < 10 and 0.0 in gazepos:
			return False
		
		# in any other case, the sample is valid
		return True

Example 13

Project: pythonect
Source File: eval.py
View license
def _run_next_virtual_nodes(graph, node, globals_, locals_, flags, pool, result):

    operator = graph.node[node].get('OPERATOR', None)

    return_value = []

    not_safe_to_iter = False

    is_head_result = True

    head_result = None

    # "Hello, world" or {...}

    if isinstance(result, (basestring, dict)) or not __isiter(result):

        not_safe_to_iter = True

    # [[1]]

    if isinstance(result, list) and len(result) == 1 and isinstance(result[0], list):

        result = result[0]

        not_safe_to_iter = True

    # More nodes ahead?

    if operator:

        if not_safe_to_iter:

            logging.debug('not_safe_to_iter is True for %s' % result)

            head_result = result

            tmp_globals = copy.copy(globals_)

            tmp_locals = copy.copy(locals_)

            tmp_globals['_'] = tmp_locals['_'] = head_result

            return_value = __resolve_and_merge_results(_run(graph, node, tmp_globals, tmp_locals, {}, None, True))

        else:

            # Originally this was implemented using result[0] and result[1:] but xrange() is not slice-able, thus, I have changed it to `for` with buffer for 1st result

            for res_value in result:

                logging.debug('Now at %s from %s' % (res_value, result))

                if is_head_result:

                    logging.debug('is_head_result is True for %s' % res_value)

                    is_head_result = False

                    head_result = res_value

                    tmp_globals = copy.copy(globals_)

                    tmp_locals = copy.copy(locals_)

                    tmp_globals['_'] = tmp_locals['_'] = head_result

                    return_value.insert(0, _run(graph, node, tmp_globals, tmp_locals, {}, None, True))

                    continue

                tmp_globals = copy.copy(globals_)

                tmp_locals = copy.copy(locals_)

                tmp_globals['_'] = tmp_locals['_'] = res_value

                # Synchronous

                if operator == '|':

                    return_value.append(pool.apply(_run, args=(graph, node, tmp_globals, tmp_locals, {}, None, True)))

                # Asynchronous

                if operator == '->':

                    return_value.append(pool.apply_async(_run, args=(graph, node, tmp_globals, tmp_locals, {}, None, True)))

            pool.close()

            pool.join()

            pool.terminate()

            logging.debug('return_value = %s' % return_value)

            return_value = __resolve_and_merge_results(return_value)

    # Loopback

    else:

        # AS IS

        if not_safe_to_iter:

            return_value = [result]

        # Iterate for all possible *return values*

        else:

            for res_value in result:

                return_value.append(res_value)

            # Unbox

            if len(return_value) == 1:

                return_value = return_value[0]

    return return_value

Example 14

Project: volumina
Source File: volumeEditorWidget.py
View license
    def _initShortcuts(self):
        # TODO: Fix this dependency on ImageView/HUD internals
        mgr = ShortcutManager()
        ActionInfo = ShortcutManager.ActionInfo
        mgr.register("x", ActionInfo( "Navigation", 
                                      "Minimize/Maximize x-Window", 
                                      "Minimize/Maximize x-Window", 
                                      self.quadview.switchXMinMax,
                                      self.editor.imageViews[0].hud.buttons['maximize'],
                                      self.editor.imageViews[0].hud.buttons['maximize'] ) )

        mgr.register("y", ActionInfo( "Navigation", 
                                      "Minimize/Maximize y-Window", 
                                      "Minimize/Maximize y-Window", 
                                      self.quadview.switchYMinMax,
                                      self.editor.imageViews[1].hud.buttons['maximize'],
                                      self.editor.imageViews[1].hud.buttons['maximize'] ) )

        mgr.register("z", ActionInfo( "Navigation", 
                                      "Minimize/Maximize z-Window", 
                                      "Minimize/Maximize z-Window", 
                                      self.quadview.switchZMinMax,
                                      self.editor.imageViews[2].hud.buttons['maximize'],
                                      self.editor.imageViews[2].hud.buttons['maximize'] ) )

        for i, v in enumerate(self.editor.imageViews):
            mgr.register("+", ActionInfo( "Navigation", 
                                          "Zoom in", 
                                          "Zoom in", 
                                          v.zoomIn,
                                          v,
                                          None) )
            mgr.register("-", ActionInfo( "Navigation", 
                                          "Zoom out", 
                                          "Zoom out", 
                                          v.zoomOut,
                                          v,
                                          None) )

            mgr.register("c", ActionInfo( "Navigation", 
                                          "Center image", 
                                          "Center image", 
                                          v.centerImage,
                                          v,
                                          None) )

            mgr.register("h", ActionInfo( "Navigation", 
                                          "Toggle hud", 
                                          "Toggle hud", 
                                          v.toggleHud,
                                          v,
                                          None) )

            # FIXME: The nextChannel/previousChannel functions don't work right now.
            #self._shortcutHelper("q", "Navigation", "Switch to next channel",     v, self.editor.nextChannel,     Qt.WidgetShortcut))
            #self._shortcutHelper("a", "Navigation", "Switch to previous channel", v, self.editor.previousChannel, Qt.WidgetShortcut))
            
            def sliceDelta(axis, delta):
                newPos = copy.copy(self.editor.posModel.slicingPos)
                newPos[axis] += delta
                newPos[axis] = max(0, newPos[axis]) 
                newPos[axis] = min(self.editor.posModel.shape[axis]-1, newPos[axis]) 
                self.editor.posModel.slicingPos = newPos

            def jumpToFirstSlice(axis):
                newPos = copy.copy(self.editor.posModel.slicingPos)
                newPos[axis] = 0
                self.editor.posModel.slicingPos = newPos
                
            def jumpToLastSlice(axis):
                newPos = copy.copy(self.editor.posModel.slicingPos)
                newPos[axis] = self.editor.posModel.shape[axis]-1
                self.editor.posModel.slicingPos = newPos
            
            # TODO: Fix this dependency on ImageView/HUD internals
            mgr.register("Ctrl+Up", ActionInfo( "Navigation", 
                                                "Slice up", 
                                                "Slice up", 
                                                partial(sliceDelta, i, 1),
                                                v,
                                                v.hud.buttons['slice'].upLabel) )

            mgr.register("Ctrl+Down", ActionInfo( "Navigation", 
                                                  "Slice up", 
                                                  "Slice up", 
                                                  partial(sliceDelta, i, -1),
                                                  v,
                                                  v.hud.buttons['slice'].downLabel) )

#            self._shortcutHelper("p", "Navigation", "Slice up (alternate shortcut)",   v, partial(sliceDelta, i, 1),  Qt.WidgetShortcut)
#            self._shortcutHelper("o", "Navigation", "Slice down (alternate shortcut)", v, partial(sliceDelta, i, -1), Qt.WidgetShortcut)
            
            mgr.register("Ctrl+Shift+Up", ActionInfo( "Navigation", 
                                                      "10 slices up", 
                                                      "10 slices up", 
                                                      partial(sliceDelta, i, 10),
                                                      v,
                                                      None) )

            mgr.register("Ctrl+Shift+Down", ActionInfo( "Navigation", 
                                                        "10 slices down", 
                                                        "10 slices down", 
                                                        partial(sliceDelta, i, -10),
                                                        v,
                                                        None) )

            mgr.register("Shift+Up", ActionInfo( "Navigation", 
                                                 "Jump to first slice", 
                                                 "Jump to first slice", 
                                                 partial(jumpToFirstSlice, i),
                                                 v,
                                                 None) )

            mgr.register("Shift+Down", ActionInfo( "Navigation", 
                                                   "Jump to last slice", 
                                                   "Jump to last slice", 
                                                   partial(jumpToLastSlice, i),
                                                   v,
                                                   None) )

Example 15

Project: web2py-appreport
Source File: pisa_parser.py
View license
def pisaLoop(node, c, path=[], **kw):

    # Initialize KW
    if not kw:
        kw = {
            "margin-top": 0,
            "margin-bottom": 0,
            "margin-left": 0,
            "margin-right": 0,
            }
    else:
        kw = copy.copy(kw)
        
    indent = len(path) * "  "

    # TEXT
    if node.nodeType == Node.TEXT_NODE:
        # print indent, "#", repr(node.data) #, c.frag
        c.addFrag(node.data)
        # c.text.append(node.value)
       
    # ELEMENT
    elif node.nodeType == Node.ELEMENT_NODE:  
        
        node.tagName = node.tagName.replace(":", "").lower()
        
        if node.tagName in ("style", "script"):
            return
        
        path = copy.copy(path) + [node.tagName]
        
        # Prepare attributes        
        attr = pisaGetAttributes(c, node.tagName, node.attributes)        
        # log.debug(indent + "<%s %s>" % (node.tagName, attr) + repr(node.attributes.items())) #, path
        
        # Calculate styles                
        c.cssAttr = CSSCollect(node, c)
        c.node = node

        # Block?    
        PAGE_BREAK = 1
        PAGE_BREAK_RIGHT = 2
        PAGE_BREAK_LEFT = 3

        pageBreakAfter = False
        frameBreakAfter = False
        display = c.cssAttr.get("display", "inline").lower()
        # print indent, node.tagName, display, c.cssAttr.get("background-color", None), attr
        isBlock = (display == "block")
        if isBlock:
            c.addPara()

            # Page break by CSS
            if c.cssAttr.has_key("-pdf-next-page"):                 
                c.addStory(NextPageTemplate(str(c.cssAttr["-pdf-next-page"])))
            if c.cssAttr.has_key("-pdf-page-break"):
                if str(c.cssAttr["-pdf-page-break"]).lower() == "before":
                    c.addStory(PageBreak()) 
            if c.cssAttr.has_key("-pdf-frame-break"): 
                if str(c.cssAttr["-pdf-frame-break"]).lower() == "before":
                    c.addStory(FrameBreak()) 
                if str(c.cssAttr["-pdf-frame-break"]).lower() == "after":
                    frameBreakAfter = True
            if c.cssAttr.has_key("page-break-before"):            
                if str(c.cssAttr["page-break-before"]).lower() == "always":
                    c.addStory(PageBreak()) 
                if str(c.cssAttr["page-break-before"]).lower() == "right":
                    c.addStory(PageBreak()) 
                    c.addStory(PmlRightPageBreak())
                if str(c.cssAttr["page-break-before"]).lower() == "left":
                    c.addStory(PageBreak()) 
                    c.addStory(PmlLeftPageBreak())
            if c.cssAttr.has_key("page-break-after"):            
                if str(c.cssAttr["page-break-after"]).lower() == "always":
                    pageBreakAfter = PAGE_BREAK
                if str(c.cssAttr["page-break-after"]).lower() == "right":
                    pageBreakAfter = PAGE_BREAK_RIGHT
                if str(c.cssAttr["page-break-after"]).lower() == "left":
                    pageBreakAfter = PAGE_BREAK_LEFT
            
        if display == "none":
            # print "none!"
            return
        
        # Translate CSS to frags 

        # Save previous frag styles
        c.pushFrag()
        
        # Map styles to Reportlab fragment properties
        CSS2Frag(c, kw, isBlock) 
                          
        # EXTRAS
        if c.cssAttr.has_key("-pdf-keep-with-next"):
            c.frag.keepWithNext = getBool(c.cssAttr["-pdf-keep-with-next"])
        if c.cssAttr.has_key("-pdf-outline"):
            c.frag.outline = getBool(c.cssAttr["-pdf-outline"])
        if c.cssAttr.has_key("-pdf-outline-level"):
            c.frag.outlineLevel = int(c.cssAttr["-pdf-outline-level"])
        if c.cssAttr.has_key("-pdf-outline-open"):
            c.frag.outlineOpen = getBool(c.cssAttr["-pdf-outline-open"])
        #if c.cssAttr.has_key("-pdf-keep-in-frame-max-width"):
        #    c.frag.keepInFrameMaxWidth = getSize("".join(c.cssAttr["-pdf-keep-in-frame-max-width"]))
        #if c.cssAttr.has_key("-pdf-keep-in-frame-max-height"):
        #    c.frag.keepInFrameMaxHeight = getSize("".join(c.cssAttr["-pdf-keep-in-frame-max-height"]))
        if c.cssAttr.has_key("-pdf-keep-in-frame-mode"):
            value = str(c.cssAttr["-pdf-keep-in-frame-mode"]).strip().lower()
            if value not in ("shrink", "error", "overflow", "shrink", "truncate"):
                value = None
            c.frag.keepInFrameMode = value
                
        # BEGIN tag
        klass = globals().get("pisaTag%s" % node.tagName.replace(":", "").upper(), None)
        obj = None      

        # Static block
        elementId = attr.get("id", None)             
        staticFrame = c.frameStatic.get(elementId, None)
        if staticFrame:
            c.frag.insideStaticFrame += 1
            oldStory = c.swapStory()
                  
        # Tag specific operations
        if klass is not None:        
            obj = klass(node, attr)
            obj.start(c)
            
        # Visit child nodes
        c.fragBlock = fragBlock = copy.copy(c.frag)        
        for nnode in node.childNodes:
            pisaLoop(nnode, c, path, **kw)        
        c.fragBlock = fragBlock
                            
        # END tag
        if obj:
            obj.end(c)

        # Block?
        if isBlock:
            c.addPara()

            # XXX Buggy!

            # Page break by CSS
            if pageBreakAfter:
                c.addStory(PageBreak()) 
                if pageBreakAfter == PAGE_BREAK_RIGHT:
                    c.addStory(PmlRightPageBreak())
                if pageBreakAfter == PAGE_BREAK_LEFT:
                    c.addStory(PmlLeftPageBreak())
            if frameBreakAfter:                
                c.addStory(FrameBreak()) 

        # Static block, END
        if staticFrame:
            c.addPara()
            for frame in staticFrame:
                frame.pisaStaticStory = c.story            
            c.swapStory(oldStory)
            c.frag.insideStaticFrame -= 1
            
        # c.debug(1, indent, "</%s>" % (node.tagName))
        
        # Reset frag style                   
        c.pullFrag()                                    

    # Unknown or not handled
    else:
        # c.debug(1, indent, "???", node, node.nodeType, repr(node))
        # Loop over children
        for node in node.childNodes:
            pisaLoop(node, c, path, **kw)

Example 16

Project: web2py-appreport
Source File: pisa_parser.py
View license
def pisaLoop(node, c, path=[], **kw):

    # Initialize KW
    if not kw:
        kw = {
            "margin-top": 0,
            "margin-bottom": 0,
            "margin-left": 0,
            "margin-right": 0,
            }
    else:
        kw = copy.copy(kw)
        
    indent = len(path) * "  "

    # TEXT
    if node.nodeType == Node.TEXT_NODE:
        # print indent, "#", repr(node.data) #, c.frag
        c.addFrag(node.data)
        # c.text.append(node.value)
       
    # ELEMENT
    elif node.nodeType == Node.ELEMENT_NODE:  
        
        node.tagName = node.tagName.replace(":", "").lower()
        
        if node.tagName in ("style", "script"):
            return
        
        path = copy.copy(path) + [node.tagName]
        
        # Prepare attributes        
        attr = pisaGetAttributes(c, node.tagName, node.attributes)        
        # log.debug(indent + "<%s %s>" % (node.tagName, attr) + repr(node.attributes.items())) #, path
        
        # Calculate styles                
        c.cssAttr = CSSCollect(node, c)
        c.node = node

        # Block?    
        PAGE_BREAK = 1
        PAGE_BREAK_RIGHT = 2
        PAGE_BREAK_LEFT = 3

        pageBreakAfter = False
        frameBreakAfter = False
        display = c.cssAttr.get("display", "inline").lower()
        # print indent, node.tagName, display, c.cssAttr.get("background-color", None), attr
        isBlock = (display == "block")
        if isBlock:
            c.addPara()

            # Page break by CSS
            if c.cssAttr.has_key("-pdf-next-page"):                 
                c.addStory(NextPageTemplate(str(c.cssAttr["-pdf-next-page"])))
            if c.cssAttr.has_key("-pdf-page-break"):
                if str(c.cssAttr["-pdf-page-break"]).lower() == "before":
                    c.addStory(PageBreak()) 
            if c.cssAttr.has_key("-pdf-frame-break"): 
                if str(c.cssAttr["-pdf-frame-break"]).lower() == "before":
                    c.addStory(FrameBreak()) 
                if str(c.cssAttr["-pdf-frame-break"]).lower() == "after":
                    frameBreakAfter = True
            if c.cssAttr.has_key("page-break-before"):            
                if str(c.cssAttr["page-break-before"]).lower() == "always":
                    c.addStory(PageBreak()) 
                if str(c.cssAttr["page-break-before"]).lower() == "right":
                    c.addStory(PageBreak()) 
                    c.addStory(PmlRightPageBreak())
                if str(c.cssAttr["page-break-before"]).lower() == "left":
                    c.addStory(PageBreak()) 
                    c.addStory(PmlLeftPageBreak())
            if c.cssAttr.has_key("page-break-after"):            
                if str(c.cssAttr["page-break-after"]).lower() == "always":
                    pageBreakAfter = PAGE_BREAK
                if str(c.cssAttr["page-break-after"]).lower() == "right":
                    pageBreakAfter = PAGE_BREAK_RIGHT
                if str(c.cssAttr["page-break-after"]).lower() == "left":
                    pageBreakAfter = PAGE_BREAK_LEFT
            
        if display == "none":
            # print "none!"
            return
        
        # Translate CSS to frags 

        # Save previous frag styles
        c.pushFrag()
        
        # Map styles to Reportlab fragment properties
        CSS2Frag(c, kw, isBlock) 
                          
        # EXTRAS
        if c.cssAttr.has_key("-pdf-keep-with-next"):
            c.frag.keepWithNext = getBool(c.cssAttr["-pdf-keep-with-next"])
        if c.cssAttr.has_key("-pdf-outline"):
            c.frag.outline = getBool(c.cssAttr["-pdf-outline"])
        if c.cssAttr.has_key("-pdf-outline-level"):
            c.frag.outlineLevel = int(c.cssAttr["-pdf-outline-level"])
        if c.cssAttr.has_key("-pdf-outline-open"):
            c.frag.outlineOpen = getBool(c.cssAttr["-pdf-outline-open"])
        #if c.cssAttr.has_key("-pdf-keep-in-frame-max-width"):
        #    c.frag.keepInFrameMaxWidth = getSize("".join(c.cssAttr["-pdf-keep-in-frame-max-width"]))
        #if c.cssAttr.has_key("-pdf-keep-in-frame-max-height"):
        #    c.frag.keepInFrameMaxHeight = getSize("".join(c.cssAttr["-pdf-keep-in-frame-max-height"]))
        if c.cssAttr.has_key("-pdf-keep-in-frame-mode"):
            value = str(c.cssAttr["-pdf-keep-in-frame-mode"]).strip().lower()
            if value not in ("shrink", "error", "overflow", "shrink", "truncate"):
                value = None
            c.frag.keepInFrameMode = value
                
        # BEGIN tag
        klass = globals().get("pisaTag%s" % node.tagName.replace(":", "").upper(), None)
        obj = None      

        # Static block
        elementId = attr.get("id", None)             
        staticFrame = c.frameStatic.get(elementId, None)
        if staticFrame:
            c.frag.insideStaticFrame += 1
            oldStory = c.swapStory()
                  
        # Tag specific operations
        if klass is not None:        
            obj = klass(node, attr)
            obj.start(c)
            
        # Visit child nodes
        c.fragBlock = fragBlock = copy.copy(c.frag)        
        for nnode in node.childNodes:
            pisaLoop(nnode, c, path, **kw)        
        c.fragBlock = fragBlock
                            
        # END tag
        if obj:
            obj.end(c)

        # Block?
        if isBlock:
            c.addPara()

            # XXX Buggy!

            # Page break by CSS
            if pageBreakAfter:
                c.addStory(PageBreak()) 
                if pageBreakAfter == PAGE_BREAK_RIGHT:
                    c.addStory(PmlRightPageBreak())
                if pageBreakAfter == PAGE_BREAK_LEFT:
                    c.addStory(PmlLeftPageBreak())
            if frameBreakAfter:                
                c.addStory(FrameBreak()) 

        # Static block, END
        if staticFrame:
            c.addPara()
            for frame in staticFrame:
                frame.pisaStaticStory = c.story            
            c.swapStory(oldStory)
            c.frag.insideStaticFrame -= 1
            
        # c.debug(1, indent, "</%s>" % (node.tagName))
        
        # Reset frag style                   
        c.pullFrag()                                    

    # Unknown or not handled
    else:
        # c.debug(1, indent, "???", node, node.nodeType, repr(node))
        # Loop over children
        for node in node.childNodes:
            pisaLoop(node, c, path, **kw)

Example 17

Project: nginx-amplify-agent
Source File: config.py
View license
    def __collect_data(self, subtree=None, ctx=None):
        """
        Searches needed data in config's tree

        :param subtree: dict with tree to parse
        :param ctx: dict with context
        """
        ctx = ctx if ctx is not None else {}
        subtree = subtree if subtree is not None else {}

        for key, value in subtree.iteritems():
            if key == 'error_log':
                error_logs = value if isinstance(value, list) else [value]
                for er_log_definition in error_logs:
                    if er_log_definition == 'off':
                        continue

                    split_er_log_definition = er_log_definition.split(' ')
                    log_name = split_er_log_definition[0]
                    log_level = split_er_log_definition[-1] \
                        if split_er_log_definition[-1] in ERROR_LOG_LEVELS else 'error'  # nginx default log level
                    log_name = re.sub('[\'"]', '', log_name)  # remove all ' and "

                    # if not syslog, assume it is a file...if not starts with '/' assume relative path
                    if not log_name.startswith('syslog') and not log_name.startswith('/'):
                        log_name = '%s/%s' % (self.prefix, log_name)

                    if log_name not in self.error_logs:
                        self.error_logs[log_name] = log_level
            elif key == 'access_log':
                access_logs = value if isinstance(value, list) else [value]
                for ac_log_definition in access_logs:
                    if ac_log_definition == 'off':
                        continue

                    parts = filter(len, ac_log_definition.split(' '))
                    log_format = None if len(parts) == 1 else parts[1]
                    log_name = parts[0]
                    log_name = re.sub('[\'"]', '', log_name)  # remove all ' and "

                    # if not syslog, assume it is a file...if not starts with '/' assume relative path
                    if not log_name.startswith('syslog') and not log_name.startswith('/'):
                        log_name = '%s/%s' % (self.prefix, log_name)

                    self.access_logs[log_name] = log_format
            elif key == 'log_format':
                for k, v in value.iteritems():
                    self.log_formats[k] = v
            elif key == 'server' and isinstance(value, list) and 'upstream' not in ctx:
                for server in value:

                    current_ctx = copy.copy(ctx)
                    if server.get('listen') is None:
                        # if no listens specified, then use default *:80 and *:8000
                        listen = ['80', '8000']
                    else:
                        listen = server.get('listen')
                    listen = listen if isinstance(listen, list) else [listen]

                    ctx['ip_port'] = []
                    for item in listen:
                        listen_first_part = item.split(' ')[0]
                        try:
                            addr, port = self.__parse_listen(listen_first_part)
                            if addr in ('*', '0.0.0.0'):
                                addr = '127.0.0.1'
                            elif addr == '[::]':
                                addr = '[::1]'
                            ctx['ip_port'].append((addr, port))
                        except Exception as e:
                            context.log.error('failed to parse bad ipv6 listen directive: %s' % listen_first_part)
                            context.log.debug('additional info:', exc_info=True)

                    if 'server_name' in server:
                        ctx['server_name'] = server.get('server_name')

                    self.__collect_data(subtree=server, ctx=ctx)
                    ctx = current_ctx
            elif key == 'upstream':
                for upstream, upstream_info in value.iteritems():
                    current_ctx = copy.copy(ctx)
                    ctx['upstream'] = upstream
                    self.__collect_data(subtree=upstream_info, ctx=ctx)
                    ctx = current_ctx
            elif key == 'location':
                for location, location_info in value.iteritems():
                    current_ctx = copy.copy(ctx)
                    ctx['location'] = location
                    self.__collect_data(subtree=location_info, ctx=ctx)
                    ctx = current_ctx
            elif key == 'stub_status' and ctx and 'ip_port' in ctx:
                for url in self.__status_url(ctx):
                    if url not in self.stub_status_urls:
                        self.stub_status_urls.append(url)
            elif key == 'status' and ctx and 'ip_port' in ctx:
                # use different url builders for external and internal urls
                for url in self.__status_url(ctx, server_preferred=True):
                    if url not in self.plus_status_external_urls:
                        self.plus_status_external_urls.append(url)

                # for internal (agent) usage local ip address is a better choice,
                # because the external url might not be accessible from a host
                for url in self.__status_url(ctx, server_preferred=False):
                    if url not in self.plus_status_internal_urls:
                        self.plus_status_internal_urls.append(url)
            elif isinstance(value, dict):
                self.__collect_data(subtree=value, ctx=ctx)
            elif isinstance(value, list):
                for next_subtree in value:
                    if isinstance(next_subtree, dict):
                        self.__collect_data(subtree=next_subtree, ctx=ctx)

Example 18

Project: nginx-amplify-agent
Source File: config.py
View license
    def __collect_data(self, subtree=None, ctx=None):
        """
        Searches needed data in config's tree

        :param subtree: dict with tree to parse
        :param ctx: dict with context
        """
        ctx = ctx if ctx is not None else {}
        subtree = subtree if subtree is not None else {}

        for key, value in subtree.iteritems():
            if key == 'error_log':
                error_logs = value if isinstance(value, list) else [value]
                for er_log_definition in error_logs:
                    if er_log_definition == 'off':
                        continue

                    split_er_log_definition = er_log_definition.split(' ')
                    log_name = split_er_log_definition[0]
                    log_level = split_er_log_definition[-1] \
                        if split_er_log_definition[-1] in ERROR_LOG_LEVELS else 'error'  # nginx default log level
                    log_name = re.sub('[\'"]', '', log_name)  # remove all ' and "

                    # if not syslog, assume it is a file...if not starts with '/' assume relative path
                    if not log_name.startswith('syslog') and not log_name.startswith('/'):
                        log_name = '%s/%s' % (self.prefix, log_name)

                    if log_name not in self.error_logs:
                        self.error_logs[log_name] = log_level
            elif key == 'access_log':
                access_logs = value if isinstance(value, list) else [value]
                for ac_log_definition in access_logs:
                    if ac_log_definition == 'off':
                        continue

                    parts = filter(len, ac_log_definition.split(' '))
                    log_format = None if len(parts) == 1 else parts[1]
                    log_name = parts[0]
                    log_name = re.sub('[\'"]', '', log_name)  # remove all ' and "

                    # if not syslog, assume it is a file...if not starts with '/' assume relative path
                    if not log_name.startswith('syslog') and not log_name.startswith('/'):
                        log_name = '%s/%s' % (self.prefix, log_name)

                    self.access_logs[log_name] = log_format
            elif key == 'log_format':
                for k, v in value.iteritems():
                    self.log_formats[k] = v
            elif key == 'server' and isinstance(value, list) and 'upstream' not in ctx:
                for server in value:

                    current_ctx = copy.copy(ctx)
                    if server.get('listen') is None:
                        # if no listens specified, then use default *:80 and *:8000
                        listen = ['80', '8000']
                    else:
                        listen = server.get('listen')
                    listen = listen if isinstance(listen, list) else [listen]

                    ctx['ip_port'] = []
                    for item in listen:
                        listen_first_part = item.split(' ')[0]
                        try:
                            addr, port = self.__parse_listen(listen_first_part)
                            if addr in ('*', '0.0.0.0'):
                                addr = '127.0.0.1'
                            elif addr == '[::]':
                                addr = '[::1]'
                            ctx['ip_port'].append((addr, port))
                        except Exception as e:
                            context.log.error('failed to parse bad ipv6 listen directive: %s' % listen_first_part)
                            context.log.debug('additional info:', exc_info=True)

                    if 'server_name' in server:
                        ctx['server_name'] = server.get('server_name')

                    self.__collect_data(subtree=server, ctx=ctx)
                    ctx = current_ctx
            elif key == 'upstream':
                for upstream, upstream_info in value.iteritems():
                    current_ctx = copy.copy(ctx)
                    ctx['upstream'] = upstream
                    self.__collect_data(subtree=upstream_info, ctx=ctx)
                    ctx = current_ctx
            elif key == 'location':
                for location, location_info in value.iteritems():
                    current_ctx = copy.copy(ctx)
                    ctx['location'] = location
                    self.__collect_data(subtree=location_info, ctx=ctx)
                    ctx = current_ctx
            elif key == 'stub_status' and ctx and 'ip_port' in ctx:
                for url in self.__status_url(ctx):
                    if url not in self.stub_status_urls:
                        self.stub_status_urls.append(url)
            elif key == 'status' and ctx and 'ip_port' in ctx:
                # use different url builders for external and internal urls
                for url in self.__status_url(ctx, server_preferred=True):
                    if url not in self.plus_status_external_urls:
                        self.plus_status_external_urls.append(url)

                # for internal (agent) usage local ip address is a better choice,
                # because the external url might not be accessible from a host
                for url in self.__status_url(ctx, server_preferred=False):
                    if url not in self.plus_status_internal_urls:
                        self.plus_status_internal_urls.append(url)
            elif isinstance(value, dict):
                self.__collect_data(subtree=value, ctx=ctx)
            elif isinstance(value, list):
                for next_subtree in value:
                    if isinstance(next_subtree, dict):
                        self.__collect_data(subtree=next_subtree, ctx=ctx)

Example 19

Project: dl4mt-cdec
Source File: char_biscale.py
View license
def gen_sample(tparams, f_init, f_next, x, options, trng=None,
               k=1, maxlen=500, stochastic=True, argmax=False):

    # k is the beam size we have
    if k > 1:
        assert not stochastic, \
            'Beam search does not support stochastic sampling'

    sample = []
    sample_score = []
    if stochastic:
        sample_score = 0

    live_k = 1
    dead_k = 0

    hyp_samples = [[]] * live_k
    hyp_scores = numpy.zeros(live_k).astype('float32')
    hyp_states = []

    # get initial state of decoder rnn and encoder context
    ret = f_init(x)
    next_state_char, next_state_word, ctx0 = ret[0], ret[1], ret[2]
    next_bound_char = numpy.zeros((1, options['dec_dim'])).astype('float32')
    next_bound_word = numpy.zeros((1, options['dec_dim'])).astype('float32')
    next_w = -1 * numpy.ones((1,)).astype('int64')  # bos indicator

    for ii in xrange(maxlen):
        ctx = numpy.tile(ctx0, [live_k, 1])
        inps = [next_w, ctx, next_state_char, next_state_word, next_bound_char, next_bound_word]
        ret = f_next(*inps)
        next_p, next_w, next_state_char, next_state_word, next_bound_char, next_bound_word = ret[0], ret[1], ret[2], ret[3], ret[4], ret[5]

        if stochastic:
            if argmax:
                nw = next_p[0].argmax()
            else:
                nw = next_w[0]
            sample.append(nw)
            sample_score += next_p[0, nw]
            if nw == 0:
                break
        else:
            cand_scores = hyp_scores[:, None] - numpy.log(next_p)
            cand_flat = cand_scores.flatten()
            ranks_flat = cand_flat.argsort()[:(k-dead_k)]

            voc_size = next_p.shape[1]
            trans_indices = ranks_flat / voc_size
            word_indices = ranks_flat % voc_size
            costs = cand_flat[ranks_flat]

            new_hyp_samples = []
            new_hyp_scores = numpy.zeros(k-dead_k).astype('float32')
            new_hyp_states_char = []
            new_hyp_states_word = []
            new_hyp_bounds_char = []
            new_hyp_bounds_word = []

            for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
                new_hyp_samples.append(hyp_samples[ti]+[wi])
                new_hyp_scores[idx] = copy.copy(costs[idx])
                new_hyp_states_char.append(copy.copy(next_state_char[ti]))
                new_hyp_states_word.append(copy.copy(next_state_word[ti]))
                new_hyp_bounds_char.append(copy.copy(next_bound_char[ti]))
                new_hyp_bounds_word.append(copy.copy(next_bound_word[ti]))

            # check the finished samples
            new_live_k = 0
            hyp_samples = []
            hyp_scores = []
            hyp_states_char = []
            hyp_states_word = []
            hyp_bounds_char = []
            hyp_bounds_word = []

            for idx in xrange(len(new_hyp_samples)):
                if new_hyp_samples[idx][-1] == 0:
                    sample.append(new_hyp_samples[idx])
                    sample_score.append(new_hyp_scores[idx])
                    dead_k += 1
                else:
                    new_live_k += 1
                    hyp_samples.append(new_hyp_samples[idx])
                    hyp_scores.append(new_hyp_scores[idx])
                    hyp_states_char.append(new_hyp_states_char[idx])
                    hyp_states_word.append(new_hyp_states_word[idx])
                    hyp_bounds_char.append(new_hyp_bounds_char[idx])
                    hyp_bounds_word.append(new_hyp_bounds_word[idx])
            hyp_scores = numpy.array(hyp_scores)
            live_k = new_live_k

            if new_live_k < 1:
                break
            if dead_k >= k:
                break

            next_w = numpy.array([w[-1] for w in hyp_samples])
            next_state_char = numpy.array(hyp_states_char)
            next_state_word = numpy.array(hyp_states_word)
            next_bound_char = numpy.array(hyp_bounds_char)
            next_bound_word = numpy.array(hyp_bounds_word)

    if not stochastic:
        # dump every remaining one
        if live_k > 0:
            for idx in xrange(live_k):
                sample.append(hyp_samples[idx])
                sample_score.append(hyp_scores[idx])

    return sample, sample_score

Example 20

Project: dl4mt-multi
Source File: sampling.py
View license
def gen_sample(f_init, f_next, x, src_selector, trg_selector, k=1,
               maxlen=30, stochastic=True, argmax=False, eos_idx=0,
               cond_init_trg=False, ignore_unk=False, minlen=1, unk_idx=1,
               f_next_state=None, return_alphas=False):
    if k > 1:
        assert not stochastic, \
            'Beam search does not support stochastic sampling'

    sample = []
    sample_score = []
    sample_decalphas = []
    if stochastic:
        sample_score = 0

    live_k = 1
    dead_k = 0

    hyp_samples = [[]] * live_k
    hyp_decalphas = []
    hyp_scores = numpy.zeros(live_k).astype('float32')
    hyp_states = []

    # multi-source
    inp_xs = [x]
    init_inps = inp_xs

    ret = f_init(*init_inps)
    next_state, ctx0 = ret[0], ret[1]
    next_w = -1 * numpy.ones((1,)).astype('int64')

    for ii in range(maxlen):
        ctx = numpy.tile(ctx0, [live_k, 1])

        prev_w = copy.copy(next_w)
        prev_state = copy.copy(next_state)
        inps = [next_w, ctx, next_state]

        ret = f_next(*inps)
        next_p, next_w, next_state = ret[0], ret[1], ret[2]

        if return_alphas:
            next_decalpha = ret.pop(0)

        if stochastic:
            if argmax:
                nw = next_p[0].argmax()
            else:
                nw = next_w[0]
            sample.append(nw)
            sample_score -= numpy.log(next_p[0, nw])
            if nw == eos_idx:
                break
        else:
            log_probs = numpy.log(next_p)

            # Adjust log probs according to search restrictions
            if ignore_unk:
                log_probs[:, unk_idx] = -numpy.inf
            if ii < minlen:
                log_probs[:, eos_idx] = -numpy.inf

            cand_scores = hyp_scores[:, None] - numpy.log(next_p)
            cand_flat = cand_scores.flatten()
            ranks_flat = cand_flat.argsort()[:(k-dead_k)]

            voc_size = next_p.shape[1]
            trans_indices = ranks_flat / voc_size
            word_indices = ranks_flat % voc_size
            costs = cand_flat[ranks_flat]

            new_hyp_samples = []
            new_hyp_scores = numpy.zeros(k-dead_k).astype('float32')
            new_hyp_states = []
            new_hyp_decalphas = []

            for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
                new_hyp_samples.append(hyp_samples[ti]+[wi])
                new_hyp_scores[idx] = copy.copy(costs[idx])
                new_hyp_states.append(copy.copy(next_state[ti]))

                if return_alphas:
                    tmp_decalphas = []
                    if ii > 0:
                        tmp_decalphas = copy.copy(hyp_decalphas[ti])
                    tmp_decalphas.append(next_decalpha[ti])
                    new_hyp_decalphas.append(tmp_decalphas)

            # check the finished samples
            new_live_k = 0
            hyp_samples = []
            hyp_scores = []
            hyp_states = []
            hyp_decalphas = []

            for idx in range(len(new_hyp_samples)):
                if new_hyp_samples[idx][-1] == eos_idx:
                    sample.append(new_hyp_samples[idx])
                    sample_score.append(new_hyp_scores[idx])
                    if return_alphas:
                        sample_decalphas.append(new_hyp_decalphas[idx])
                    dead_k += 1
                else:
                    new_live_k += 1
                    hyp_samples.append(new_hyp_samples[idx])
                    hyp_scores.append(new_hyp_scores[idx])
                    hyp_states.append(new_hyp_states[idx])
                    if return_alphas:
                        hyp_decalphas.append(new_hyp_decalphas[idx])
            hyp_scores = numpy.array(hyp_scores)
            live_k = new_live_k

            if new_live_k < 1:
                break
            if dead_k >= k:
                break

            next_w = numpy.array([w[-1] for w in hyp_samples])
            next_state = numpy.array(hyp_states)

    if not stochastic:
        # dump every remaining one
        if live_k > 0:
            for idx in range(live_k):
                sample.append(hyp_samples[idx])
                sample_score.append(hyp_scores[idx])
                if return_alphas:
                    sample_decalphas.append(hyp_decalphas[idx])

    if not return_alphas:
        return numpy.array(sample), numpy.array(sample_score)
    return numpy.array(sample), numpy.array(sample_score), \
        numpy.array(sample_decalphas)

Example 21

Project: pyregion
Source File: mpl_helper.py
View license
def as_mpl_artists(shape_list,
                   properties_func=None,
                   text_offset=5.0, origin=1):
    """
    Converts a region list to a list of patches and a list of artists.


    Optional Keywords:
    [ text_offset ] - If there is text associated with the regions, add
    some vertical offset (in pixels) to the text so that it doesn't overlap
    with the regions.

    Often, the regions files implicitly assume the lower-left corner
    of the image as a coordinate (1,1). However, the python convetion
    is that the array index starts from 0. By default (origin = 1),
    coordinates of the returned mpl artists have coordinate shifted by
    (1, 1). If you do not want this shift, set origin=0.
    """

    patch_list = []
    artist_list = []

    if properties_func is None:
        properties_func = properties_func_default

    # properties for continued(? multiline?) regions
    saved_attrs = None

    for shape in shape_list:

        patches = []

        if saved_attrs is None:
            _attrs = [], {}
        else:
            _attrs = copy.copy(saved_attrs[0]), copy.copy(saved_attrs[1])

        kwargs = properties_func(shape, _attrs)

        if shape.name == "composite":
            saved_attrs = shape.attr
            continue

        if saved_attrs is None and shape.continued:
            saved_attrs = shape.attr
        #         elif (shape.name in shape.attr[1]):
        #             if (shape.attr[1][shape.name] != "ignore"):
        #                 saved_attrs = shape.attr

        if not shape.continued:
            saved_attrs = None

        # text associated with the shape
        txt = shape.attr[1].get("text")

        if shape.name == "polygon":
            xy = np.array(shape.coord_list)
            xy.shape = -1, 2

            # -1 for change origin to 0,0
            patches = [mpatches.Polygon(xy - origin, closed=True, **kwargs)]

        elif shape.name == "rotbox" or shape.name == "box":
            xc, yc, w, h, rot = shape.coord_list
            # -1 for change origin to 0,0
            xc, yc = xc - origin, yc - origin
            _box = np.array([[-w / 2., -h / 2.],
                             [-w / 2., h / 2.],
                             [w / 2., h / 2.],
                             [w / 2., -h / 2.]])
            box = _box + [xc, yc]
            rotbox = rotated_polygon(box, xc, yc, rot)
            patches = [mpatches.Polygon(rotbox, closed=True, **kwargs)]

        elif shape.name == "ellipse":
            xc, yc = shape.coord_list[:2]
            # -1 for change origin to 0,0
            xc, yc = xc - origin, yc - origin
            angle = shape.coord_list[-1]

            maj_list, min_list = shape.coord_list[2:-1:2], shape.coord_list[3:-1:2]

            patches = [mpatches.Ellipse((xc, yc), 2 * maj, 2 * min,
                                        angle=angle, **kwargs)
                       for maj, min in zip(maj_list, min_list)]

        elif shape.name == "annulus":
            xc, yc = shape.coord_list[:2]
            # -1 for change origin to 0,0
            xc, yc = xc - origin, yc - origin
            r_list = shape.coord_list[2:]

            patches = [mpatches.Ellipse((xc, yc), 2 * r, 2 * r, **kwargs) for r in r_list]

        elif shape.name == "circle":
            xc, yc, major = shape.coord_list
            # -1 for change origin to 0,0
            xc, yc = xc - origin, yc - origin
            patches = [mpatches.Ellipse((xc, yc), 2 * major, 2 * major, angle=0, **kwargs)]

        elif shape.name == "panda":
            xc, yc, a1, a2, an, r1, r2, rn = shape.coord_list
            # -1 for change origin to 0,0
            xc, yc = xc - origin, yc - origin
            patches = [mpatches.Arc((xc, yc), rr * 2, rr * 2, angle=0,
                                    theta1=a1, theta2=a2, **kwargs)
                       for rr in np.linspace(r1, r2, rn + 1)]

            for aa in np.linspace(a1, a2, an + 1):
                xx = np.array([r1, r2]) * np.cos(aa / 180. * np.pi) + xc
                yy = np.array([r1, r2]) * np.sin(aa / 180. * np.pi) + yc
                p = Path(np.transpose([xx, yy]))
                patches.append(mpatches.PathPatch(p, **kwargs))

        elif shape.name == "pie":
            xc, yc, r1, r2, a1, a2 = shape.coord_list
            # -1 for change origin to 0,0
            xc, yc = xc - origin, yc - origin

            patches = [mpatches.Arc((xc, yc), rr * 2, rr * 2, angle=0,
                                    theta1=a1, theta2=a2, **kwargs)
                       for rr in [r1, r2]]

            for aa in [a1, a2]:
                xx = np.array([r1, r2]) * np.cos(aa / 180. * np.pi) + xc
                yy = np.array([r1, r2]) * np.sin(aa / 180. * np.pi) + yc
                p = Path(np.transpose([xx, yy]))
                patches.append(mpatches.PathPatch(p, **kwargs))

        elif shape.name == "epanda":
            xc, yc, a1, a2, an, r11, r12, r21, r22, rn, angle = shape.coord_list
            # -1 for change origin to 0,0
            xc, yc = xc - origin, yc - origin

            # mpl takes angle a1, a2 as angle as in circle before
            # transformation to ellipse.

            x1, y1 = cos(a1 / 180. * pi), sin(a1 / 180. * pi) * r11 / r12
            x2, y2 = cos(a2 / 180. * pi), sin(a2 / 180. * pi) * r11 / r12

            a1, a2 = atan2(y1, x1) / pi * 180., atan2(y2, x2) / pi * 180.

            patches = [mpatches.Arc((xc, yc), rr1 * 2, rr2 * 2,
                                    angle=angle, theta1=a1, theta2=a2,
                                    **kwargs)
                       for rr1, rr2 in zip(np.linspace(r11, r21, rn + 1),
                                           np.linspace(r12, r22, rn + 1))]

            for aa in np.linspace(a1, a2, an + 1):
                xx = np.array([r11, r21]) * np.cos(aa / 180. * np.pi)
                yy = np.array([r11, r21]) * np.sin(aa / 180. * np.pi)
                p = Path(np.transpose([xx, yy]))
                tr = Affine2D().scale(1, r12 / r11).rotate_deg(angle).translate(xc, yc)
                p2 = tr.transform_path(p)
                patches.append(mpatches.PathPatch(p2, **kwargs))

        elif shape.name == "text":
            xc, yc = shape.coord_list[:2]
            # -1 for change origin to 0,0
            xc, yc = xc - origin, yc - origin

            if txt:
                _t = _get_text(txt, xc, yc, 0, 0, **kwargs)
                artist_list.append(_t)

        elif shape.name == "point":
            xc, yc = shape.coord_list[:2]
            # -1 for change origin to 0,0
            xc, yc = xc - origin, yc - origin
            artist_list.append(Line2D([xc], [yc],
                                      **kwargs))

            if txt:
                textshape = copy.copy(shape)
                textshape.name = "text"
                textkwargs = properties_func(textshape, _attrs)
                _t = _get_text(txt, xc, yc, 0, text_offset,
                               va="bottom",
                               **textkwargs)
                artist_list.append(_t)

        elif shape.name in ["line", "vector"]:
            if shape.name == "line":
                x1, y1, x2, y2 = shape.coord_list[:4]
                # -1 for change origin to 0,0
                x1, y1, x2, y2 = x1 - origin, y1 - origin, x2 - origin, y2 - origin

                a1, a2 = shape.attr[1].get("line", "0 0").strip().split()[:2]

                arrowstyle = "-"
                if int(a1):
                    arrowstyle = "<" + arrowstyle
                if int(a2):
                    arrowstyle = arrowstyle + ">"

            else:  # shape.name == "vector"
                x1, y1, l, a = shape.coord_list[:4]
                # -1 for change origin to 0,0
                x1, y1 = x1 - origin, y1 - origin
                x2, y2 = x1 + l * np.cos(a / 180. * np.pi), y1 + l * np.sin(a / 180. * np.pi)
                v1 = int(shape.attr[1].get("vector", "0").strip())

                if v1:
                    arrowstyle = "->"
                else:
                    arrowstyle = "-"

            patches = [mpatches.FancyArrowPatch(posA=(x1, y1),
                                                posB=(x2, y2),
                                                arrowstyle=arrowstyle,
                                                arrow_transmuter=None,
                                                connectionstyle="arc3",
                                                patchA=None, patchB=None,
                                                shrinkA=0, shrinkB=0,
                                                connector=None,
                                                **kwargs)]

        else:
            warnings.warn("'as_mpl_artists' does not know how to convert {0} "
                          "to mpl artist".format(shape.name))

        patch_list.extend(patches)

        if txt and patches:
            # the text associated with a shape uses different
            # matplotlib keywords than the shape itself for, e.g.,
            # color
            textshape = copy.copy(shape)
            textshape.name = "text"
            textkwargs = properties_func(textshape, _attrs)

            # calculate the text position
            _bb = [p.get_window_extent() for p in patches]

            # this is to work around backward-incompatible change made
            # in matplotlib 1.2. This change is later reverted so only
            # some versions are affected. With affected version of
            # matplotlib, get_window_extent method calls get_transform
            # method which sets the _transformSet to True, which is
            # not desired.
            for p in patches:
                p._transformSet = False

            _bbox = Bbox.union(_bb)
            x0, y0, x1, y1 = _bbox.extents
            xc = .5 * (x0 + x1)

            _t = _get_text(txt, xc, y1, 0, text_offset,
                           va="bottom",
                           **textkwargs)
            artist_list.append(_t)

    return patch_list, artist_list

Example 22

Project: asciimatics
Source File: basics.py
View license
def demo(screen):
    scenes = []
    centre = (screen.width // 2, screen.height // 2)
    podium = (8, 5)

    # Scene 1.
    path = Path()
    path.jump_to(-20, centre[1])
    path.move_straight_to(centre[0], centre[1], 10)
    path.wait(30)
    path.move_straight_to(podium[0], podium[1], 10)
    path.wait(100)

    effects = [
        Arrow(screen, path, colour=Screen.COLOUR_GREEN),
        _speak(screen, "WELCOME TO ASCIIMATICS", centre, 30),
        _speak(screen, "My name is Aristotle Arrow.", podium, 110),
        _speak(screen,
               "I'm here to help you learn ASCIImatics.", podium, 180),
    ]
    scenes.append(Scene(effects))

    # Scene 2.
    path = Path()
    path.jump_to(podium[0], podium[1])

    effects = [
        Arrow(screen, path, colour=Screen.COLOUR_GREEN),
        _speak(screen, "Let's start with the Screen...", podium, 10),
        _speak(screen, "This is your Screen object.", podium, 80),
        Print(screen,
              Box(screen.width, screen.height, uni=screen.unicode_aware),
              0, 0, start_frame=90),
        _speak(screen, "It lets you play a Scene like this one I'm in.",
               podium, 150),
        _speak(screen, "A Scene contains one or more Effects.", podium, 220),
        _speak(screen, "Like me - I'm a Sprite!", podium, 290),
        _speak(screen, "Or these Stars.", podium, 360),
        _speak(screen, "As you can see, the Screen handles them both at once.",
               podium, 430),
        _speak(screen, "It can handle as many Effects as you like.",
               podium, 500),
        _speak(screen, "Please press <SPACE> now.", podium, 570),
        Stars(screen, (screen.width + screen.height) // 2, start_frame=360)
    ]
    scenes.append(Scene(effects, -1))

    # Scene 3.
    path = Path()
    path.jump_to(podium[0], podium[1])

    effects = [
        Arrow(screen, path, colour=Screen.COLOUR_GREEN),
        _speak(screen, "This is a new Scene.", podium, 10),
        _speak(screen, "The Screen stops all Effects and clears itself between "
                       "Scenes.",
               podium, 70),
        _speak(screen, "That's why you can't see the Stars now.", podium, 130),
        _speak(screen, "(Though you can override that if you need to.)", podium,
               200),
        _speak(screen, "Please press <SPACE> now.", podium, 270),
    ]
    scenes.append(Scene(effects, -1))

    # Scene 4.
    path = Path()
    path.jump_to(podium[0], podium[1])

    effects = [
        Arrow(screen, path, colour=Screen.COLOUR_GREEN),
        _speak(screen, "So, how do you design your animation?", podium, 10),
        _speak(screen, "1) Decide on your cinematic flow of Scenes.", podium,
               80),
        _speak(screen, "2) Create the Effects in each Scene.", podium, 150),
        _speak(screen, "3) Pass the Scenes to the Screen to play.", podium,
               220),
        _speak(screen, "It really is that easy!", podium, 290),
        _speak(screen, "Just look at this sample code.", podium, 360),
        _speak(screen, "Please press <SPACE> now.", podium, 430),
    ]
    scenes.append(Scene(effects, -1))

    # Scene 5.
    path = Path()
    path.jump_to(podium[0], podium[1])

    effects = [
        Arrow(screen, path, colour=Screen.COLOUR_GREEN),
        _speak(screen, "There are various effects you can use.  For "
                       "example...",
               podium, 10),
        Cycle(screen,
              FigletText("Colour cycling"),
              centre[1] - 5,
              start_frame=100),
        Cycle(screen,
              FigletText("using Figlet"),
              centre[1] + 1,
              start_frame=100),
        _speak(screen, "Look in the effects module for more...",
               podium, 290),
        _speak(screen, "Please press <SPACE> now.", podium, 360),
    ]
    scenes.append(Scene(effects, -1))

    # Scene 6.
    path = Path()
    path.jump_to(podium[0], podium[1])
    curve_path = []
    for i in range(0, 11):
        curve_path.append(
            (centre[0] + (screen.width / 4 * math.sin(i * math.pi / 5)),
             centre[1] - (screen.height / 4 * math.cos(i * math.pi / 5))))
    path2 = Path()
    path2.jump_to(centre[0], centre[1] - screen.height // 4)
    path2.move_round_to(curve_path, 60)

    effects = [
        Arrow(screen, path, colour=Screen.COLOUR_GREEN),
        _speak(screen, "Sprites (like me) are also an Effect.", podium, 10),
        _speak(screen, "We take a pre-defined Path to follow.", podium, 80),
        _speak(screen, "Like this one...", podium, 150),
        Plot(screen, path2, colour=Screen.COLOUR_BLUE, start_frame=160,
             stop_frame=300),
        _speak(screen, "My friend Sam will now follow it...", podium, 320),
        Sam(screen, copy.copy(path2), start_frame=380),
        _speak(screen, "Please press <SPACE> now.", podium, 420),
    ]
    scenes.append(Scene(effects, -1))

    # Scene 7.
    path = Path()
    path.jump_to(podium[0], podium[1])
    path.wait(60)
    path.move_straight_to(-5, podium[1], 20)
    path.wait(300)

    effects = [
        Arrow(screen, path, colour=Screen.COLOUR_GREEN),
        _speak(screen, "Goodbye!", podium, 10),
        Cycle(screen,
              FigletText("THE END!"),
              centre[1] - 4,
              start_frame=100),
        Print(screen, SpeechBubble("Press X to exit"), centre[1] + 6,
              start_frame=150)
    ]
    scenes.append(Scene(effects, 500))

    screen.play(scenes, stop_on_resize=True)

Example 23

Project: dispy
Source File: httpd.py
View license
        def do_POST(self):
            form = cgi.FieldStorage(fp=self.rfile, headers=self.headers,
                                    environ={'REQUEST_METHOD': 'POST'})
            if self.path == '/node_jobs':
                ip_addr = None
                for item in form.list:
                    if item.name == 'host':
                        # if it looks like IP address, skip resolving
                        if re.match('^\d+[\.\d]+$', item.value):
                            ip_addr = item.value
                        else:
                            try:
                                ip_addr = socket.gethostbyname(item.value)
                            except:
                                ip_addr = item.value
                        break
                self._dispy_ctx._cluster_lock.acquire()
                cluster_infos = [(name, cluster_info) for name, cluster_info in
                                 self._dispy_ctx._clusters.items()]
                self._dispy_ctx._cluster_lock.release()
                jobs = []
                node = None
                for name, cluster_info in cluster_infos:
                    cluster_node = cluster_info.status.get(ip_addr, None)
                    if not cluster_node:
                        continue
                    if node:
                        node.jobs_done += cluster_node.jobs_done
                        node.cpu_time += cluster_node.cpu_time
                        node.update_time = max(node.update_time, cluster_node.update_time)
                    else:
                        node = copy.copy(cluster_node)
                    cluster_jobs = cluster_info.cluster.node_jobs(ip_addr)
                    # args and kwargs are sent as strings in Python,
                    # so an object's __str__ or __repr__ is used if provided;
                    # TODO: check job is in _dispy_ctx's jobs?
                    jobs.extend([{'uid': id(job), 'job_id': str(job.id),
                                  'args': ', '.join(str(arg) for arg in job.args),
                                  'kwargs': ', '.join('%s=%s' % (key, val)
                                                      for key, val in job.kwargs.items()),
                                  'sched_time_ms': int(1000 * job.start_time),
                                  'cluster': name}
                                 for job in cluster_jobs])
                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                if node and node.avail_info:
                    node.avail_info = node.avail_info.__dict__
                self.wfile.write(json.dumps({'node': node.__dict__, 'jobs': jobs}).encode())
                return
            elif self.path == '/cancel_jobs':
                uids = []
                for item in form.list:
                    if item.name == 'uid':
                        try:
                            uids.append(int(item.value))
                        except ValueError:
                            logger.debug('Cancel job uid "%s" is invalid', item.value)

                self._dispy_ctx._cluster_lock.acquire()
                cluster_jobs = [(cluster_info.cluster, cluster_info.jobs.get(uid, None))
                                for cluster_info in self._dispy_ctx._clusters.values()
                                for uid in uids]
                self._dispy_ctx._cluster_lock.release()
                cancelled = []
                for cluster, job in cluster_jobs:
                    if not job:
                        continue
                    if cluster.cancel(job) == 0:
                        cancelled.append(id(job))
                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(cancelled).encode())
                return
            elif self.path == '/add_node':
                node = {'host': '', 'port': None, 'cpus': 0, 'cluster': None}
                node_id = None
                cluster = None
                for item in form.list:
                    if item.name == 'host':
                        node['host'] = item.value
                    elif item.name == 'cluster':
                        node['cluster'] = item.value
                    elif item.name == 'port':
                        node['port'] = item.value
                    elif item.name == 'cpus':
                        try:
                            node['cpus'] = int(item.value)
                        except:
                            pass
                    elif item.name == 'id':
                        node_id = item.value
                if node['host']:
                    self._dispy_ctx._cluster_lock.acquire()
                    clusters = [cluster_info.cluster for name, cluster_info in
                                self._dispy_ctx._clusters.items()
                                if name == node['cluster'] or not node['cluster']]
                    self._dispy_ctx._cluster_lock.release()
                    for cluster in clusters:
                        cluster.allocate_node(node)
                    self.send_response(200)
                    self.send_header('Content-Type', 'text/html')
                    self.end_headers()
                    node['id'] = node_id
                    self.wfile.write(json.dumps(node).encode())
                    return
            elif self.path == '/set_poll_sec':
                for item in form.list:
                    if item.name != 'timeout':
                        continue
                    try:
                        timeout = int(item.value)
                        if timeout < 1:
                            timeout = 0
                    except:
                        logger.warning('HTTP client %s: invalid timeout "%s" ignored',
                                       self.client_address[0], item.value)
                        timeout = 0
                    self._dispy_ctx._poll_sec = timeout
                    self.send_response(200)
                    self.send_header('Content-Type', 'text/html')
                    self.end_headers()
                    return
            elif self.path == '/set_cpus':
                node_cpus = {}
                for item in form.list:
                    self._dispy_ctx._cluster_lock.acquire()
                    for cluster_info in self._dispy_ctx._clusters.values():
                        node = cluster_info.status.get(item.name, None)
                        if node:
                            node_cpus[item.name] = cluster_info.cluster.set_node_cpus(
                                item.name, item.value)
                            if node_cpus[item.name] >= 0:
                                break
                    self._dispy_ctx._cluster_lock.release()

                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(node_cpus).encode())
                return

            logger.debug('Bad POST request from %s: %s', self.client_address[0], self.path)
            self.send_error(400)
            return

Example 24

Project: dispy
Source File: httpd.py
View license
        def do_POST(self):
            form = cgi.FieldStorage(fp=self.rfile, headers=self.headers,
                                    environ={'REQUEST_METHOD': 'POST'})
            if self.path == '/node_jobs':
                ip_addr = None
                for item in form.list:
                    if item.name == 'host':
                        # if it looks like IP address, skip resolving
                        if re.match('^\d+[\.\d]+$', item.value):
                            ip_addr = item.value
                        else:
                            try:
                                ip_addr = socket.gethostbyname(item.value)
                            except:
                                ip_addr = item.value
                        break
                self._dispy_ctx._cluster_lock.acquire()
                cluster_infos = [(name, cluster_info) for name, cluster_info in
                                 self._dispy_ctx._clusters.items()]
                self._dispy_ctx._cluster_lock.release()
                jobs = []
                node = None
                for name, cluster_info in cluster_infos:
                    cluster_node = cluster_info.status.get(ip_addr, None)
                    if not cluster_node:
                        continue
                    if node:
                        node.jobs_done += cluster_node.jobs_done
                        node.cpu_time += cluster_node.cpu_time
                        node.update_time = max(node.update_time, cluster_node.update_time)
                    else:
                        node = copy.copy(cluster_node)
                    cluster_jobs = cluster_info.cluster.node_jobs(ip_addr)
                    # args and kwargs are sent as strings in Python,
                    # so an object's __str__ or __repr__ is used if provided;
                    # TODO: check job is in _dispy_ctx's jobs?
                    jobs.extend([{'uid': id(job), 'job_id': str(job.id),
                                  'args': ', '.join(str(arg) for arg in job.args),
                                  'kwargs': ', '.join('%s=%s' % (key, val)
                                                      for key, val in job.kwargs.items()),
                                  'sched_time_ms': int(1000 * job.start_time),
                                  'cluster': name}
                                 for job in cluster_jobs])
                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                if node and node.avail_info:
                    node.avail_info = node.avail_info.__dict__
                self.wfile.write(json.dumps({'node': node.__dict__, 'jobs': jobs}).encode())
                return
            elif self.path == '/cancel_jobs':
                uids = []
                for item in form.list:
                    if item.name == 'uid':
                        try:
                            uids.append(int(item.value))
                        except ValueError:
                            logger.debug('Cancel job uid "%s" is invalid', item.value)

                self._dispy_ctx._cluster_lock.acquire()
                cluster_jobs = [(cluster_info.cluster, cluster_info.jobs.get(uid, None))
                                for cluster_info in self._dispy_ctx._clusters.values()
                                for uid in uids]
                self._dispy_ctx._cluster_lock.release()
                cancelled = []
                for cluster, job in cluster_jobs:
                    if not job:
                        continue
                    if cluster.cancel(job) == 0:
                        cancelled.append(id(job))
                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(cancelled).encode())
                return
            elif self.path == '/add_node':
                node = {'host': '', 'port': None, 'cpus': 0, 'cluster': None}
                node_id = None
                cluster = None
                for item in form.list:
                    if item.name == 'host':
                        node['host'] = item.value
                    elif item.name == 'cluster':
                        node['cluster'] = item.value
                    elif item.name == 'port':
                        node['port'] = item.value
                    elif item.name == 'cpus':
                        try:
                            node['cpus'] = int(item.value)
                        except:
                            pass
                    elif item.name == 'id':
                        node_id = item.value
                if node['host']:
                    self._dispy_ctx._cluster_lock.acquire()
                    clusters = [cluster_info.cluster for name, cluster_info in
                                self._dispy_ctx._clusters.items()
                                if name == node['cluster'] or not node['cluster']]
                    self._dispy_ctx._cluster_lock.release()
                    for cluster in clusters:
                        cluster.allocate_node(node)
                    self.send_response(200)
                    self.send_header('Content-Type', 'text/html')
                    self.end_headers()
                    node['id'] = node_id
                    self.wfile.write(json.dumps(node).encode())
                    return
            elif self.path == '/set_poll_sec':
                for item in form.list:
                    if item.name != 'timeout':
                        continue
                    try:
                        timeout = int(item.value)
                        if timeout < 1:
                            timeout = 0
                    except:
                        logger.warning('HTTP client %s: invalid timeout "%s" ignored',
                                       self.client_address[0], item.value)
                        timeout = 0
                    self._dispy_ctx._poll_sec = timeout
                    self.send_response(200)
                    self.send_header('Content-Type', 'text/html')
                    self.end_headers()
                    return
            elif self.path == '/set_cpus':
                node_cpus = {}
                for item in form.list:
                    self._dispy_ctx._cluster_lock.acquire()
                    for cluster_info in self._dispy_ctx._clusters.values():
                        node = cluster_info.status.get(item.name, None)
                        if node:
                            node_cpus[item.name] = cluster_info.cluster.set_node_cpus(
                                item.name, item.value)
                            if node_cpus[item.name] >= 0:
                                break
                    self._dispy_ctx._cluster_lock.release()

                self.send_response(200)
                self.send_header('Content-Type', 'application/json; charset=utf-8')
                self.end_headers()
                self.wfile.write(json.dumps(node_cpus).encode())
                return

            logger.debug('Bad POST request from %s: %s', self.client_address[0], self.path)
            self.send_error(400)
            return

Example 25

Project: Windows-Agent
Source File: basic.py
View license
def collect():
    logging.debug('enter basic collect')
    push_interval = 60
    zh_decode = "gbk"

    time_now = int(time.time())
    payload = []
    data = {"endpoint": g.HOSTNAME, "metric": "", "timestamp": time_now,
            "step": push_interval, "value": "", "counterType": "", "tags": ""}

    cpu_status = psutil.cpu_times_percent()
    mem_status = psutil.virtual_memory()
    swap_status = psutil.swap_memory()
    disk_io_status = psutil.disk_io_counters(perdisk=True)
    net_io_status = psutil.net_io_counters(pernic=True)

    # agent alive
    data["metric"] = "agent.alive"
    data["value"] = 1
    data["counterType"] = "GAUGE"
    payload.append(copy.copy(data))

    logging.debug(cpu_status)
    data["metric"] = "cpu.user"
    data["value"] = cpu_status.user
    data["counterType"] = "GAUGE"
    payload.append(copy.copy(data))

    data["metric"] = "cpu.system"
    data["value"] = cpu_status.system
    payload.append(copy.copy(data))

    data["metric"] = "cpu.idle"
    data["value"] = cpu_status.idle
    payload.append(copy.copy(data))

    data["metric"] = "mem.memused.percent"
    data["value"] = mem_status.percent
    payload.append(copy.copy(data))

    data["metric"] = "mem.swapused.percent"
    data["value"] = swap_status.percent
    payload.append(copy.copy(data))

    disk_status = psutil.disk_partitions()
    for disk in disk_status:
        if 'cdrom' in disk.opts or disk.fstype == '':
            continue
        disk_info = psutil.disk_usage(disk.mountpoint)

        data["metric"] = "df.used.percent"
        data["value"] = disk_info.percent
        data["tags"] = "disk=" + disk.device.split(":")[0]
        payload.append(copy.copy(data))

        data["metric"] = "df.byte.total"
        data["value"] = disk_info.total
        payload.append(copy.copy(data))

        data["metric"] = "df.byte.used"
        data["value"] = disk_info.used
        payload.append(copy.copy(data))

        data["metric"] = "df.byte.free"
        data["value"] = disk_info.free
        payload.append(copy.copy(data))

    for key in disk_io_status:
        data["metric"] = "disk.io.read_count"
        data["value"] = disk_io_status[key].read_count
        data["tags"] = "device=" + key
        data["counterType"] = "COUNTER"
        payload.append(copy.copy(data))

        data["metric"] = "disk.io.write_count"
        data["value"] = disk_io_status[key].write_count
        payload.append(copy.copy(data))

        data["metric"] = "disk.io.read_bytes"
        data["value"] = disk_io_status[key].read_bytes
        payload.append(copy.copy(data))

        data["metric"] = "disk.io.write_bytes"
        data["value"] = disk_io_status[key].write_bytes
        payload.append(copy.copy(data))

        data["metric"] = "disk.io.read_time"
        data["value"] = disk_io_status[key].read_time
        payload.append(copy.copy(data))

        data["metric"] = "disk.io.write_time"
        data["value"] = disk_io_status[key].write_time
        payload.append(copy.copy(data))

    for key in net_io_status:
        if is_interface_ignore(key):
            continue

        data["metric"] = "net.if.in.mbits"
        data["value"] = net_io_status[key].bytes_recv * 8 / 100000
        data["tags"] = "interface=" + key.decode(zh_decode)
        payload.append(copy.copy(data))

        data["metric"] = "net.if.out.mbits"
        data["value"] = net_io_status[key].bytes_sent * 8 / 100000
        payload.append(copy.copy(data))

        data["metric"] = "net.if.in.packets"
        data["value"] = net_io_status[key].packets_recv
        payload.append(copy.copy(data))

        data["metric"] = "net.if.out.packets"
        data["value"] = net_io_status[key].packets_sent
        payload.append(copy.copy(data))

        data["metric"] = "net.if.in.error"
        data["value"] = net_io_status[key].errin
        payload.append(copy.copy(data))

        data["metric"] = "net.if.out.error"
        data["value"] = net_io_status[key].errout
        payload.append(copy.copy(data))

        data["metric"] = "net.if.in.drop"
        data["value"] = net_io_status[key].dropin
        payload.append(copy.copy(data))

        data["metric"] = "net.if.out.drop"
        data["value"] = net_io_status[key].dropout
        payload.append(copy.copy(data))
        logging.debug(payload)

        payload = filter(lambda x: x['metric'] not in g.IGNORE, payload)

    try:
        result = send_data_to_transfer(payload)
    except Exception as e:
        logging.error(e)
    else:
        logging.info(result)

Example 26

Project: auto-sklearn
Source File: classification.py
View license
    @classmethod
    def get_hyperparameter_search_space(cls, include=None, exclude=None,
                                        dataset_properties=None):
        """Create the hyperparameter configuration space.

        Parameters
        ----------
        include : dict (optional, default=None)

        Returns
        -------
        """
        cs = ConfigurationSpace()

        if dataset_properties is None or not isinstance(dataset_properties, dict):
            dataset_properties = dict()
        if not 'target_type' in dataset_properties:
            dataset_properties['target_type'] = 'classification'
        if dataset_properties['target_type'] != 'classification':
            dataset_properties['target_type'] = 'classification'

        pipeline = cls._get_pipeline()
        cs = cls._get_hyperparameter_search_space(cs, dataset_properties,
                                                  exclude, include, pipeline)

        classifiers = cs.get_hyperparameter('classifier:__choice__').choices
        preprocessors = cs.get_hyperparameter('preprocessor:__choice__').choices
        available_classifiers = pipeline[-1][1].get_available_components(
            dataset_properties)
        available_preprocessors = pipeline[-2][1].get_available_components(
            dataset_properties)

        possible_default_classifier = copy.copy(list(
            available_classifiers.keys()))
        default = cs.get_hyperparameter('classifier:__choice__').default
        del possible_default_classifier[possible_default_classifier.index(default)]

        # A classifier which can handle sparse data after the densifier is
        # forbidden for memory issues
        for key in classifiers:
            if SPARSE in available_classifiers[key].get_properties()['input']:
                if 'densifier' in preprocessors:
                    while True:
                        try:
                            cs.add_forbidden_clause(
                                ForbiddenAndConjunction(
                                    ForbiddenEqualsClause(
                                        cs.get_hyperparameter(
                                            'classifier:__choice__'), key),
                                    ForbiddenEqualsClause(
                                        cs.get_hyperparameter(
                                            'preprocessor:__choice__'), 'densifier')
                                ))
                            # Success
                            break
                        except ValueError:
                            # Change the default and try again
                            try:
                                default = possible_default_classifier.pop()
                            except IndexError:
                                raise ValueError("Cannot find a legal default configuration.")
                            cs.get_hyperparameter(
                                'classifier:__choice__').default = default

        # which would take too long
        # Combinations of non-linear models with feature learning:
        classifiers_ = ["adaboost", "decision_tree", "extra_trees",
                        "gradient_boosting", "k_nearest_neighbors",
                        "libsvm_svc", "random_forest", "gaussian_nb",
                        "decision_tree", "xgradient_boosting"]
        feature_learning = ["kitchen_sinks", "nystroem_sampler"]

        for c, f in product(classifiers_, feature_learning):
            if c not in classifiers:
                continue
            if f not in preprocessors:
                continue
            while True:
                try:
                    cs.add_forbidden_clause(ForbiddenAndConjunction(
                        ForbiddenEqualsClause(cs.get_hyperparameter(
                            "classifier:__choice__"), c),
                        ForbiddenEqualsClause(cs.get_hyperparameter(
                            "preprocessor:__choice__"), f)))
                    break
                except KeyError:
                    break
                except ValueError as e:
                    # Change the default and try again
                    try:
                        default = possible_default_classifier.pop()
                    except IndexError:
                        raise ValueError(
                            "Cannot find a legal default configuration.")
                    cs.get_hyperparameter(
                        'classifier:__choice__').default = default

        # Won't work
        # Multinomial NB etc don't use with features learning, pca etc
        classifiers_ = ["multinomial_nb"]
        preproc_with_negative_X = ["kitchen_sinks", "pca", "truncatedSVD",
                                   "fast_ica", "kernel_pca", "nystroem_sampler"]

        for c, f in product(classifiers_, preproc_with_negative_X):
            if c not in classifiers:
                continue
            if f not in preprocessors:
                continue
            while True:
                try:
                    cs.add_forbidden_clause(ForbiddenAndConjunction(
                        ForbiddenEqualsClause(cs.get_hyperparameter(
                            "preprocessor:__choice__"), f),
                        ForbiddenEqualsClause(cs.get_hyperparameter(
                            "classifier:__choice__"), c)))
                    break
                except KeyError:
                    break
                except ValueError:
                    # Change the default and try again
                    try:
                        default = possible_default_classifier.pop()
                    except IndexError:
                        raise ValueError(
                            "Cannot find a legal default configuration.")
                    cs.get_hyperparameter(
                        'classifier:__choice__').default = default

        return cs

Example 27

Project: pc-bot
Source File: thunder.py
View license
    def handler_user_votes(self, user, channel, message):
        """Record a user's vote."""

        # parse out the vote into individual tokens, separated by commas,
        # spaces, or both -- make this into a purely comma-separated vote
        message = re.sub(r'/[\s]+/', ' ', message)
        message = message.replace(', ', ',').replace(' ', ',')
        vote = message.split(',')

        # Copy the user's former vote, if any.
        # We will modify `answer` instead of writing his vote directly
        # to `self.current_votes`, so that if there's an error, we don't save
        # only half the vote somehow.
        answer = set()
        if user in self.current_votes:
            answer = self.current_votes[user]

        # Ensure that every sub-piece of this vote is individually valid
        # I currently understand:
        #   - integers on the talk_id list, optionally prefixed with [+-]
        #   - string "all"
        #   - string "none"
        invalid_pieces = []
        invalid_talk_ids = []
        for piece in vote:
            # I understand integers if they are on the talk_id list,
            # including if they are prefixed with [+-]
            if re.match(r'^[+-]?[\d]+$', piece):
                talk_id = int(piece.replace('-', '').replace('+', ''))
                if talk_id not in self.current_group.talk_ids:
                    invalid_talk_ids.append(talk_id)
                continue

            # I understand "all" and "none"
            if piece == 'all' or piece == 'none':
                continue

            # I have no idea what this is
            invalid_pieces.append(piece)

        # Sanity check: if I have any invalid tokens or talk_ids that aren't
        # in the talk_id list, fail out now.
        if len(invalid_pieces) or len(invalid_talk_ids):
            if len(invalid_pieces) > 3:
                self.msg(channel, '%s: I do not believe that was intended '
                                  'to be a vote.' % user)
            elif len(invalid_pieces):
                self.msg(channel, '{user}: I do not understand {tok}.'.format(
                    user=user,
                    tok=self._english_list(
                        ['"{0}"'.format(i) for i in invalid_pieces],
                        conjunction='or',
                    ),
                ))
            if len(invalid_talk_ids):
                self.msg(channel, '{user}: You voted for {talks}, which '
                                  '{to_be_verb} not part of this group. Your '
                                  'vote has not been recorded.'.format(
                    talks=self._english_list(
                        ['#{0}'.format(i) for i in invalid_talk_ids],
                    ),
                    to_be_verb='is' if len(invalid_talk_ids) == 1 else 'are',
                    user=user,
                ))
            return

        # The simple case is that this is a "plain" vote -- a list of
        # integers with no specials (e.g. "none") and no modifiers (+/-).
        #
        # This is straightforward: the vote becomes, in its entirety, the
        # user's vote, and anything previously recorded for the user is
        # simply dropped.
        if reduce(lambda x, y: bool(x) and bool(y),
                               [re.match(r'^[\d]+$', i) for i in vote]):
            self.current_votes[user] = set([int(i) for i in vote])
            return

        # Sanity check: non-plain votes should not have *any* plain elements;
        # therefore, if there are any, we should error out now.
        if reduce(lambda x, y: bool(x) or bool(y),
                  [re.match(r'^[\d]+$', i) for i in vote]):
            # Use examples from the actual group to minimize confusion
            examples = list(self.current_group.talk_ids)[0:2]
            while len(examples) < 2:
                examples.append(randint(1, 100))  # just in case

            # Spit out the error.
            # Since this is long, send as much of it as possible to PMs
            self.msg(channel, '{0}: I cannot process this vote. See your '
                              'private messages for details.'.format(user))
            self.msg(user, 'I cannot process this vote. I understand two '
                           'voting paradigms:')
            self.msg(user, '1. An absolute list of talks '
                           '(e.g. `{0}, {1}`)'.format(*examples))
            self.msg(user, '2. Two special keywords ("all", "none"), and the '
                           'addition/removal of talks from those keywords or '
                           'from your prior vote (e.g. `all -{1}` or '
                            '`+{0}`).'.format(*examples))
            self.msg(user, 'Your vote mixes these two paradigms together, and '
                           "I don't know how to process that, so as a picky "
                           'robot, I am cowardly giving up.')
            return

        # Sanity check: exclusive modifier votes only make sense if either
        #   1. "all" or "none" is included in the vote -or-
        #   2. the user has voted already
        # If neither of these cases obtains, error out.
        if vote[0] not in ('all', 'none') and user not in self.current_votes:
            self.msg(channel, '{0}: You can only modify your prior vote if '
                              'you have already voted; you have '
                              'not.'.format(user))
            return

        # Sanity check (last one, for now): "all" or "none" only make sense 
        # at the *beginning* of a vote; don't take them at the end.
        if 'all' in vote[1:] or 'none' in vote[1:]:
            self.msg(channel, '{0}: If using "all" or "none" in a complex '
                              'vote, please use them exclusively at the '
                              'beginning.'.format(user))
            return

        # Okay, this is a valid vote with modifiers; parse it from left
        # to right and process each of the modifiers.
        for piece in vote:
            # First, is this "all" or "none"? these are the simplest
            # cases -- either a full set or no set.
            if piece == 'all':
                answer = copy(self.current_group.talk_ids)
            if piece == 'none':
                answer = set()

            # Add or remove votes with operators from the set.
            if piece.startswith('+'):
                talk_id = int(piece[1:])
                answer.add(talk_id)
            if piece.startswith('-'):
                talk_id = int(piece[1:])
                answer.remove(talk_id)

        # Okay, we processed a valid vote without error; set it.
        self.current_votes[user] = answer

Example 28

Project: pyomo
Source File: process_data.py
View license
def _process_param(cmd, _model, _data, _default, index=None, param=None, ncolumns=None):
    """
    Called by _process_data to process data for a Parameter declaration
    """
    #print 'PARAM',cmd,index, ncolumns
    generate_debug_messages = __debug__ and logger.isEnabledFor(logging.DEBUG)
    if generate_debug_messages:
        logger.debug("DEBUG: _process_param(start) %s",cmd)
    #
    # Process parameters
    #
    dflt = None
    singledef = True
    cmd = cmd[1:]
    if cmd[0] == ":":
        singledef = False
        cmd = cmd[1:]
    #print "SINGLEDEF", singledef
    if singledef:
        pname = cmd[0]
        cmd = cmd[1:]
        if len(cmd) >= 2 and cmd[0] == "default":
            dflt = _data_eval(cmd[1])[0]
            cmd = cmd[2:]
        if dflt != None:
            _default[pname] = dflt
        if cmd[0] == ":=":
            cmd = cmd[1:]
        transpose = False
        if cmd[0] == "(tr)":
            transpose = True
            if cmd[1] == ":":
                cmd = cmd[1:]
            else:
                cmd[0] = ":"
        if cmd[0] != ":":
            #print "HERE YYY", pname, transpose, _model, ncolumns
            if generate_debug_messages:
                logger.debug("DEBUG: _process_param (singledef without :...:=) %s",cmd)
            cmd = _apply_templates(cmd)
            #print 'cmd',cmd
            if not transpose:
                if pname not in _data:
                    _data[pname] = {}
                if not ncolumns is None:
                    finaldata = _process_data_list(pname, ncolumns-1, _data_eval(cmd))
                elif not _model is None:
                    _param = getattr(_model, pname)
                    finaldata = _process_data_list(pname, _param.dim(), _data_eval(cmd))
                else:
                    finaldata = _process_data_list(pname, 1, _data_eval(cmd))
                for key in finaldata:
                    _data[pname][key]=finaldata[key]
            else:
                tmp = ["param", pname, ":="]
                i=1
                while i < len(cmd):
                    i0 = i
                    while cmd[i] != ":=":
                        i=i+1
                    ncol = i - i0 + 1
                    lcmd = i
                    while lcmd < len(cmd) and cmd[lcmd] != ":":
                        lcmd += 1
                    j0 = i0 - 1
                    for j in range(1,ncol):
                        ii = 1 + i
                        kk = ii + j
                        while kk < lcmd:
                            if cmd[kk] != ".":
                            #if 1>0:
                                tmp.append(copy.copy(cmd[j+j0]))
                                tmp.append(copy.copy(cmd[ii]))
                                tmp.append(copy.copy(cmd[kk]))
                            ii = ii + ncol
                            kk = kk + ncol
                    i = lcmd + 1
                _process_param(tmp, _model, _data, _default, index=index, param=param, ncolumns=ncolumns)
        else:
            tmp = ["param", pname, ":="]
            if param is None:
                param = [ pname ]
            i=1
            if generate_debug_messages:
                logger.debug("DEBUG: _process_param (singledef with :...:=) %s",cmd)
            while i < len(cmd):
                i0 = i
                while i<len(cmd) and cmd[i] != ":=":
                    i=i+1
                if i==len(cmd):
                    raise ValueError("ERROR: Trouble on line "+str(Lineno)+" of file "+Filename)
                ncol = i - i0 + 1
                lcmd = i
                while lcmd < len(cmd) and cmd[lcmd] != ":":
                    lcmd += 1
                j0 = i0 - 1
                for j in range(1,ncol):
                    ii = 1 + i
                    kk = ii + j
                    while kk < lcmd:
                        if cmd[kk] != ".":
                            if transpose:
                                tmp.append(copy.copy(cmd[j+j0]))
                                tmp.append(copy.copy(cmd[ii]))
                            else:
                                tmp.append(copy.copy(cmd[ii]))
                                tmp.append(copy.copy(cmd[j+j0]))
                            tmp.append(copy.copy(cmd[kk]))
                        ii = ii + ncol
                        kk = kk + ncol
                i = lcmd + 1
                _process_param(tmp, _model, _data, _default, index=index, param=param[0], ncolumns=3)

    else:
        if generate_debug_messages:
            logger.debug("DEBUG: _process_param (cmd[0]=='param:') %s",cmd)
        i=0
        nsets=0
        while i<len(cmd):
            if cmd[i] == ':=':
                i = -1
                break
            if cmd[i] == ":":
                nsets = i
                break
            i += 1
        nparams=0
        _i = i+1
        while i<len(cmd):
            if cmd[i] == ':=':
                nparams = i-_i
                break
            i += 1
        if i==len(cmd):
            raise ValueError("Trouble on data file line "+str(Lineno)+" of file "+Filename)
        if generate_debug_messages:
            logger.debug("NSets %d",nsets)
        Lcmd = len(cmd)
        j=0
        d = 1
        #print "HERE", nsets, nparams
        #
        # Process sets first
        #
        while j<nsets:
            # NOTE: I'm pretty sure that nsets is always equal to 1
            sname = cmd[j]
            if not ncolumns is None:
                d = ncolumns-nparams
            elif _model is None:
                d = 1
            else:
                index = getattr(_model, sname)
                d = index.dimen
            #print "SET",sname,d,_model#,getattr(_model,sname).dimen, type(index)
            #d = getattr(_model,sname).dimen
            np = i-1
            if generate_debug_messages:
                logger.debug("I %d, J %d, SName %s, d %d",i,j,sname,d)
            dnp = d + np - 1
            #k0 = i + d - 2
            ii = i + j + 1
            tmp = [ "set", cmd[j], ":=" ]
            while ii < Lcmd:
                if d > 1:
                    _tmp = []
                    for dd in range(0,d):
                        _tmp.append(copy.copy(cmd[ii+dd]))
                    tmp.append(tuple(_tmp))
                else:
                    for dd in range(0,d):
                        tmp.append(copy.copy(cmd[ii+dd]))
                ii += dnp
            _process_set(tmp, _model, _data)
            j += 1
        if nsets > 0:
            j += 1
        #
        # Process parameters second
        #
        #print "HERE", cmd
        #print "HERE", param
        #print "JI",j,i # XXX
        jstart = j
        if param is None:
            param = []
            _j = j
            while _j < i:
                param.append( cmd[_j] )
                _j += 1
        while j < i:
            #print "HERE", i, j, jstart, cmd[j]
            pname = param[j-jstart]
            if generate_debug_messages:
                logger.debug("I %d, J %d, Pname %s",i,j,pname)
            if not ncolumns is None:
                d = ncolumns - nparams
            elif _model is None:
                d = 1
            else:
                d = getattr(_model, param[j-jstart]).dim()
            if nsets > 0:
                np = i-1
                dnp = d+np-1
                ii = i + 1
                kk = i + d + j-1
            else:
                np = i
                dnp = d + np
                ii = i + 1
                kk = np + 1 + d + nsets + j
            #print cmd[ii], d, np, dnp, ii, kk
            tmp = [ "param", pname, ":=" ]
            if generate_debug_messages:
                logger.debug('dnp %d\nnp %d', dnp, np)
            while kk < Lcmd:
                if generate_debug_messages:
                    logger.debug("kk %d, ii %d",kk,ii)
                iid = ii + d
                while ii < iid:
                    tmp.append(copy.copy(cmd[ii]))
                    ii += 1
                ii += dnp-d
                tmp.append(copy.copy(cmd[kk]))
                kk += dnp
            #print "TMP", tmp, ncolumns-nparams+1
            if not ncolumns is None:
                nc = ncolumns-nparams+1
            else:
                nc = None
            _process_param(tmp, _model, _data, _default, index=index, param=param[j-jstart], ncolumns=nc)
            j += 1

Example 29

Project: mpop
Source File: scene.py
View license
    def project(self, dest_area, channels=None, precompute=False, mode=None,
                radius=None, nprocs=1):
        """Make a copy of the current snapshot projected onto the
        *dest_area*. Available areas are defined in the region configuration
        file (ACPG). *channels* tells which channels are to be projected, and
        if None, all channels are projected and copied over to the return
        snapshot.

        If *precompute* is set to true, the projecting data is saved on disk
        for reusage. *mode* sets the mode to project in: 'quick' which works
        between cartographic projections, and, as its denomination indicates,
        is quick (but lower quality), and 'nearest' which uses nearest
        neighbour for best projection. A *mode* set to None uses 'quick' when
        possible, 'nearest' otherwise.

        *radius* defines the radius of influence for neighbour search in
        'nearest' mode (in metres). Setting it to None, or omitting it will
        fallback to default values (5 times the channel resolution) or 10,000m
        if the resolution is not available.

        Note: channels have to be loaded to be projected, otherwise an
        exception is raised.
        """

        if not is_pyresample_loaded:
            # Not much point in proceeding then
            return self

        _channels = set([])

        if channels is None:
            for chn in self.loaded_channels():
                _channels |= set([chn])

        elif isinstance(channels, (list, tuple, set)):
            for chn in channels:
                try:
                    _channels |= set([self[chn]])
                except KeyError:
                    LOG.warning("Channel " + str(chn) + " not found,"
                                "thus not projected.")
        else:
            raise TypeError("Channels must be a list/"
                            "tuple/set of channel keys!")

        res = copy.copy(self)

        if isinstance(dest_area, str):
            dest_area = mpop.projector.get_area_def(dest_area)

        res.area = dest_area
        res.channels = []

        if not _channels <= self.loaded_channels():
            LOG.warning("Cannot project nonloaded channels: %s.",
                        _channels - self.loaded_channels())
            LOG.info("Will project the other channels though.")
            _channels = _channels and self.loaded_channels()

        cov = {}

        for chn in sorted(_channels, key=lambda x: x.resolution, reverse=True):
            if chn.area is None:
                if self.area is None:
                    area_name = ("swath_" + self.fullname + "_" +
                                 str(self.time_slot) + "_"
                                 + str(chn.shape))
                    chn.area = area_name
                else:
                    if is_pyresample_loaded:
                        try:
                            chn.area = AreaDefinition(
                                self.area.area_id + str(chn.shape),
                                self.area.name,
                                self.area.proj_id,
                                self.area.proj_dict,
                                chn.shape[1],
                                chn.shape[0],
                                self.area.area_extent,
                                self.area.nprocs)

                        except AttributeError:
                            try:
                                dummy = self.area.lons
                                dummy = self.area.lats
                                chn.area = self.area
                                area_name = ("swath_" + self.fullname + "_" +
                                             str(self.time_slot) + "_"
                                             + str(chn.shape))
                                chn.area.area_id = area_name
                            except AttributeError:
                                chn.area = self.area + str(chn.shape)
                    else:
                        chn.area = self.area + str(chn.shape)
            else:  # chn.area is not None
                # if (is_pyresample_loaded and
                #     (not hasattr(chn.area, "area_id") or
                #      not chn.area.area_id)):
                #     area_name = ("swath_" + self.fullname + "_" +
                #                  str(self.time_slot) + "_"
                #                  + str(chn.shape) + "_"
                #                  + str(chn.name))
                #     chn.area.area_id = area_name

                # This leaks memory !
                #LOG.debug("chn.area = " + str(chn.area))
                #LOG.debug("type(chn.area) = " + str(type(chn.area)))
                if is_pyresample_loaded:
                    area_name = ("swath_" + self.fullname + "_" +
                                 str(self.time_slot) + "_" +
                                 str(chn.shape) + "_" +
                                 str(chn.name))
                    LOG.debug("pyresample is loaded... area-name = " +
                              str(area_name))
                    if hasattr(chn.area, "area_id") and not chn.area.area_id:
                        LOG.debug("chn.area has area_id attribute...")
                        chn.area.area_id = area_name
                    elif not hasattr(chn.area, "area_id") and not isinstance(chn.area, str):
                        setattr(chn.area, 'area_id', area_name)

            if isinstance(chn.area, str):
                area_id = chn.area
            else:
                area_id = chn.area_id or chn.area.area_id

            if area_id not in cov:
                if radius is None:
                    if chn.resolution > 0:
                        radius = 5 * chn.resolution
                    else:
                        radius = 10000
                cov[area_id] = mpop.projector.Projector(chn.area,
                                                        dest_area,
                                                        mode=mode,
                                                        radius=radius,
                                                        nprocs=nprocs)
                if precompute:
                    try:
                        cov[area_id].save()
                    except IOError:
                        LOG.exception("Could not save projection.")

            try:
                res.channels.append(chn.project(cov[area_id]))
            except NotLoadedError:
                LOG.warning("Channel " + str(chn.name) + " not loaded, "
                            "thus not projected.")

        # Compose with image object
        try:
            if res._CompositerClass is not None:
                # Pass weak ref to compositor to allow garbage collection
                res.image = res._CompositerClass(weakref.proxy(res))
        except AttributeError:
            pass

        return res

Example 30

Project: qutip
Source File: heom.py
View license
    def configure(self, H_sys, coup_op, coup_strength, temperature,
                     N_cut, N_exp, cut_freq, planck=None, boltzmann=None,
                     renorm=None, bnd_cut_approx=None,
                     options=None, progress_bar=None, stats=None):
        """
        Calls configure from :class:`HEOMSolver` and sets any attributes
        that are specific to this subclass
        """
        start_config = timeit.default_timer()

        HEOMSolver.configure(self, H_sys, coup_op, coup_strength,
                    temperature, N_cut, N_exp,
                    planck=planck, boltzmann=boltzmann,
                    options=options, progress_bar=progress_bar, stats=stats)
        self.cut_freq = cut_freq
        if renorm is not None: self.renorm = renorm
        if bnd_cut_approx is not None: self.bnd_cut_approx = bnd_cut_approx

        # Load local values for optional parameters
        # Constants and Hamiltonian.
        hbar = self.planck
        options = self.options
        progress_bar = self.progress_bar
        stats = self.stats


        if stats:
            ss_conf = stats.sections.get('config')
            if ss_conf is None:
                ss_conf = stats.add_section('config')

        c, nu = self._calc_matsubara_params()

        if renorm:
            norm_plus, norm_minus = self._calc_renorm_factors()
            if stats:
                stats.add_message('options', 'renormalisation', ss_conf)
        # Dimensions et by system
        sup_dim = H_sys.dims[0][0]**2
        unit_sys = qeye(H_sys.dims[0])

        # Use shorthands (mainly as in referenced PRL)
        lam0 = self.coup_strength
        gam = self.cut_freq
        N_c = self.N_cut
        N_m = self.N_exp
        Q = coup_op # Q as shorthand for coupling operator
        beta = 1.0/(self.boltzmann*self.temperature)

        # Ntot is the total number of ancillary elements in the hierarchy
        # Ntot = factorial(N_c + N_m) / (factorial(N_c)*factorial(N_m))
        # Turns out to be the same as nstates from state_number_enumerate
        N_he, he2idx, idx2he = enr_state_dictionaries([N_c + 1]*N_m , N_c)

        unit_helems = sp.identity(N_he, format='csr')
        if self.bnd_cut_approx:
            # the Tanimura boundary cut off operator
            if stats:
                stats.add_message('options', 'boundary cutoff approx', ss_conf)
            op = -2*spre(Q)*spost(Q.dag()) + spre(Q.dag()*Q) + spost(Q.dag()*Q)

            approx_factr = ((2*lam0 / (beta*gam*hbar)) - 1j*lam0) / hbar
            for k in range(N_m):
                approx_factr -= (c[k] / nu[k])
            L_bnd = -approx_factr*op.data
            L_helems = sp.kron(unit_helems, L_bnd)
        else:
            L_helems = sp.csr_matrix((N_he*sup_dim, N_he*sup_dim),
                                     dtype=complex)

        # Build the hierarchy element interaction matrix
        if stats: start_helem_constr = timeit.default_timer()

        unit_sup = spre(unit_sys).data
        spreQ = spre(Q).data
        spostQ = spost(Q).data
        commQ = (spre(Q) - spost(Q)).data
        N_he_interact = 0

        for he_idx in range(N_he):
            he_state = list(idx2he[he_idx])
            n_excite = sum(he_state)

            # The diagonal elements for the hierarchy operator
            # coeff for diagonal elements
            sum_n_m_freq = 0.0
            for k in range(N_m):
                sum_n_m_freq += he_state[k]*nu[k]

            op = -sum_n_m_freq*unit_sup
            L_he = _pad_csr(op, N_he, N_he, he_idx, he_idx)
            L_helems += L_he

            # Add the neighour interations
            he_state_neigh = copy(he_state)
            for k in range(N_m):

                n_k = he_state[k]
                if n_k >= 1:
                    # find the hierarchy element index of the neighbour before
                    # this element, for this Matsubara term
                    he_state_neigh[k] = n_k - 1
                    he_idx_neigh = he2idx[tuple(he_state_neigh)]

                    op = c[k]*spreQ - np.conj(c[k])*spostQ
                    if renorm:
                        op = -1j*norm_minus[n_k, k]*op
                    else:
                        op = -1j*n_k*op

                    L_he = _pad_csr(op, N_he, N_he, he_idx, he_idx_neigh)
                    L_helems += L_he
                    N_he_interact += 1

                    he_state_neigh[k] = n_k

                if n_excite <= N_c - 1:
                    # find the hierarchy element index of the neighbour after
                    # this element, for this Matsubara term
                    he_state_neigh[k] = n_k + 1
                    he_idx_neigh = he2idx[tuple(he_state_neigh)]

                    op = commQ
                    if renorm:
                        op = -1j*norm_plus[n_k, k]*op
                    else:
                        op = -1j*op

                    L_he = _pad_csr(op, N_he, N_he, he_idx, he_idx_neigh)
                    L_helems += L_he
                    N_he_interact += 1

                    he_state_neigh[k] = n_k

        if stats:
            stats.add_timing('hierarchy contruct',
                             timeit.default_timer() - start_helem_constr,
                            ss_conf)
            stats.add_count('Num hierarchy elements', N_he, ss_conf)
            stats.add_count('Num he interactions', N_he_interact, ss_conf)

        # Setup Liouvillian
        if stats: start_louvillian = timeit.default_timer()
        H_he = sp.kron(unit_helems, liouvillian(H_sys).data)

        L_helems += H_he

        if stats:
            stats.add_timing('Liouvillian contruct',
                             timeit.default_timer() - start_louvillian,
                            ss_conf)

        if stats: start_integ_conf = timeit.default_timer()

        r = scipy.integrate.ode(cy_ode_rhs)

        r.set_f_params(L_helems.data, L_helems.indices, L_helems.indptr)
        r.set_integrator('zvode', method=options.method, order=options.order,
                         atol=options.atol, rtol=options.rtol,
                         nsteps=options.nsteps, first_step=options.first_step,
                         min_step=options.min_step, max_step=options.max_step)

        if stats:
            time_now = timeit.default_timer()
            stats.add_timing('Liouvillian contruct',
                             time_now - start_integ_conf,
                            ss_conf)
            if ss_conf.total_time is None:
                ss_conf.total_time = time_now - start_config
            else:
                ss_conf.total_time += time_now - start_config

        self._ode = r
        self._N_he = N_he
        self._sup_dim = sup_dim
        self._configured = True

Example 31

Project: bep
Source File: run.py
View license
def main(): # needs to be done as a main func for setuptools to work correctly in creating an executable
    # for the approach i am taking here using nested subparsers:
    # https://mail.python.org/pipermail/python-list/2010-August/585617.html

    # nargs options:
    # (default): by not specifying nargs at all, you just get a string of 1 item
    # = N   where N is some specified number of args
    # = '?' makes a string of one item, and if no args are given, then default is used.
    # = '*' makes a list of all args passed after command and if no args given, then default is used.
    # = '+' makes list of all args passed after command, but requires at least one arg

    top_parser = argparse.ArgumentParser(description=name.upper(),
                            formatter_class=argparse.RawDescriptionHelpFormatter,
                            #formatter_class=argparse.RawTextHelpFormatter,
                            #add_help=False,
                            epilog=usage.epilog_use)

    #################################
    ### this goes at the top level
    top_parser.add_argument('--version', action='version', version='%(prog)s {}'.format(__version__))
    top_parser.add_argument('-l', '--language', nargs='?', default='python', help=usage.lang_use)

    group = top_parser.add_mutually_exclusive_group()
    group.add_argument("-v", "--verbose", action="store_true", help=usage.verbose_use)
    group.add_argument("-q", "--quiet", action="store_true", help=usage.quiet_use)
    #################################


    def check_for_all_error(cmd_arg):
        if cmd_arg in ['all', 'All', 'ALL', '--All', '--ALL']:
            raise SystemExit("\nError: Did you mean to specifiy --all instead?")


    # If --all is passed in:
    # Skip stuff below if '--all' is specified w/ one of these accepted cmds
    # (this is some seriously hacky brute force shit!)
    build_up_subparsers = True
    additional_args = []
    cmds_that_accept_all_arg = ['update', 'remove', 'turn_off']
    for cmd in cmds_that_accept_all_arg:
        if cmd in sys.argv:
            for i in sys.argv:  # test for misspecified '--all' command
                check_for_all_error(i)
            if '--all' in sys.argv:
                #print(sys.argv)
                build_up_subparsers = False
                                                                            # TODO add help page for all
                top_parser.add_argument('--all', action='store_true', help=usage.all_use) #metavar="arg")
                args = top_parser.parse_known_args()
                args, additional_args = args
                if len(additional_args) > 1:    # this makes it so that it could only be len(additional_args)==1
                    error_all_arg = "--all can only be called with one of the following args:\n\t"
                    error_all_arg = error_all_arg + '{update, remove, turn_off}'
                    top_parser.error(error_all_arg)
                #else:
                    #additional_args = additional_args[0]


    # To display how to run a command:
    # look at all pkgs and check that passed in package name is one that's already installed
    everything_already_installed = utils.all_pkgs_and_branches_for_all_pkg_types_already_installed(installed_pkgs_dir)
    any_of_this_pkg_already_installed = lambda pkg_to_process: utils.lang_and_pkg_type_and_pkg_and_branches_tuple(
                                                                        pkg_to_process, everything_already_installed)
    cmds_that_can_display_how_to = cmds_that_accept_all_arg + ['turn_on']
    for cmd in cmds_that_can_display_how_to:    # everything except install i think
        if (cmd in sys.argv) and ('--all' not in sys.argv):
            if ('-h' not in sys.argv) and ('--help' not in sys.argv):
                args = top_parser.parse_known_args()
                args, additional_args = args
                if len(additional_args) == 2:
                    additional_args_copy = copy.copy(additional_args)
                    additional_args_copy.remove(cmd) # 2 things in here, one equal to cmd, the other is what we want to see if it's alreay installed
                    potential_pkg_to_proc = additional_args_copy[0]

                    #print any_of_this_pkg_already_installed(potential_pkg_to_proc)
                    if any_of_this_pkg_already_installed(potential_pkg_to_proc):
                        # should i make a function call out of this instead of relying on the command to be handled below?
                        print(" **** This is how to {} {} ****".format(cmd, potential_pkg_to_proc))
                        build_up_subparsers = False
                    elif potential_pkg_to_proc not in possible_choices:   # else if the other arg/package name passed in is not a pkg_already_installed (& not one of the next possible cmd options)
                        #print an error say that whatever is passed in cannot be updated/turned_on/etc
                        #b/c it's not currently installed.
                        error_msg = "cannot {} {}: not a currently installed package.\n".format(cmd, potential_pkg_to_proc)
                        error_msg = error_msg + "[Execute `{} list` to see installed packages.]".format(name)
                        top_parser.error(error_msg)
                    #else:   # want this instead b/c otherwise the above hides the help pages
                        #additional_args = []     # set back to empty to avoid the flag at the end of argparse stuff
                #else:
                    #error_msg = "An already installed package name must be passed in with {}".format(cmd)
                    #top_parser.error(error_msg)
                else:
                    additional_args = []     # set back to empty to avoid the flag at the end of argparse stuff


    if build_up_subparsers:
        top_subparser = top_parser.add_subparsers(title='Commands',
                                        description='[ These are the commands that can be passed to %(prog)s ]',
                                        #help=usage.subparser_use)
                                        help='[ Command specific help info ]')
        ### create parser for the "list" command
        # maybe make it so that it can list all branches installed for a specific pkg,
        parser_list = top_subparser.add_parser('list', help=usage.list_use)
        parser_list.add_argument('list_arg', action="store_true", help=usage.list_sub_use) #metavar="arg")


        class CheckIfCanBeInstalled(argparse.Action):
            ''' makes sure a repo to install has both a user_name and repo_name:
                    eg. ipython/ipython
                or is an actual path to a repo on the local filesystem'''

            def __call__(self, parser, namespace, arg_value, option_string=None):
                pkg_type = parser.prog.split(' ')[-1]
                if utils.check_if_valid_pkg_to_install(arg_value, pkg_type):
                    setattr(namespace, self.dest, arg_value)
                else:
                    if pkg_type == 'local':
                        error_msg = "\n\tIs not a path that exists on local filesystem."
                        raise parser.error(arg_value + error_msg)
                    else:
                        error_msg = '\nneed to make sure a username and repo_name are specified, like so:\n\tusername/repo_name'
                        raise parser.error(arg_value + error_msg)


        ##################################################
        cmd_help = vars(usage.cmd_help)
        for cmd in ['install', 'update', 'remove', 'turn_off', 'turn_on']:
            if cmd == 'install':
                install_parser = top_subparser.add_parser(cmd, help=usage.install_use.format(packages_file),
                                                          formatter_class=argparse.RawTextHelpFormatter)
                install_parser.set_defaults(top_subparser=cmd)
                install_subparser = install_parser.add_subparsers(dest='pkg_type', help=usage.install_sub_use.format(packages_file))
                for c in repo_choices:
                    pkg_type_to_install = install_subparser.add_parser(c)
                    # pkg_type_to_install.set_defaults(pkg_type_to_install=c) # is the same as 'pkg_type' dest above

                    pkg_type_to_install.add_argument('pkg_to_install',   # like ipython/ipython
                                                     action=CheckIfCanBeInstalled)   # actions here to make sure it's legit

                    # local repos don't get to have a branch specified; a branch would need to be checked out first, then installed.
                    #if c != 'local':
                        #pkg_type_to_install.add_argument('-b', '--branch', dest='branch', default=None)#, action=CheckBranch)    # the branch bit is filled out below

                    if c == 'github':
                        pkg_type_to_install.add_argument('repo_type', default='git', nargs='?')

                    elif c == 'bitbucket':
                        pkg_type_to_install.add_argument('repo_type', choices=['git', 'hg'])

                    # elif c == 'local':    # just get the type of repo from the local filesystem so it doesn't have to be specified
                        # pkg_type_to_install.add_argument('repo_type', choices=['git', 'hg', 'bzr'])

                    #elif c == 'remote':    # TODO not implemented but would be specified like so
                        #pkg_type_to_install.add_argument('repo_type', choices=['git', 'hg', 'bzr'])

                    pkg_type_to_install.add_argument('-b', '--branch', dest='branch', default=None)#, action=CheckBranch)    # the branch bit is filled out below

                for c in other_choices:
                    if c == 'packages':
                        pkg_type_to_install = install_subparser.add_parser(c, help=usage.packages_file_use.format(packages_file))

                    #elif c == 'stable': # TODO not implemented
                        #pkg_type_to_install = install_subparser.add_parser(c)
                        #pkg_type_to_install.add_argument('pkg_to_install')  # like ipython
                        ##pkg_type_to_install.add_argument('--pversion')      # TODO like 1.2.1 (add this in later to install different version of a stable pkg)

                # NOTE this seems like a better way to go in the future:
                # install_parser.set_defaults(func=run_install)
                # then run_install would be defined to run the install process (rather than having the conditionals below)
                # def run_install(args):
                #   install_arg = args.install_arg  # would be a list of pkgs or a string of the packages file
                #   ...process the install_arg to decide what to install
                #   ...then do the install
                ##################################################
            else:
                subparser_parser = top_subparser.add_parser(cmd, help=cmd_help['{}_use'.format(cmd)],
                                                            formatter_class=argparse.RawTextHelpFormatter)
                subparser_parser.set_defaults(top_subparser=cmd)

                ### didn't work, not sure why yet
                #all_dest = '{}_ALL'.format(cmd)
                #subparser_parser.add_argument('--all',
                                                ##help=usage.remove_sub_use.format(name=name),    # FIXME not sure why this wouldn't work
                                                ##action=CheckIfALL, action='store_true')

                #cur_args = vars(top_parser.parse_args())
                #print(cur_args)
                #if 'all' in cur_args:
                    #if cur_args['all']:
                        #break
                this_cmds_help = cmd_help['{}_sub_use'.format(cmd)].format(name=name)
                subparsers_subparser = subparser_parser.add_subparsers(dest='pkg_type', help=this_cmds_help)

                for c in repo_choices:
                    pkg_type_to_proc = subparsers_subparser.add_parser(c)
                    pkg_type_to_proc.add_argument('pkg_to_{}'.format(cmd))   # like ipython
                    pkg_type_to_proc.add_argument('-b', '--branch', dest='branch', default=None)  # needs to be specified in script (for installs though it use default name if not specified)

                #for c in other_choices: #TODO
                    ##if c == 'packages':    # packages args only used for installs
                        ##pkg_type_to_proc = subparsers_subparser.add_parser(c)
                    #if c == 'stable':
                        #pkg_type_to_proc = subparsers_subparser.add_parser(c)
                        #pkg_type_to_proc.add_argument('pkg_to_{}'.format(cmd))  # like ipython
                        #pkg_type_to_proc.add_argument('--pversion', help='package version')      # like 1.2.1 (default should be the newest, but can specify older ones)
            ##################################################



        args = top_parser.parse_args()

        # handle branches here
        if ('top_subparser' in args) and (args.top_subparser == 'install'):
            if ('branch' in args) and (args.branch == None):
                if args.pkg_type == 'local':    # for local, grab the currently checked out branch from the repo and set that as the branch to install
                    branch, repo_type = utils.get_checked_out_local_branch(args.pkg_to_install)
                    args.repo_type = repo_type
                else:
                    branch = utils.get_default_branch(args.repo_type)
                args.branch = branch
            elif ('branch' in args) and (args.branch != None):
                if args.pkg_type == 'local':    # for local, don't allow branch to be specified; just use currently checked out branch
                    error_msg = "for `local` packages a branch cannot be specified;\n"
                    error_msg = error_msg + "check out the desired branch from the repo itself, then install."
                    raise top_parser.error(error_msg)
        elif ('top_subparser' in args) and (args.top_subparser != 'install'):
            if ('branch' in args) and (args.branch == None):
                error_msg = 'need to make sure a branch is specified;\n'
                error_msg = error_msg + "[Execute `{} list` to see installed packages and branches.]".format(name)
                raise top_parser.error(error_msg)


    class noise(object):
        verbose = args.verbose
        quiet = args.quiet


    """
    # REMOVE LATER...this just shows what we're dealing with here
    print('##########################################################')
    print(args)
    if additional_args:
        print(additional_args)
    print('##########################################################')
    #raise SystemExit
    """

    #--------------------------------------------------------------------------------------------------------------

    if noise.quiet:
        print('-'*60)



    #######################################################################################################################
    #### install pkg(s)
    kwargs = dict(packages_file=packages_file, packages_file_path=packages_file_path,
                 noise=noise, install_dirs=install_dirs, installed_pkgs_dir=installed_pkgs_dir)

    if ('top_subparser' in args) and (args.top_subparser == 'install'):
        any_pkgs_processed = install.install_cmd(args, **kwargs)
    #######################################################################################################################



    #######################################################################################################################
    #### if nothing is installed, then don't continue on to other commands (since they only process currently installed stuff)
    everything_already_installed = utils.all_pkgs_and_branches_for_all_pkg_types_already_installed(installed_pkgs_dir)
    if not everything_already_installed:
        raise SystemExit('\nNo packages installed.')
    #######################################################################################################################



    #######################################################################################################################
    #### list installed pkg(s) (by each package type)
    elif 'list_arg' in args:
        list_packages.list_cmd(everything_already_installed, noise)
    #######################################################################################################################



    #######################################################################################################################
    # for everything else (update, remove, turn_on/off)
    #elif args:
    #elif ((('top_subparser' in args) and (args.top_subparser in ['update', 'remove', 'turn_on', 'turn_off'])) or
         #(('update' in additional_args) or ('remove' in additional_args) or ('turn_off' in additional_args) or
          #('turn_on' in additional_args))):
    else:   # FIXME not sure this is as good as it could be by just using else instead of something more specific

        actions_to_take = {}
        #top_level_any_pkgs_processed = False
        for lang_dir_name, pkg_type_dict in everything_already_installed.items():
            for pkg_type, pkgs_and_branches in pkg_type_dict.items():
                any_pkgs_processed = False
                #if pkgs_and_branches:  # don't think i need this

                pkgs_status = utils.pkgs_and_branches_for_pkg_type_status(pkgs_and_branches)
                pkgs_and_branches_on = pkgs_status['pkg_branches_on']
                pkgs_and_branches_off = pkgs_status['pkg_branches_off']

                kwargs = dict(lang_dir_name=lang_dir_name, pkg_type=pkg_type, noise=noise, install_dirs=install_dirs,
                            pkgs_and_branches_on=pkgs_and_branches_on, pkgs_and_branches_off=pkgs_and_branches_off,
                            additional_args=additional_args, everything_already_installed=everything_already_installed)


                if ('pkg_to_update' in args) or ('update' in additional_args):
                    any_pkgs_processed = update_packages.update_cmd(args, **kwargs)

                elif ('pkg_to_remove' in args) or ('remove' in additional_args):
                    any_pkgs_processed = remove_packages.remove_cmd(args, **kwargs)

                elif ('pkg_to_turn_off' in args) or ('turn_off' in additional_args):
                    any_pkgs_processed = turn_off.turn_off_cmd(args, **kwargs)

                elif ('pkg_to_turn_on' in args) or ('turn_on' in additional_args):
                    any_pkgs_processed = turn_on.turn_on_cmd(args, **kwargs)


                if any_pkgs_processed:
                    #top_level_any_pkgs_processed = True #+= 1
                    if type(any_pkgs_processed) == dict:    # it will be a dict when a pkg didn't actually get processed, but has commands to get processed
                        actions_to_take.update(any_pkgs_processed)

        #if not top_level_any_pkgs_processed: # NOTE KEEP for now, but i don't think this will ever get hit?
            #utils.when_not_quiet_mode('\n[ No action performed ]'.format(pkg_type), noise.quiet)


        if actions_to_take:

            if len(actions_to_take) == 1:
                alert, cmd = actions_to_take.items()[0]
                option = '\n* {}\n{}\n'.format(alert, cmd)
                print(option)

                if not (cmd.startswith('****') and cmd.endswith('****')):

                    print('-'*60)
                    msg = "The above version is installed, would you like to run the\ncommand [y/N]? "
                    response = raw_input(msg)
                    if response:
                        response = response.lower()
                        if response in ['y', 'yes']:
                            utils.cmd_output(cmd)
                        elif response in ['n', 'no']:
                            print("\nBye then.")
                        else:
                            raise SystemExit("\nError: {}: not valid input".format(response))
                    else:
                        print("\nOk, bye then.")


            elif len(actions_to_take) > 1:

                actions_to_take_with_num_keys = {}  # takes the alert, cmd (key, val) pairs from actions_to_take and makes them as a value tuple, w/ a num as each pair's key.
                for num, alert_key in enumerate(actions_to_take, start=1): # actions_to_take is a dict with alert, cmd (key, val) pairs
                    actions_to_take_with_num_keys[num] = (alert_key, actions_to_take[alert_key])
                actions_to_take_with_num_keys = OrderedDict(sorted(actions_to_take_with_num_keys.items(), key=lambda t: t[0]))  # sorted by key (which are nums)

                for num_key, alert_and_cmd_tuple_val in actions_to_take_with_num_keys.items():
                    if num_key == 1:
                        print('')
                    alert, cmd =  alert_and_cmd_tuple_val
                    option = '{}. {}\n{}\n'.format(num_key, alert, cmd)
                    print(option)

                print('-'*60)
                msg = "The versions above are installed.  If you'd like to run the command\n"
                msg = msg + "for an item, enter the number (if not, then just hit enter to exit). "
                response = raw_input(msg)
                if response:
                    try:
                        response = int(response)
                    except ValueError:
                        raise SystemExit("\nError: invalid response: {}".format(response))
                    if response in range(1, len(actions_to_take_with_num_keys)+1):
                        #print response # now run the command
                        # Could either 1. open a subprocess and run from the command line -- easy way
                        # or 2. try to pass back into the the command that got us here -- better way

                        # Number 2 would involve something like this with updating the kwargs:
                        #kwargs = dict(lang_dir_name=lang_dir_name, pkg_type=pkg_type, noise=noise, install_dirs=install_dirs,
                                    #pkgs_and_branches_on=pkgs_and_branches_on, pkgs_and_branches_off=pkgs_and_branches_off,
                                    #additional_args=additional_args, everything_already_installed=everything_already_installed)
                        #actions.update_action(args, **kwargs)

                        # Doing number 1 above, just to get it working, though 2 would probably be better in long run.
                        cmd = actions_to_take_with_num_keys[response][1]    # this gets the command from the alert, cmd tuple
                        if (cmd.startswith('****') and cmd.endswith('****')):
                            print("\nNo command to process,\n{}".format(cmd))
                        else:
                            utils.cmd_output(cmd)
                    else:
                        raise SystemExit("\nError: invalid response: {}".format(response))
                else:
                    print("\nOk, bye then.")

Example 32

Project: rst2pdf
Source File: pdfbuilder.py
View license
    def assemble_doctree(self, docname, title, author, appendices):

        # FIXME: use the new inline_all_trees from Sphinx.
        # check how the LaTeX builder does it.

        self.docnames = set([docname])
        self.info(darkgreen(docname) + " ", nonl=1)
        def process_tree(docname, tree):
            tree = tree.deepcopy()
            for toctreenode in tree.traverse(addnodes.toctree):
                newnodes = []
                includefiles = list(map(str, toctreenode['includefiles']))
                for includefile in includefiles:
                    try:
                        self.info(darkgreen(includefile) + " ", nonl=1)
                        subtree = process_tree(includefile,
                        self.env.get_doctree(includefile))
                        self.docnames.add(includefile)
                    except Exception:
                        self.warn('%s: toctree contains ref to nonexisting file %r'\
                                                     % (docname, includefile))
                    else:
                        sof = addnodes.start_of_file(docname=includefile)
                        sof.children = subtree.children
                        newnodes.append(sof)
                toctreenode.parent.replace(toctreenode, newnodes)
            return tree

        tree = self.env.get_doctree(docname)
        tree = process_tree(docname, tree)

        self.docutils_languages = {}
        if self.config.language:
            self.docutils_languages[self.config.language] = \
                get_language_available(self.config.language)[2]

        if self.opts.get('pdf_use_index',self.config.pdf_use_index):
            # Add index at the end of the document

            # This is a hack. create_index creates an index from
            # ALL the documents data, not just this one.
            # So, we preserve a copy, use just what we need, then
            # restore it.
            #from pudb import set_trace; set_trace()
            t=copy(self.env.indexentries)
            try:
                self.env.indexentries={docname:self.env.indexentries[docname+'-gen']}
            except KeyError:
                self.env.indexentries={}
                for dname in self.docnames:
                    self.env.indexentries[dname]=t.get(dname,[])
            genindex = self.env.create_index(self)
            self.env.indexentries=t
            # EOH (End Of Hack)

            if genindex: # No point in creating empty indexes
                index_nodes=genindex_nodes(genindex)
                tree.append(nodes.raw(text='OddPageBreak twoColumn', format='pdf'))
                tree.append(index_nodes)

        # This is stolen from the HTML builder's prepare_writing function
        self.domain_indices = []
        # html_domain_indices can be False/True or a list of index names
        indices_config = self.config.pdf_domain_indices
        if indices_config and hasattr(self.env, 'domains'):
            for domain in self.env.domains.values():
                for indexcls in domain.indices:
                    indexname = '%s-%s' % (domain.name, indexcls.name)
                    if isinstance(indices_config, list):
                        if indexname not in indices_config:
                            continue
                    # deprecated config value
                    if indexname == 'py-modindex' and \
                           not self.config.pdf_use_modindex:
                        continue
                    content, collapse = indexcls(domain).generate()
                    if content:
                        self.domain_indices.append(
                            (indexname, indexcls, content, collapse))

        # self.domain_indices contains a list of indices to generate, like
        # this:
        # [('py-modindex',
        #    <class 'sphinx.domains.python.PythonModuleIndex'>,
        #   [(u'p', [[u'parrot', 0, 'test', u'module-parrot', 'Unix, Windows',
        #   '', 'Analyze and reanimate dead parrots.']])], True)]

        # Now this in the HTML builder is passed onto write_domain_indices.
        # We handle it right here

        for indexname, indexcls, content, collapse in self.domain_indices:
            indexcontext = dict(
                indextitle = indexcls.localname,
                content = content,
                collapse_index = collapse,
            )
            # In HTML this is handled with a Jinja template, domainindex.html
            # We have to generate docutils stuff right here in the same way.
            self.info(' ' + indexname, nonl=1)
            print()

            output=['DUMMY','=====','',
                    '.. _modindex:\n\n']
            t=indexcls.localname
            t+='\n'+'='*len(t)+'\n'
            output.append(t)

            for letter, entries in content:
                output.append('.. cssclass:: heading4\n\n%s\n\n'%letter)
                for (name, grouptype, page, anchor,
                    extra, qualifier, description) in entries:
                    if qualifier:
                        q = '[%s]'%qualifier
                    else:
                        q = ''

                    if extra:
                        e = '(%s)'%extra
                    else:
                        e = ''
                    output.append ('`%s <#%s>`_ %s %s'%(name, anchor, e, q))
                    output.append('    %s'%description)
                output.append('')

            dt = docutils.core.publish_doctree('\n'.join(output))[1:]
            dt.insert(0,nodes.raw(text='OddPageBreak twoColumn', format='pdf'))
            tree.extend(dt)


        if appendices:
            tree.append(nodes.raw(text='OddPageBreak %s'%self.page_template, format='pdf'))
            self.info()
            self.info('adding appendixes...', nonl=1)
            for docname in appendices:
                self.info(darkgreen(docname) + " ", nonl=1)
                appendix = self.env.get_doctree(docname)
                appendix['docname'] = docname
                tree.append(appendix)
            self.info('done')

        self.info()
        self.info("resolving references...")
        #print tree
        #print '--------------'
        self.env.resolve_references(tree, docname, self)
        #print tree

        for pendingnode in tree.traverse(addnodes.pending_xref):
            # This needs work, need to keep track of all targets
            # so I don't replace and create hanging refs, which
            # crash
            if pendingnode.get('reftarget',None) == 'genindex'\
                and self.config.pdf_use_index:
                pendingnode.replace_self(nodes.reference(text=pendingnode.astext(),
                    refuri=pendingnode['reftarget']))
            # FIXME: probably need to handle dangling links to domain-specific indexes
            else:
                # FIXME: This is from the LaTeX builder and I still don't understand it
                # well, and doesn't seem to work

                # resolve :ref:s to distant tex files -- we can't add a cross-reference,
                # but append the document name
                docname = pendingnode['refdocname']
                sectname = pendingnode['refsectname']
                newnodes = [nodes.emphasis(sectname, sectname)]
                for subdir, title in self.titles:
                    if docname.startswith(subdir):
                        newnodes.append(nodes.Text(_(' (in '), _(' (in ')))
                        newnodes.append(nodes.emphasis(title, title))
                        newnodes.append(nodes.Text(')', ')'))
                        break
                else:
                    pass
                pendingnode.replace_self(newnodes)
            #else:
                #pass
        return tree

Example 33

Project: rst2pdf
Source File: styles.py
View license
    def __init__(self, flist, font_path=None, style_path=None, def_dpi=300):
        log.info('Using stylesheets: %s' % ','.join(flist))
        # find base path
        if hasattr(sys, 'frozen'):
            self.PATH = abspath(dirname(sys.executable))
        else:
            self.PATH = abspath(dirname(__file__))

        # flist is a list of stylesheet filenames.
        # They will be loaded and merged in order.
        # but the two default stylesheets will always
        # be loaded first
        flist = [join(self.PATH, 'styles', 'styles.style'),
                join(self.PATH, 'styles', 'default.style')] + flist

        self.def_dpi=def_dpi
        if font_path is None:
            font_path=[]
        font_path+=['.', os.path.join(self.PATH, 'fonts')]
        self.FontSearchPath = list(map(os.path.expanduser, font_path))

        if style_path is None:
            style_path=[]
        style_path+=['.', os.path.join(self.PATH, 'styles'),
                      '~/.rst2pdf/styles']
        self.StyleSearchPath = list(map(os.path.expanduser, style_path))
        self.FontSearchPath=list(set(self.FontSearchPath))
        self.StyleSearchPath=list(set(self.StyleSearchPath))

        log.info('FontPath:%s'%self.FontSearchPath)
        log.info('StylePath:%s'%self.StyleSearchPath)

        findfonts.flist = self.FontSearchPath
        # Page width, height
        self.pw = 0
        self.ph = 0

        # Page size [w,h]
        self.ps = None

        # Margins (top,bottom,left,right,gutter)
        self.tm = 0
        self.bm = 0
        self.lm = 0
        self.rm = 0
        self.gm = 0

        #text width
        self.tw = 0

        # Default emsize, later it will be the fontSize of the base style
        self.emsize=10

        self.languages = []

        ssdata = self.readSheets(flist)

        # Get pageSetup data from all stylessheets in order:
        self.ps = pagesizes.A4
        self.page={}
        for data, ssname in ssdata:
            page = data.get('pageSetup', {})
            if page:
                self.page.update(page)
                pgs=page.get('size', None)
                if pgs: # A standard size
                    pgs=pgs.upper()
                    if pgs in pagesizes.__dict__:
                        self.ps = list(pagesizes.__dict__[pgs])
                        self.psname = pgs
                        if 'width' in self.page: del(self.page['width'])
                        if 'height' in self.page: del(self.page['height'])
                    elif pgs.endswith('-LANDSCAPE'):
                        self.psname = pgs.split('-')[0]
                        self.ps = list(pagesizes.landscape(pagesizes.__dict__[self.psname]))
                        if 'width' in self.page: del(self.page['width'])
                        if 'height' in self.page: del(self.page['height'])
                    else:
                        log.critical('Unknown page size %s in stylesheet %s'%\
                            (page['size'], ssname))
                        continue
                else: #A custom size
                    if 'size'in self.page:
                        del(self.page['size'])
                    # The sizes are expressed in some unit.
                    # For example, 2cm is 2 centimeters, and we need
                    # to do 2*cm (cm comes from reportlab.lib.units)
                    if 'width' in page:
                        self.ps[0] = self.adjustUnits(page['width'])
                    if 'height' in page:
                        self.ps[1] = self.adjustUnits(page['height'])
                self.pw, self.ph = self.ps
                if 'margin-left' in page:
                    self.lm = self.adjustUnits(page['margin-left'])
                if 'margin-right' in page:
                    self.rm = self.adjustUnits(page['margin-right'])
                if 'margin-top' in page:
                    self.tm = self.adjustUnits(page['margin-top'])
                if 'margin-bottom' in page:
                    self.bm = self.adjustUnits(page['margin-bottom'])
                if 'margin-gutter' in page:
                    self.gm = self.adjustUnits(page['margin-gutter'])
                if 'spacing-header' in page:
                    self.ts = self.adjustUnits(page['spacing-header'])
                if 'spacing-footer' in page:
                    self.bs = self.adjustUnits(page['spacing-footer'])
                if 'firstTemplate' in page:
                    self.firstTemplate = page['firstTemplate']

                # tw is the text width.
                # We need it to calculate header-footer height
                # and compress literal blocks.
                self.tw = self.pw - self.lm - self.rm - self.gm

        # Get page templates from all stylesheets
        self.pageTemplates = {}
        for data, ssname in ssdata:
            templates = data.get('pageTemplates', {})
            # templates is a dictionary of pageTemplates
            for key in templates:
                template = templates[key]
                # template is a dict.
                # template[┬┤frames'] is a list of frames
                if key in self.pageTemplates:
                    self.pageTemplates[key].update(template)
                else:
                    self.pageTemplates[key] = template

        # Get font aliases from all stylesheets in order
        self.fontsAlias = {}
        for data, ssname in ssdata:
            self.fontsAlias.update(data.get('fontsAlias', {}))

        embedded_fontnames = []
        self.embedded = []
        # Embed all fonts indicated in all stylesheets
        for data, ssname in ssdata:
            embedded = data.get('embeddedFonts', [])

            for font in embedded:
                try:
                    # Just a font name, try to embed it
                    if isinstance(font, str):
                        # See if we can find the font
                        fname, pos = findfonts.guessFont(font)
                        if font in embedded_fontnames:
                            pass
                        else:
                            fontList = findfonts.autoEmbed(font)
                            if fontList:
                                embedded_fontnames.append(font)
                        if not fontList:
                            if (fname, pos) in embedded_fontnames:
                                fontList = None
                            else:
                                fontList = findfonts.autoEmbed(fname)
                        if fontList is not None:
                            self.embedded += fontList
                            # Maybe the font we got is not called
                            # the same as the one we gave
                            # so check that out
                            suff = ["", "-Oblique", "-Bold", "-BoldOblique"]
                            if not fontList[0].startswith(font):
                                # We need to create font aliases, and use them
                                for fname, aliasname in zip(
                                        fontList,
                                        [font + suffix for suffix in suff]):
                                    self.fontsAlias[aliasname] = fname
                        continue

                    # Each "font" is a list of four files, which will be
                    # used for regular / bold / italic / bold+italic
                    # versions of the font.
                    # If your font doesn't have one of them, just repeat
                    # the regular font.

                    # Example, using the Tuffy font from
                    # http://tulrich.com/fonts/
                    # "embeddedFonts" : [
                    #                    ["Tuffy.ttf",
                    #                     "Tuffy_Bold.ttf",
                    #                     "Tuffy_Italic.ttf",
                    #                     "Tuffy_Bold_Italic.ttf"]
                    #                   ],

                    # The fonts will be registered with the file name,
                    # minus the extension.

                    if font[0].lower().endswith('.ttf'): # A True Type font
                        for variant in font:
                            location=self.findFont(variant)
                            pdfmetrics.registerFont(
                                TTFont(str(variant.split('.')[0]),
                                location))
                            log.info('Registering font: %s from %s'%\
                                (str(variant.split('.')[0]),location))
                            self.embedded.append(str(variant.split('.')[0]))

                        # And map them all together
                        regular, bold, italic, bolditalic = [
                            variant.split('.')[0] for variant in font]
                        addMapping(regular, 0, 0, regular)
                        addMapping(regular, 0, 1, italic)
                        addMapping(regular, 1, 0, bold)
                        addMapping(regular, 1, 1, bolditalic)
                    else: # A Type 1 font
                        # For type 1 fonts we require
                        # [FontName,regular,italic,bold,bolditalic]
                        # where each variant is a (pfbfile,afmfile) pair.
                        # For example, for the URW palladio from TeX:
                        # ["Palatino",("uplr8a.pfb","uplr8a.afm"),
                        #             ("uplri8a.pfb","uplri8a.afm"),
                        #             ("uplb8a.pfb","uplb8a.afm"),
                        #             ("uplbi8a.pfb","uplbi8a.afm")]
                        faceName = font[0]
                        regular = pdfmetrics.EmbeddedType1Face(*font[1])
                        italic = pdfmetrics.EmbeddedType1Face(*font[2])
                        bold = pdfmetrics.EmbeddedType1Face(*font[3])
                        bolditalic = pdfmetrics.EmbeddedType1Face(*font[4])

                except Exception:
                    _, e, _ = sys.exc_info()
                    try:
                        if isinstance(font, list):
                            fname = font[0]
                        else:
                            fname = font
                        log.error("Error processing font %s: %s",
                            os.path.splitext(fname)[0], str(e))
                        log.error("Registering %s as Helvetica alias", fname)
                        self.fontsAlias[fname] = 'Helvetica'
                    except Exception:
                        _, e, _ = sys.exc_info()
                        log.critical("Error processing font %s: %s",
                            fname, str(e))
                        continue

        # Go though all styles in all stylesheets and find all fontNames.
        # Then decide what to do with them
        for data, ssname in ssdata:
            for [skey, style] in self.stylepairs(data):
                for key in style:
                    if key == 'fontName' or key.endswith('FontName'):
                        # It's an alias, replace it
                        if style[key] in self.fontsAlias:
                            style[key] = self.fontsAlias[style[key]]
                        # Embedded already, nothing to do
                        if style[key] in self.embedded:
                            continue
                        # Standard font, nothing to do
                        if style[key] in (
                                    "Courier",
                                    "Courier-Bold",
                                    "Courier-BoldOblique",
                                    "Courier-Oblique",
                                    "Helvetica",
                                    "Helvetica-Bold",
                                    "Helvetica-BoldOblique",
                                    "Helvetica-Oblique",
                                    "Symbol",
                                    "Times-Bold",
                                    "Times-BoldItalic",
                                    "Times-Italic",
                                    "Times-Roman",
                                    "ZapfDingbats"):
                            continue
                        # Now we need to do something
                        # See if we can find the font
                        fname, pos = findfonts.guessFont(style[key])

                        if style[key] in embedded_fontnames:
                            pass
                        else:
                            fontList = findfonts.autoEmbed(style[key])
                            if fontList:
                                embedded_fontnames.append(style[key])
                        if not fontList:
                            if (fname, pos) in embedded_fontnames:
                                fontList = None
                            else:
                                fontList = findfonts.autoEmbed(fname)
                            if fontList:
                                embedded_fontnames.append((fname, pos))
                        if fontList:
                            self.embedded += fontList
                            # Maybe the font we got is not called
                            # the same as the one we gave so check that out
                            suff = ["", "-Bold", "-Oblique", "-BoldOblique"]
                            if not fontList[0].startswith(style[key]):
                                # We need to create font aliases, and use them
                                basefname=style[key].split('-')[0]
                                for fname, aliasname in zip(
                                        fontList,
                                        [basefname + suffix for
                                        suffix in suff]):
                                    self.fontsAlias[aliasname] = fname
                                style[key] = self.fontsAlias[basefname +\
                                             suff[pos]]
                        else:
                            log.error("Unknown font: \"%s\","
                                      "replacing with Helvetica", style[key])
                            style[key] = "Helvetica"

        #log.info('FontList: %s'%self.embedded)
        #log.info('FontAlias: %s'%self.fontsAlias)
        # Get styles from all stylesheets in order
        self.stylesheet = {}
        self.styles = []
        self.linkColor = 'navy'
        # FIXME: linkColor should probably not be a global
        #        style, and tocColor should probably not
        #        be a special case, but for now I'm going
        #        with the flow...
        self.tocColor = None
        for data, ssname in ssdata:
            self.linkColor = data.get('linkColor') or self.linkColor
            self.tocColor = data.get('tocColor') or self.tocColor
            for [skey, style] in self.stylepairs(data):
                sdict = {}
                # FIXME: this is done completely backwards
                for key in style:
                    # Handle color references by name
                    if key == 'color' or key.endswith('Color') and style[key]:
                        style[key] = formatColor(style[key])

                    # Yet another workaround for the unicode bug in
                    # reportlab's toColor
                    elif key == 'commands':
                        style[key]=validateCommands(style[key])
                        #for command in style[key]:
                            #c=command[0].upper()
                            #if c=='ROWBACKGROUNDS':
                                #command[3]=[str(c) for c in command[3]]
                            #elif c in ['BOX','INNERGRID'] or c.startswith('LINE'):
                                #command[4]=str(command[4])

                    # Handle alignment constants
                    elif key == 'alignment':
                        style[key] = dict(TA_LEFT=0,
                                          LEFT=0,
                                          TA_CENTER=1,
                                          CENTER=1,
                                          TA_CENTRE=1,
                                          CENTRE=1,
                                          TA_RIGHT=2,
                                          RIGHT=2,
                                          TA_JUSTIFY=4,
                                          JUSTIFY=4,
                                          DECIMAL=8, )[style[key].upper()]

                    elif key == 'language':
                        if not style[key] in self.languages:
                            self.languages.append(style[key])

                    # Make keys str instead of unicode (required by reportlab)
                    sdict[str(key)] = style[key]
                    sdict['name'] = skey
                # If the style already exists, update it
                if skey in self.stylesheet:
                    self.stylesheet[skey].update(sdict)
                else: # New style
                    self.stylesheet[skey] = sdict
                    self.styles.append(sdict)

        # If the stylesheet has a style name docutils won't reach
        # make a copy with a sanitized name.
        # This may make name collisions possible but that should be
        # rare (who would have custom_name and custom-name in the
        # same stylesheet? ;-)
        # Issue 339

        styles2=[]
        for s in self.styles:
            if not re.match("^[a-z](-?[a-z0-9]+)*$", s['name']):
                s2 = copy(s)
                s2['name'] = docutils.nodes.make_id(s['name'])
                log.warning('%s is an invalid docutils class name, adding alias %s'%(s['name'], s2['name']))
                styles2.append(s2)
        self.styles.extend(styles2)

        # And create  reportlabs stylesheet
        self.StyleSheet = StyleSheet1()
        # Patch to make the code compatible with reportlab from SVN 2.4+ and
        # 2.4
        if not hasattr(self.StyleSheet, 'has_key'):
            self.StyleSheet.__class__.has_key = lambda s, k : k in s
        for s in self.styles:
            if 'parent' in s:
                if s['parent'] is None:
                    if s['name'] != 'base':
                        s['parent'] = self.StyleSheet['base']
                    else:
                        del(s['parent'])
                else:
                    s['parent'] = self.StyleSheet[s['parent']]
            else:
                if s['name'] != 'base':
                    s['parent'] = self.StyleSheet['base']

            # If the style has no bulletFontName but it has a fontName, set it
            if ('bulletFontName' not in s) and ('fontName' in s):
                s['bulletFontName'] = s['fontName']

            hasFS = True
            # Adjust fontsize units
            if 'fontSize' not in s:
                s['fontSize'] = s['parent'].fontSize
                s['trueFontSize']=None
                hasFS = False
            elif 'parent' in s:
                # This means you can set the fontSize to
                # "2cm" or to "150%" which will be calculated
                # relative to the parent style
                s['fontSize'] = self.adjustUnits(s['fontSize'],
                                    s['parent'].fontSize)
                s['trueFontSize']=s['fontSize']
            else:
                # If s has no parent, it's base, which has
                # an explicit point size by default and %
                # makes no sense, but guess it as % of 10pt
                s['fontSize'] = self.adjustUnits(s['fontSize'], 10)

            # If the leading is not set, but the size is, set it
            if 'leading' not in s and hasFS:
                s['leading'] = 1.2*s['fontSize']

            # If the bullet font size is not set, set it as fontSize
            if ('bulletFontSize' not in s) and ('fontSize' in s):
                s['bulletFontSize'] = s['fontSize']

            # If the borderPadding is a list and wordaxe <=0.3.2,
            # convert it to an integer. Workaround for Issue
            if 'borderPadding' in s and ((HAS_WORDAXE and \
                    wordaxe_version <='wordaxe 0.3.2') or
                    reportlab.Version < "2.3" )\
                    and isinstance(s['borderPadding'], list):
                log.warning('Using a borderPadding list in '\
                    'style %s with wordaxe <= 0.3.2 or Reportlab < 2.3. That is not '\
                    'supported, so it will probably look wrong'%s['name'])
                s['borderPadding']=s['borderPadding'][0]

            self.StyleSheet.add(ParagraphStyle(**s))


        self.emsize=self['base'].fontSize
        # Make stdFont the basefont, for Issue 65
        reportlab.rl_config.canvas_basefontname = self['base'].fontName
        # Make stdFont the default font for table cell styles (Issue 65)
        reportlab.platypus.tables.CellStyle.fontname=self['base'].fontName

Example 34

Project: eden
Source File: appadmin.py
View license
def ccache():
    form = FORM(
        P(TAG.BUTTON(
            T("Clear CACHE?"), _type="submit", _name="yes", _value="yes")),
        P(TAG.BUTTON(
            T("Clear RAM"), _type="submit", _name="ram", _value="ram")),
        P(TAG.BUTTON(
            T("Clear DISK"), _type="submit", _name="disk", _value="disk")),
    )

    if form.accepts(request.vars, session):
        clear_ram = False
        clear_disk = False
        session.flash = ""
        if request.vars.yes:
            clear_ram = clear_disk = True
        if request.vars.ram:
            clear_ram = True
        if request.vars.disk:
            clear_disk = True

        if clear_ram:
            cache.ram.clear()
            session.flash += T("Ram Cleared")
        if clear_disk:
            cache.disk.clear()
            session.flash += T("Disk Cleared")

        redirect(URL(r=request))

    try:
        from guppy import hpy
        hp = hpy()
    except ImportError:
        hp = False

    import shelve
    import os
    import copy
    import time
    import math
    from gluon import portalocker

    ram = {
        'entries': 0,
        'bytes': 0,
        'objects': 0,
        'hits': 0,
        'misses': 0,
        'ratio': 0,
        'oldest': time.time(),
        'keys': []
    }
    disk = copy.copy(ram)
    total = copy.copy(ram)
    disk['keys'] = []
    total['keys'] = []

    def GetInHMS(seconds):
        hours = math.floor(seconds / 3600)
        seconds -= hours * 3600
        minutes = math.floor(seconds / 60)
        seconds -= minutes * 60
        seconds = math.floor(seconds)

        return (hours, minutes, seconds)

    for key, value in cache.ram.storage.items():
        if isinstance(value, dict):
            ram['hits'] = value['hit_total'] - value['misses']
            ram['misses'] = value['misses']
            try:
                ram['ratio'] = ram['hits'] * 100 / value['hit_total']
            except (KeyError, ZeroDivisionError):
                ram['ratio'] = 0
        else:
            if hp:
                ram['bytes'] += hp.iso(value[1]).size
                ram['objects'] += hp.iso(value[1]).count
            ram['entries'] += 1
            if value[0] < ram['oldest']:
                ram['oldest'] = value[0]
            ram['keys'].append((key, GetInHMS(time.time() - value[0])))
    folder = os.path.join(request.folder,'cache')
    if not os.path.exists(folder):
        os.mkdir(folder)
    locker = open(os.path.join(folder, 'cache.lock'), 'a')
    portalocker.lock(locker, portalocker.LOCK_EX)
    disk_storage = shelve.open(
        os.path.join(folder, 'cache.shelve'))
    try:
        for key, value in disk_storage.items():
            if isinstance(value, dict):
                disk['hits'] = value['hit_total'] - value['misses']
                disk['misses'] = value['misses']
                try:
                    disk['ratio'] = disk['hits'] * 100 / value['hit_total']
                except (KeyError, ZeroDivisionError):
                    disk['ratio'] = 0
            else:
                if hp:
                    disk['bytes'] += hp.iso(value[1]).size
                    disk['objects'] += hp.iso(value[1]).count
                disk['entries'] += 1
                if value[0] < disk['oldest']:
                    disk['oldest'] = value[0]
                disk['keys'].append((key, GetInHMS(time.time() - value[0])))

    finally:
        portalocker.unlock(locker)
        locker.close()
        disk_storage.close()

    total['entries'] = ram['entries'] + disk['entries']
    total['bytes'] = ram['bytes'] + disk['bytes']
    total['objects'] = ram['objects'] + disk['objects']
    total['hits'] = ram['hits'] + disk['hits']
    total['misses'] = ram['misses'] + disk['misses']
    total['keys'] = ram['keys'] + disk['keys']
    try:
        total['ratio'] = total['hits'] * 100 / (total['hits'] +
                                                total['misses'])
    except (KeyError, ZeroDivisionError):
        total['ratio'] = 0

    if disk['oldest'] < ram['oldest']:
        total['oldest'] = disk['oldest']
    else:
        total['oldest'] = ram['oldest']

    ram['oldest'] = GetInHMS(time.time() - ram['oldest'])
    disk['oldest'] = GetInHMS(time.time() - disk['oldest'])
    total['oldest'] = GetInHMS(time.time() - total['oldest'])

    def key_table(keys):
        return TABLE(
            TR(TD(B(T('Key'))), TD(B(T('Time in Cache (h:m:s)')))),
            *[TR(TD(k[0]), TD('%02d:%02d:%02d' % k[1])) for k in keys],
            **dict(_class='cache-keys',
                   _style="border-collapse: separate; border-spacing: .5em;"))

    ram['keys'] = key_table(ram['keys'])
    disk['keys'] = key_table(disk['keys'])
    total['keys'] = key_table(total['keys'])

    return dict(form=form, total=total,
                ram=ram, disk=disk, object_stats=hp != False)

Example 35

Project: scalarizr
Source File: __init__.py
View license
    def _src_generator(self):
        '''
        Compress, split, yield out
        '''
        if self._up:
            # Tranzit volume size is chunk for each worker
            # and Ext filesystem overhead

            # if the upload is multiparted, the manifest won't be used
            self.manifest = Manifest()
            # supposedly, manifest's destination path; assumes that dst
            # generator yields self.dst.next()+transfer_id
            self.manifest.cloudfs_path = os.path.join(self.dst.next(),
                                                      self.transfer_id, self.manifest_path)
            self.manifest["description"] = self.description
            if self.tags:
                self.manifest["tags"] = self.tags

            def delete_uploaded_chunk(src, dst, retry, chunk_num):
                os.remove(src)
            self._transfer.on(transfer_complete=delete_uploaded_chunk)

            for src in self.src:
                LOG.debug('src: %s, type: %s', src, type(src))
                fileinfo = {
                        "name": '',
                        "streamer": None,
                        "compressor": None,
                        "chunks": [],
                }
                self.manifest["files"].append(fileinfo)  # moved here from the bottom
                prefix = self._tranzit_vol.mpoint
                stream = None
                cmd = tar = gzip = None

                if hasattr(src, 'read'):
                    stream = src
                    if hasattr(stream, 'name'):
                        # os.pipe stream has name '<fdopen>'
                        name = os.path.basename(stream.name).strip('<>')  # ? can stream name end with '/'
                    else:
                        name = 'stream-%s' % hash(stream)
                    fileinfo["name"] = name
                    prefix = os.path.join(prefix, name) + '.'
                elif self.streamer and isinstance(src, basestring) and os.path.isdir(src):
                    name = os.path.basename(src.rstrip('/'))
                    fileinfo["name"] = name

                    if self.streamer == "tar":
                        fileinfo["streamer"] = "tar"
                        prefix = os.path.join(prefix, name) + '.tar.'

                        if src.endswith('/'):  # tar dir content
                            tar_cmdargs = ['/bin/tar', 'cp', '-C', src, '.']
                        else:
                            parent, target = os.path.split(src)
                            tar_cmdargs = ['/bin/tar', 'cp', '-C', parent, target]

                        LOG.debug("LargeTransfer src_generator TAR POPEN")
                        tar = cmd = subprocess.Popen(
                                                        tar_cmdargs,
                                                        stdout=subprocess.PIPE,
                                                        stderr=subprocess.PIPE,
                                                        close_fds=True)
                        LOG.debug("LargeTransfer src_generator AFTER TAR")
                    elif hasattr(self.streamer, "popen"):
                        fileinfo["streamer"] = str(self.streamer)
                        prefix = os.path.join(prefix, name) + '.'

                        LOG.debug("LargeTransfer src_generator custom streamer POPEN")
                        # TODO: self.streamer.args += src
                        tar = cmd = self.streamer.popen(stdin=None)
                        LOG.debug("LargeTransfer src_generator after custom streamer POPEN")
                    stream = tar.stdout
                elif isinstance(src, basestring) and os.path.isfile(src):
                    name = os.path.basename(src)
                    fileinfo["name"] = name
                    prefix = os.path.join(prefix, name) + '.'

                    stream = open(src)
                else:
                    raise ValueError('Unsupported src: %s' % src)

                if self.compressor == "gzip":
                    fileinfo["compressor"] = "gzip"
                    prefix += 'gz.'
                    LOG.debug("LargeTransfer src_generator GZIP POPEN")
                    gzip = cmd = subprocess.Popen(
                                            [self._gzip_bin(), '-5'],
                                            stdin=stream,
                                            stdout=subprocess.PIPE,
                                            stderr=subprocess.PIPE,
                                            close_fds=True)
                    LOG.debug("LargeTransfer src_generator AFTER GZIP")
                    if tar:
                        # Allow tar to receive SIGPIPE if gzip exits.
                        tar.stdout.close()
                    stream = gzip.stdout
                # custom compressor
                elif hasattr(self.compressor, "popen"):
                    fileinfo["compressor"] = str(self.compressor)
                    LOG.debug("LargeTransfer src_generator custom compressor POPEN")
                    cmd = self.compressor.popen(stdin=stream)
                    LOG.debug("LargeTransfer src_generator after custom compressor POPEN")
                    if tar:
                        tar.stdout.close()
                    stream = cmd.stdout

                for filename, md5sum, size in self._split(stream, prefix):
                    fileinfo["chunks"].append((os.path.basename(filename), md5sum, size))
                    LOG.debug("LargeTransfer src_generator yield %s", filename)
                    yield filename
                if cmd:
                    out, err = cmd.communicate()
                    if err:
                        LOG.debug("LargeTransfer src_generator cmd pipe stderr: %s", err)

            # send manifest to file transfer
            if not self.multipart:
                LOG.debug("Manifest: %s", self.manifest.data)
                manifest_f = os.path.join(self._tranzit_vol.mpoint, self.manifest_path)
                self.manifest.write(manifest_f)
                LOG.debug("LargeTransfer yield %s", manifest_f)
                yield manifest_f

        elif not self._up:
            def on_transfer_error(*args):
                LOG.debug("transfer_error event, shutting down")
                self.kill()
            self._transfer.on(transfer_error=on_transfer_error)

            # The first yielded object will be the manifest, so
            # catch_manifest is a listener that's supposed to trigger only
            # once and unsubscribe itself.
            def wait_manifest(src, dst, retry, chunk_num):
                self._transfer.un('transfer_complete', wait_manifest)
                self._manifest_ready.set()
            self._transfer.on(transfer_complete=wait_manifest)


            manifest_path = self.src
            yield manifest_path

            # ? except EventInterrupt: save exc and return
            self._manifest_ready.wait()

            # we should have the manifest on the tmpfs by now
            manifest_local = os.path.join(self._tranzit_vol.mpoint,
                                          os.path.basename(manifest_path))
            manifest = Manifest(manifest_local)
            os.remove(manifest_local)
            remote_path = os.path.dirname(manifest_path)

            # add ready and done events to each chunk without breaking the
            # chunk order
            with self._chunks_events_access:
                if not self._killed:
                    self.files = copy(manifest["files"])
                    for file_ in self.files:
                        file_["chunks"] = OrderedDict([(
                                chunk[0], {
                                        "md5sum": chunk[1],
                                        "size": chunk[2] if len(chunk) > 2 else None,
                                        "downloaded": InterruptibleEvent(),
                                        "processed": InterruptibleEvent()
                                }
                        ) for chunk in file_["chunks"]])
                        # chunk is [basename, md5sum, size]

            # launch restorer
            if self._restorer is None:
                LOG.debug("STARTING RESTORER")
                self._restorer = threading.Thread(target=self._dl_restorer)
                self._restorer.start()

            def wait_chunk(src, dst, retry, chunk_num):
                chunk_name = os.path.basename(src)
                for file_ in self.files:
                    if chunk_name in file_["chunks"]:
                        chunk = file_["chunks"][chunk_name]
                chunk["downloaded"].set()
                chunk["processed"].wait()
                os.remove(os.path.join(dst, chunk_name))
            self._transfer.on(transfer_complete=wait_chunk)

            for file_ in self.files:
                for chunk in file_["chunks"]:
                    yield os.path.join(remote_path, chunk)

Example 36

Project: schematics
Source File: test_conversion.py
View license
@pytest.mark.parametrize('variant', (None, 'noerrors'))
@pytest.mark.parametrize('partial', (True, False))
@pytest.mark.parametrize('import_, two_pass, input_instance, input_init, init',
                       [( True,    False,    False,          None,       True),
                        ( True,    False,    False,          None,       False),
                        ( True,    False,    True,           False,      True),
                        ( True,    False,    True,           False,      False),
                        ( True,    False,    True,           True,       True),
                        ( True,    False,    True,           True,       False),
                        ( True,    True,     False,          None,       True),
                        ( True,    True,     False,          None,       False),
                        ( True,    True,     True,           False,      True),
                        ( True,    True,     True,           False,      False),
                        ( True,    True,     True,           True,       True),
                        ( True,    True,     True,           True,       False),
                        ( False,   None,     True,           False,      True),
                        ( False,   None,     True,           False,      False),
                        ( False,   None,     True,           True,       True),
                        ( False,   None,     True,           True,       False)])
def test_conversion_with_validation(input, import_, two_pass, input_instance, input_init, init,
                                    partial, variant):

    init_to_none = input_init or init

    if variant == 'noerrors':

        orig_input = copy(input)

        if input_instance:
            assert input.modelfield is orig_input.modelfield

        if import_:
            if two_pass:
                m = M(input, init=init)
                m.validate(partial=partial)
            else:
                m = M(input, init=init, partial=partial, validate=True)
        else:
            input.validate(init_values=init, partial=partial)
            m = input

        assert input == orig_input

        if input_instance:
            if import_:
                assert m.modelfield is not input.modelfield
                assert m._data['modelfield'] is not input._data['modelfield']
                assert m.modelfield.listfield is not input.modelfield.listfield
            else:
                assert m.modelfield is input.modelfield
                assert m._data['modelfield'] is input._data['modelfield']
                assert m.modelfield.listfield is input.modelfield.listfield

        return

    if init_to_none:
        partial_data = {
            'intfield': 1,
            'reqfield': u'foo',
            'matrixfield': None,
            'modelfield': {
                'intfield': None,
                'reqfield': u'bar',
                'matrixfield': None,
                'modelfield': {
                    'reqfield': None,
                    'listfield': None,
                    'modelfield': M({
                        'intfield': 0,
                        'reqfield': u'foo',
                        'listfield': None})}}}
    else:
        partial_data = {
            'intfield': 1,
            'reqfield': u'foo',
            'modelfield': {
                'reqfield': u'bar',
                'modelfield': {
                    'listfield': None,
                    'modelfield': M({
                        'intfield': 0,
                        'reqfield': u'foo',
                        'listfield': None}, init=False)}}}

    with pytest.raises(DataError) as excinfo:
        if import_:
            if two_pass:
                m = M(input, init=init)
                m.validate(partial=partial)
            else:
                M(input, init=init, partial=partial, validate=True)
        else:
            input.validate(init_values=init, partial=partial)

    errors = excinfo.value.errors

    err_list = errors.pop('listfield')
    assert type(err_list) is ValidationError
    assert len(err_list) == 1

    err_list = errors['modelfield'].pop('listfield')
    assert type(err_list) is ValidationError
    assert len(err_list) == 2

    err_list = errors['modelfield']['modelfield'].pop('intfield')
    assert len(err_list) == 1

    if not partial:
        err_list = errors['modelfield']['modelfield'].pop('reqfield')
        assert len(err_list) == 1
        if init_to_none:
            partial_data['modelfield']['modelfield'].pop('reqfield')

    err_dict = errors['modelfield']['modelfield'].pop('matrixfield')
    sub_err_dict = err_dict.pop(1)
    assert list((k, type(v)) for k, v in sub_err_dict.items()) \
        == [(2, ValidationError), (3, ValidationError)]
    assert err_dict == {}

    assert errors['modelfield'].pop('modelfield') == {}
    assert errors.pop('modelfield') == {}
    assert errors == {}

    assert excinfo.value.partial_data == partial_data

Example 37

Project: scipy
Source File: odepack.py
View license
def odeint(func, y0, t, args=(), Dfun=None, col_deriv=0, full_output=0,
           ml=None, mu=None, rtol=None, atol=None, tcrit=None, h0=0.0,
           hmax=0.0, hmin=0.0, ixpr=0, mxstep=0, mxhnil=0, mxordn=12,
           mxords=5, printmessg=0):
    """
    Integrate a system of ordinary differential equations.

    Solve a system of ordinary differential equations using lsoda from the
    FORTRAN library odepack.

    Solves the initial value problem for stiff or non-stiff systems
    of first order ode-s::

        dy/dt = func(y, t0, ...)

    where y can be a vector.

    *Note*: The first two arguments of ``func(y, t0, ...)`` are in the
    opposite order of the arguments in the system definition function used
    by the `scipy.integrate.ode` class.

    Parameters
    ----------
    func : callable(y, t0, ...)
        Computes the derivative of y at t0.
    y0 : array
        Initial condition on y (can be a vector).
    t : array
        A sequence of time points for which to solve for y.  The initial
        value point should be the first element of this sequence.
    args : tuple, optional
        Extra arguments to pass to function.
    Dfun : callable(y, t0, ...)
        Gradient (Jacobian) of `func`.
    col_deriv : bool, optional
        True if `Dfun` defines derivatives down columns (faster),
        otherwise `Dfun` should define derivatives across rows.
    full_output : bool, optional
        True if to return a dictionary of optional outputs as the second output
    printmessg : bool, optional
        Whether to print the convergence message

    Returns
    -------
    y : array, shape (len(t), len(y0))
        Array containing the value of y for each desired time in t,
        with the initial value `y0` in the first row.
    infodict : dict, only returned if full_output == True
        Dictionary containing additional output information

        =======  ============================================================
        key      meaning
        =======  ============================================================
        'hu'     vector of step sizes successfully used for each time step.
        'tcur'   vector with the value of t reached for each time step.
                 (will always be at least as large as the input times).
        'tolsf'  vector of tolerance scale factors, greater than 1.0,
                 computed when a request for too much accuracy was detected.
        'tsw'    value of t at the time of the last method switch
                 (given for each time step)
        'nst'    cumulative number of time steps
        'nfe'    cumulative number of function evaluations for each time step
        'nje'    cumulative number of jacobian evaluations for each time step
        'nqu'    a vector of method orders for each successful step.
        'imxer'  index of the component of largest magnitude in the
                 weighted local error vector (e / ewt) on an error return, -1
                 otherwise.
        'lenrw'  the length of the double work array required.
        'leniw'  the length of integer work array required.
        'mused'  a vector of method indicators for each successful time step:
                 1: adams (nonstiff), 2: bdf (stiff)
        =======  ============================================================

    Other Parameters
    ----------------
    ml, mu : int, optional
        If either of these are not None or non-negative, then the
        Jacobian is assumed to be banded.  These give the number of
        lower and upper non-zero diagonals in this banded matrix.
        For the banded case, `Dfun` should return a matrix whose
        rows contain the non-zero bands (starting with the lowest diagonal).
        Thus, the return matrix `jac` from `Dfun` should have shape
        ``(ml + mu + 1, len(y0))`` when ``ml >=0`` or ``mu >=0``.
        The data in `jac` must be stored such that ``jac[i - j + mu, j]``
        holds the derivative of the `i`th equation with respect to the `j`th
        state variable.  If `col_deriv` is True, the transpose of this
        `jac` must be returned.
    rtol, atol : float, optional
        The input parameters `rtol` and `atol` determine the error
        control performed by the solver.  The solver will control the
        vector, e, of estimated local errors in y, according to an
        inequality of the form ``max-norm of (e / ewt) <= 1``,
        where ewt is a vector of positive error weights computed as
        ``ewt = rtol * abs(y) + atol``.
        rtol and atol can be either vectors the same length as y or scalars.
        Defaults to 1.49012e-8.
    tcrit : ndarray, optional
        Vector of critical points (e.g. singularities) where integration
        care should be taken.
    h0 : float, (0: solver-determined), optional
        The step size to be attempted on the first step.
    hmax : float, (0: solver-determined), optional
        The maximum absolute step size allowed.
    hmin : float, (0: solver-determined), optional
        The minimum absolute step size allowed.
    ixpr : bool, optional
        Whether to generate extra printing at method switches.
    mxstep : int, (0: solver-determined), optional
        Maximum number of (internally defined) steps allowed for each
        integration point in t.
    mxhnil : int, (0: solver-determined), optional
        Maximum number of messages printed.
    mxordn : int, (0: solver-determined), optional
        Maximum order to be allowed for the non-stiff (Adams) method.
    mxords : int, (0: solver-determined), optional
        Maximum order to be allowed for the stiff (BDF) method.

    See Also
    --------
    ode : a more object-oriented integrator based on VODE.
    quad : for finding the area under a curve.

    Examples
    --------
    The second order differential equation for the angle `theta` of a
    pendulum acted on by gravity with friction can be written::

        theta''(t) + b*theta'(t) + c*sin(theta(t)) = 0

    where `b` and `c` are positive constants, and a prime (') denotes a
    derivative.  To solve this equation with `odeint`, we must first convert
    it to a system of first order equations.  By defining the angular
    velocity ``omega(t) = theta'(t)``, we obtain the system::

        theta'(t) = omega(t)
        omega'(t) = -b*omega(t) - c*sin(theta(t))

    Let `y` be the vector [`theta`, `omega`].  We implement this system
    in python as:

    >>> def pend(y, t, b, c):
    ...     theta, omega = y
    ...     dydt = [omega, -b*omega - c*np.sin(theta)]
    ...     return dydt
    ...
    
    We assume the constants are `b` = 0.25 and `c` = 5.0:

    >>> b = 0.25
    >>> c = 5.0

    For initial conditions, we assume the pendulum is nearly vertical
    with `theta(0)` = `pi` - 0.1, and it initially at rest, so
    `omega(0)` = 0.  Then the vector of initial conditions is

    >>> y0 = [np.pi - 0.1, 0.0]

    We generate a solution 101 evenly spaced samples in the interval
    0 <= `t` <= 10.  So our array of times is:

    >>> t = np.linspace(0, 10, 101)

    Call `odeint` to generate the solution.  To pass the parameters
    `b` and `c` to `pend`, we give them to `odeint` using the `args`
    argument.

    >>> from scipy.integrate import odeint
    >>> sol = odeint(pend, y0, t, args=(b, c))

    The solution is an array with shape (101, 2).  The first column
    is `theta(t)`, and the second is `omega(t)`.  The following code
    plots both components.

    >>> import matplotlib.pyplot as plt
    >>> plt.plot(t, sol[:, 0], 'b', label='theta(t)')
    >>> plt.plot(t, sol[:, 1], 'g', label='omega(t)')
    >>> plt.legend(loc='best')
    >>> plt.xlabel('t')
    >>> plt.grid()
    >>> plt.show()
    """

    if ml is None:
        ml = -1  # changed to zero inside function call
    if mu is None:
        mu = -1  # changed to zero inside function call
    t = copy(t)
    y0 = copy(y0)
    output = _odepack.odeint(func, y0, t, args, Dfun, col_deriv, ml, mu,
                             full_output, rtol, atol, tcrit, h0, hmax, hmin,
                             ixpr, mxstep, mxhnil, mxordn, mxords)
    if output[-1] < 0:
        warning_msg = _msgs[output[-1]] + " Run with full_output = 1 to get quantitative information."
        warnings.warn(warning_msg, ODEintWarning)
    elif printmessg:
        warning_msg = _msgs[output[-1]]
        warnings.warn(warning_msg, ODEintWarning)

    if full_output:
        output[1]['message'] = _msgs[output[-1]]

    output = output[:-1]
    if len(output) == 1:
        return output[0]
    else:
        return output

Example 38

Project: iris
Source File: cartography.py
View license
def project(cube, target_proj, nx=None, ny=None):
    """
    Nearest neighbour regrid to a specified target projection.

    Return a new cube that is the result of projecting a cube with 1 or 2
    dimensional latitude-longitude coordinates from its coordinate system into
    a specified projection e.g. Robinson or Polar Stereographic.
    This function is intended to be used in cases where the cube's coordinates
    prevent one from directly visualising the data, e.g. when the longitude
    and latitude are two dimensional and do not make up a regular grid.

    Args:
        * cube
            An instance of :class:`iris.cube.Cube`.
        * target_proj
            An instance of the Cartopy Projection class, or an instance of
            :class:`iris.coord_systems.CoordSystem` from which a projection
            will be obtained.
    Kwargs:
        * nx
            Desired number of sample points in the x direction for a domain
            covering the globe.
        * ny
            Desired number of sample points in the y direction for a domain
            covering the globe.

    Returns:
        An instance of :class:`iris.cube.Cube` and a list describing the
        extent of the projection.

    .. note::

        This function assumes global data and will if necessary extrapolate
        beyond the geographical extent of the source cube using a nearest
        neighbour approach. nx and ny then include those points which are
        outside of the target projection.

    .. note::

        Masked arrays are handled by passing their masked status to the
        resulting nearest neighbour values.  If masked, the value in the
        resulting cube is set to 0.

    .. warning::

        This function uses a nearest neighbour approach rather than any form
        of linear/non-linear interpolation to determine the data value of each
        cell in the resulting cube. Consequently it may have an adverse effect
        on the statistics of the data e.g. the mean and standard deviation
        will not be preserved.

    """
    try:
        lat_coord, lon_coord = _get_lat_lon_coords(cube)
    except IndexError:
        raise ValueError('Cannot get latitude/longitude '
                         'coordinates from cube {!r}.'.format(cube.name()))

    if lat_coord.coord_system != lon_coord.coord_system:
        raise ValueError('latitude and longitude coords appear to have '
                         'different coordinates systems.')

    if lon_coord.units != 'degrees':
        lon_coord = lon_coord.copy()
        lon_coord.convert_units('degrees')
    if lat_coord.units != 'degrees':
        lat_coord = lat_coord.copy()
        lat_coord.convert_units('degrees')

    # Determine source coordinate system
    if lat_coord.coord_system is None:
        # Assume WGS84 latlon if unspecified
        warnings.warn('Coordinate system of latitude and longitude '
                      'coordinates is not specified. Assuming WGS84 Geodetic.')
        orig_cs = iris.coord_systems.GeogCS(semi_major_axis=6378137.0,
                                            inverse_flattening=298.257223563)
    else:
        orig_cs = lat_coord.coord_system

    # Convert to cartopy crs
    source_cs = orig_cs.as_cartopy_crs()

    # Obtain coordinate arrays (ignoring bounds) and convert to 2d
    # if not already.
    source_x = lon_coord.points
    source_y = lat_coord.points
    if source_x.ndim != 2 or source_y.ndim != 2:
        source_x, source_y = np.meshgrid(source_x, source_y)

    # Calculate target grid
    target_cs = None
    if isinstance(target_proj, iris.coord_systems.CoordSystem):
        target_cs = target_proj
        target_proj = target_proj.as_cartopy_projection()

    # Resolution of new grid
    if nx is None:
        nx = source_x.shape[1]
    if ny is None:
        ny = source_x.shape[0]

    target_x, target_y, extent = cartopy.img_transform.mesh_projection(
        target_proj, nx, ny)

    # Determine dimension mappings - expect either 1d or 2d
    if lat_coord.ndim != lon_coord.ndim:
        raise ValueError("The latitude and longitude coordinates have "
                         "different dimensionality.")

    latlon_ndim = lat_coord.ndim
    lon_dims = cube.coord_dims(lon_coord)
    lat_dims = cube.coord_dims(lat_coord)

    if latlon_ndim == 1:
        xdim = lon_dims[0]
        ydim = lat_dims[0]
    elif latlon_ndim == 2:
        if lon_dims != lat_dims:
            raise ValueError("The 2d latitude and longitude coordinates "
                             "correspond to different dimensions.")
        # If coords are 2d assume that grid is ordered such that x corresponds
        # to the last dimension (shortest stride).
        xdim = lon_dims[1]
        ydim = lon_dims[0]
    else:
        raise ValueError('Expected the latitude and longitude coordinates '
                         'to have 1 or 2 dimensions, got {} and '
                         '{}.'.format(lat_coord.ndim, lon_coord.ndim))

    # Create array to store regridded data
    new_shape = list(cube.shape)
    new_shape[xdim] = nx
    new_shape[ydim] = ny
    new_data = ma.zeros(new_shape, cube.data.dtype)

    # Create iterators to step through cube data in lat long slices
    new_shape[xdim] = 1
    new_shape[ydim] = 1
    index_it = np.ndindex(*new_shape)
    if lat_coord.ndim == 1 and lon_coord.ndim == 1:
        slice_it = cube.slices([lat_coord, lon_coord])
    elif lat_coord.ndim == 2 and lon_coord.ndim == 2:
        slice_it = cube.slices(lat_coord)
    else:
        raise ValueError('Expected the latitude and longitude coordinates '
                         'to have 1 or 2 dimensions, got {} and '
                         '{}.'.format(lat_coord.ndim, lon_coord.ndim))

#    # Mask out points outside of extent in source_cs - disabled until
#    # a way to specify global/limited extent is agreed upon and code
#    # is generalised to handle -180 to +180, 0 to 360 and >360 longitudes.
#    source_desired_xy = source_cs.transform_points(target_proj,
#                                                   target_x.flatten(),
#                                                   target_y.flatten())
#    if np.any(source_x < 0.0) and np.any(source_x > 180.0):
#        raise ValueError('Unable to handle range of longitude.')
#    # This does not work in all cases e.g. lon > 360
#    if np.any(source_x > 180.0):
#        source_desired_x = (source_desired_xy[:, 0].reshape(ny, nx) +
#                            360.0) % 360.0
#    else:
#        source_desired_x = source_desired_xy[:, 0].reshape(ny, nx)
#    source_desired_y = source_desired_xy[:, 1].reshape(ny, nx)
#    outof_extent_points = ((source_desired_x < source_x.min()) |
#                           (source_desired_x > source_x.max()) |
#                           (source_desired_y < source_y.min()) |
#                           (source_desired_y > source_y.max()))
#    # Make array a mask by default (rather than a single bool) to allow mask
#    # to be assigned to slices.
#    new_data.mask = np.zeros(new_shape)

    # Step through cube data, regrid onto desired projection and insert results
    # in new_data array
    for index, ll_slice in zip(index_it, slice_it):
        # Regrid source data onto target grid
        index = list(index)
        index[xdim] = slice(None, None)
        index[ydim] = slice(None, None)
        new_data[index] = cartopy.img_transform.regrid(ll_slice.data,
                                                       source_x, source_y,
                                                       source_cs,
                                                       target_proj,
                                                       target_x, target_y)

#    # Mask out points beyond extent
#    new_data[index].mask[outof_extent_points] = True

    # Remove mask if it is unnecessary
    if not np.any(new_data.mask):
        new_data = new_data.data

    # Create new cube
    new_cube = iris.cube.Cube(new_data)

    # Add new grid coords
    x_coord = iris.coords.DimCoord(target_x[0, :], 'projection_x_coordinate',
                                   units='m',
                                   coord_system=copy.copy(target_cs))
    y_coord = iris.coords.DimCoord(target_y[:, 0], 'projection_y_coordinate',
                                   units='m',
                                   coord_system=copy.copy(target_cs))

    new_cube.add_dim_coord(x_coord, xdim)
    new_cube.add_dim_coord(y_coord, ydim)

    # Add resampled lat/lon in original coord system
    source_desired_xy = source_cs.transform_points(target_proj,
                                                   target_x.flatten(),
                                                   target_y.flatten())
    new_lon_points = source_desired_xy[:, 0].reshape(ny, nx)
    new_lat_points = source_desired_xy[:, 1].reshape(ny, nx)
    new_lon_coord = iris.coords.AuxCoord(new_lon_points,
                                         standard_name='longitude',
                                         units='degrees',
                                         coord_system=orig_cs)
    new_lat_coord = iris.coords.AuxCoord(new_lat_points,
                                         standard_name='latitude',
                                         units='degrees',
                                         coord_system=orig_cs)
    new_cube.add_aux_coord(new_lon_coord, [ydim, xdim])
    new_cube.add_aux_coord(new_lat_coord, [ydim, xdim])

    coords_to_ignore = set()
    coords_to_ignore.update(cube.coords(contains_dimension=xdim))
    coords_to_ignore.update(cube.coords(contains_dimension=ydim))
    for coord in cube.dim_coords:
        if coord not in coords_to_ignore:
            new_cube.add_dim_coord(coord.copy(), cube.coord_dims(coord))
    for coord in cube.aux_coords:
        if coord not in coords_to_ignore:
            new_cube.add_aux_coord(coord.copy(), cube.coord_dims(coord))
    discarded_coords = coords_to_ignore.difference([lat_coord, lon_coord])
    if discarded_coords:
        warnings.warn('Discarding coordinates that share dimensions with '
                      '{} and {}: {}'.format(lat_coord.name(),
                                             lon_coord.name(),
                                             [coord.name() for
                                              coord in discarded_coords]))

    # TODO handle derived coords/aux_factories

    # Copy metadata across
    new_cube.metadata = cube.metadata

    return new_cube, extent

Example 39

Project: iSDX
Source File: tnode.py
View license
def cmd_thread(conn):
    global generation
    data = conn.recv(1024)
    
    if len(data) == 0:
        conn.sendall(host + ':XX ERROR: No data\n')
        conn.close()
        return;
    
    tokens = data.split()
    tokens = shlex.split(data)
    n = len(tokens)
    if n == 0:
        conn.sendall(host + ':XX ERROR: Null data\n')
        conn.close()
        return;
    
    cmd = tokens[0]
    
    if cmd == 'quit':
        conn.sendall(host + ':XX OK: EXITING\n')
        conn.close()
        os._exit(1)

    if cmd == 'dump' and n == 1:
        while not outq.empty():
            conn.sendall(outq.get())
        conn.close()
        return;
    
    if cmd == 'exec':
        tokens.pop(0)
        try:
            p = subprocess.Popen(tokens, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            out, err = p.communicate()
        except:
            out = 'Command Failed\n'
            err = ''
        conn.sendall(out)
        conn.sendall(err)
        # DEBUG conn.sendall(str(tokens))
        conn.close()
        return
    
    if cmd == 'listener' and n == 3:
        addr = tokens[1]
        port = tokens[2]
        r = create_listener(addr, int(port))
        if len(r) > 0:
            conn.sendall(host +':00' + ' ' + r + '\n')
        conn.close()
        return;
    
    if cmd == 'test' and n == 5:
        rand = tokens[1]
        baddr = tokens[2]
        daddr = tokens[3]
        dport = tokens[4]
        
        m = rand + ' bind:' + baddr + ' dst:' + daddr + ':' + str(dport)
    
        try:
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.bind((baddr, 0))
            s.settimeout(connection_timeout)
            s.connect((daddr, int(dport)))
            s.sendall(rand) #must be 10 characters
            for _ in range(2000):
                s.sendall(buf)
            s.shutdown(1)
            #time.sleep(1)  # seems to be needed on windows or we get a violent server exception
            s.close()
        except Exception, e:
            conn.sendall(host + ':XX ERROR: ' + 'TEST ' + m + ' ' + repr(e) + '\n')
            conn.close()
            return
        
        conn.sendall(host + ':XX OK: ' + 'TEST ' + m + ' TRANSFER COMPLETE\n')     
        conn.close()
        return;
    
    if cmd == 'result' and n == 2:
        rid = tokens[1]
        lock.acquire()
        c = completed.get(rid)
        p = pending.get(rid)
        if c is None and p is None:
            lock.release()
            msg = host + ':00 UNKNOWN ' + rid + '\n'
        elif p is not None:
            lock.release()
            msg = p
        else:
            completed.pop(rid)
            lock.release()
            msg = c
        conn.sendall(msg)
        conn.close()
        return
    
    if cmd == 'reset':
        generation += 1
        conn.sendall(host + ':XX OK: ' + 'RESET new generation = ' + str(generation) + '\n')     
        conn.close()
        return;
    
    if cmd == 'result' and n == 1:
        lock.acquire()
        c = copy.copy(completed)
        p = copy.copy(pending)
        lock.release()
        for i in c:
            conn.sendall(c[i])
        for i in p:
            conn.sendall(p[i])
        conn.close()
        return
    
    if cmd == 'announce' or cmd == 'withdraw' and n > 1:
        if cmd == 'withdraw':
            no = 'no '
        else:
            no = ''
        
        try:
            # first find the BGP ASN
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', 2605))
            s.sendall('sdnip\nenable\nshow ip bgp summary\nquit\nquit\nquit\n')
            data = ''
            while True:
                chunk = s.recv(4096)
                if chunk is None or len(chunk) == 0:
                    break
                data += chunk
            s.close()            
            asn = 'UNKNOWN'
            for l in data.split('\n'):
                if 'local AS number' in l:
                    asn = l.split()[-1]
                    break
            #conn.sendall('\n*****   ' + asn + '   *****\n')
                   
            # now send the command to announce the route
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', 2605))
            nets = 'sdnip\n'  # password in sdnip.py
            nets += 'enable\n'  # reach privileged bgp commands
            nets += 'configure terminal\n' # beginning of config
            nets += 'router bgp ' + asn + '\n' # bgp config
            for i in range(1, n):
                nets += no + 'network ' + tokens[i] + '\n' # "no" for withdraw
            nets += 'quit\nquit\nquit\n' # unwind the commands (or connection won't terminate
            s.sendall(nets)
            
            while True:
                data = s.recv(4096)
                if data is None or len(data) == 0:
                    break
                conn.sendall(data)
            s.close()
            conn.close()
        except Exception, e:
            conn.sendall(host + ':XX ERROR: ' + 'ANNOUNCE/WITHDRAW ' + repr(e) + '\n')
            conn.close()
            return
        return
    
    if cmd == 'bgp' and n == 1:
        try:
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', 2605))
            seq = 'sdnip\n'  # password in sdnip.py
            seq += 'enable\n'  # reach privileged bgp commands
            seq += 'show ip bgp\n' # dump bgp routes
            seq += 'quit\nquit\n' # unwind the commands (or connection won't terminate
            s.sendall(seq)
            
            while True:
                data = s.recv(4096)
                if data is None or len(data) == 0:
                    break
                conn.sendall(data)
            s.close()
            conn.close()
        except Exception, e:
            conn.sendall(host + ':XX ERROR: ' + 'BGP ' + repr(e) + '\n')
            conn.close()
            return
        return
    
    if cmd == 'router':
        del tokens[0]
        all = ""
        for arg in tokens:
            all += arg + ' '
        try:
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.connect(('127.0.0.1', 2605))
            seq = 'sdnip\n'  # password in sdnip.py
            seq += 'enable\n'  # reach privileged bgp commands
            seq += all + '\n' 
            seq += 'quit\nquit\n' # unwind the commands (or connection won't terminate
            s.sendall(seq)
            
            while True:
                data = s.recv(4096)
                if data is None or len(data) == 0:
                    break
                conn.sendall(data)
            s.close()
            conn.close()
        except Exception, e:
            conn.sendall(host + ':XX ERROR: ' + 'ROUTER ' + repr(e) + '\n')
            conn.close()
            return
        return
    
    if cmd == 'echo':
        del tokens[0]
        all = ""
        for arg in tokens:
            all += "<" + arg + "> "
        conn.sendall(all + '\n')
        conn.close()
    
    conn.sendall(host + ':XX ERROR: Bad command: ' + data)
    conn.close()

Example 40

Project: sfepy
Source File: terms.py
View license
    def check_shapes(self, *args, **kwargs):
        """
        Check term argument shapes at run-time.
        """
        from sfepy.base.base import output
        from sfepy.mechanics.tensors import dim2sym

        dim = self.region.dim
        sym = dim2sym(dim)

        def _parse_scalar_shape(sh):
            if isinstance(sh, basestr):
                if sh == 'D':
                    return dim

                elif sh == 'D2':
                    return dim**2

                elif sh == 'S':
                    return sym

                elif sh == 'N': # General number.
                    return nm.inf

                else:
                    return int(sh)

            else:
                return sh

        def _parse_tuple_shape(sh):
            if isinstance(sh, basestr):
                return tuple((_parse_scalar_shape(ii.strip())
                              for ii in sh.split(',')))

            else:
                return (int(sh),)

        arg_kinds = get_arg_kinds(self.ats)

        arg_shapes_list = self.arg_shapes
        if not isinstance(arg_shapes_list, list):
            arg_shapes_list = [arg_shapes_list]

        # Loop allowed shapes until a match is found, else error.
        allowed_shapes = []
        prev_shapes = {}
        for _arg_shapes in arg_shapes_list:
            # Unset shapes are taken from the previous iteration.
            arg_shapes = copy(prev_shapes)
            arg_shapes.update(_arg_shapes)
            prev_shapes = arg_shapes

            allowed_shapes.append(arg_shapes)

            n_ok = 0
            for ii, arg_kind in enumerate(arg_kinds):
                if arg_kind in ('user', 'ts'):
                    n_ok += 1
                    continue

                arg = args[ii]

                if self.mode is not None:
                    extended_ats = self.ats[ii] + ('/%s' % self.mode)

                else:
                    extended_ats = self.ats[ii]

                try:
                    sh = arg_shapes[self.ats[ii]]

                except KeyError:
                    sh = arg_shapes[extended_ats]

                if arg_kind.endswith('variable'):
                    n_el, n_qp, _dim, n_en, n_c = self.get_data_shape(arg)
                    shape = _parse_scalar_shape(sh[0] if isinstance(sh, tuple)
                                                else sh)
                    if nm.isinf(shape):
                        n_ok += 1

                    else:
                        n_ok += shape == n_c

                elif arg_kind.endswith('material'):
                    if arg is None: # Switched-off opt_material.
                        n_ok += sh is None
                        continue

                    if sh is None:
                        continue

                    prefix = ''
                    if isinstance(sh, basestr):
                        aux = sh.split(':')
                        if len(aux) == 2:
                            prefix, sh = aux

                    shape = _parse_tuple_shape(sh)
                    ls = len(shape)

                    aarg = nm.array(arg, ndmin=1)

                    # Substiture general dimension 'N' with actual value.
                    iinfs = nm.where(nm.isinf(shape))[0]
                    if len(iinfs):
                        shape = list(shape)
                        for iinf in iinfs:
                            shape[iinf] = aarg.shape[-ls+iinf]
                        shape = tuple(shape)

                    if (ls > 1) or (shape[0] > 1):
                        # Array.
                        n_ok += shape == aarg.shape[-ls:]

                    elif (ls == 1) and (shape[0] == 1):
                        # Scalar constant.
                        from numbers import Number
                        n_ok += isinstance(arg, Number)

                else:
                    n_ok += 1

            if n_ok == len(arg_kinds):
                break

        else:
            term_str = '%s.%d.%s(%s)' % (self.name, self.integral.order,
                                         self.region.name, self.arg_str)
            output('allowed argument shapes for term "%s":' % term_str)
            output(allowed_shapes)
            raise ValueError('wrong arguments shapes for "%s" term! (see above)'
                             % term_str)

Example 41

Project: ReproWeb
Source File: datastructures.py
View license
    def test_basic_interface(self):
        md = self.storage_class()
        assert isinstance(md, dict)

        mapping = [('a', 1), ('b', 2), ('a', 2), ('d', 3),
                   ('a', 1), ('a', 3), ('d', 4), ('c', 3)]
        md = self.storage_class(mapping)

        # simple getitem gives the first value
        assert md['a'] == 1
        assert md['c'] == 3
        with self.assert_raises(KeyError):
            md['e']
        assert md.get('a') == 1

        # list getitem
        assert md.getlist('a') == [1, 2, 1, 3]
        assert md.getlist('d') == [3, 4]
        # do not raise if key not found
        assert md.getlist('x') == []

        # simple setitem overwrites all values
        md['a'] = 42
        assert md.getlist('a') == [42]

        # list setitem
        md.setlist('a', [1, 2, 3])
        assert md['a'] == 1
        assert md.getlist('a') == [1, 2, 3]

        # verify that it does not change original lists
        l1 = [1, 2, 3]
        md.setlist('a', l1)
        del l1[:]
        assert md['a'] == 1

        # setdefault, setlistdefault
        assert md.setdefault('u', 23) == 23
        assert md.getlist('u') == [23]
        del md['u']

        md.setlist('u', [-1, -2])

        # delitem
        del md['u']
        with self.assert_raises(KeyError):
            md['u']
        del md['d']
        assert md.getlist('d') == []

        # keys, values, items, lists
        assert list(sorted(md.keys())) == ['a', 'b', 'c']
        assert list(sorted(md.iterkeys())) == ['a', 'b', 'c']

        assert list(sorted(md.values())) == [1, 2, 3]
        assert list(sorted(md.itervalues())) == [1, 2, 3]

        assert list(sorted(md.items())) == [('a', 1), ('b', 2), ('c', 3)]
        assert list(sorted(md.items(multi=True))) == \
               [('a', 1), ('a', 2), ('a', 3), ('b', 2), ('c', 3)]
        assert list(sorted(md.iteritems())) == [('a', 1), ('b', 2), ('c', 3)]
        assert list(sorted(md.iteritems(multi=True))) == \
               [('a', 1), ('a', 2), ('a', 3), ('b', 2), ('c', 3)]

        assert list(sorted(md.lists())) == [('a', [1, 2, 3]), ('b', [2]), ('c', [3])]
        assert list(sorted(md.iterlists())) == [('a', [1, 2, 3]), ('b', [2]), ('c', [3])]

        # copy method
        c = md.copy()
        assert c['a'] == 1
        assert c.getlist('a') == [1, 2, 3]

        # copy method 2
        c = copy(md)
        assert c['a'] == 1
        assert c.getlist('a') == [1, 2, 3]

        # update with a multidict
        od = self.storage_class([('a', 4), ('a', 5), ('y', 0)])
        md.update(od)
        assert md.getlist('a') == [1, 2, 3, 4, 5]
        assert md.getlist('y') == [0]

        # update with a regular dict
        md = c
        od = {'a': 4, 'y': 0}
        md.update(od)
        assert md.getlist('a') == [1, 2, 3, 4]
        assert md.getlist('y') == [0]

        # pop, poplist, popitem, popitemlist
        assert md.pop('y') == 0
        assert 'y' not in md
        assert md.poplist('a') == [1, 2, 3, 4]
        assert 'a' not in md
        assert md.poplist('missing') == []

        # remaining: b=2, c=3
        popped = md.popitem()
        assert popped in [('b', 2), ('c', 3)]
        popped = md.popitemlist()
        assert popped in [('b', [2]), ('c', [3])]

        # type conversion
        md = self.storage_class({'a': '4', 'b': ['2', '3']})
        assert md.get('a', type=int) == 4
        assert md.getlist('b', type=int) == [2, 3]

        # repr
        md = self.storage_class([('a', 1), ('a', 2), ('b', 3)])
        assert "('a', 1)" in repr(md)
        assert "('a', 2)" in repr(md)
        assert "('b', 3)" in repr(md)

        # add and getlist
        md.add('c', '42')
        md.add('c', '23')
        assert md.getlist('c') == ['42', '23']
        md.add('c', 'blah')
        assert md.getlist('c', type=int) == [42, 23]

        # setdefault
        md = self.storage_class()
        md.setdefault('x', []).append(42)
        md.setdefault('x', []).append(23)
        assert md['x'] == [42, 23]

        # to dict
        md = self.storage_class()
        md['foo'] = 42
        md.add('bar', 1)
        md.add('bar', 2)
        assert md.to_dict() == {'foo': 42, 'bar': 1}
        assert md.to_dict(flat=False) == {'foo': [42], 'bar': [1, 2]}

        # popitem from empty dict
        with self.assert_raises(KeyError):
            self.storage_class().popitem()

        with self.assert_raises(KeyError):
            self.storage_class().popitemlist()

        # key errors are of a special type
        with self.assert_raises(BadRequestKeyError):
            self.storage_class()[42]

        # setlist works
        md = self.storage_class()
        md['foo'] = 42
        md.setlist('foo', [1, 2])
        assert md.getlist('foo') == [1, 2]

Example 42

View license
def from_NS_output_to_chains(folder):
    """
    Translate the output of MultiNest into readable output for Monte Python

    This routine will be called by the module :mod:`analyze`.

    If mode separation has been performed (i.e., multimodal=True), it creates
    'mode_#' subfolders containing a chain file with the corresponding samples
    and a 'log.param' file in which the starting point is the best fit of the
    nested sampling, and the same for the sigma. The minimum and maximum value
    are cropped to the extent of the modes in the case of the parameters used
    for the mode separation, and preserved in the rest.

    The mono-modal case is treated as a special case of the multi-modal one.

    """
    chain_name = [a for a in folder.split(os.path.sep) if a][-2]
    base_name = os.path.join(folder, chain_name)

    # Read the arguments of the NS run
    # This file is intended to be machine generated: no "#" ignored or tests
    # done
    NS_arguments = {}
    with open(base_name+name_arguments, 'r') as afile:
        for line in afile:
            arg   = line.split('=')[0].strip()
            value = line.split('=')[1].strip()
            arg_type = (NS_user_arguments[arg]['type']
                        if arg in NS_user_arguments else
                        NS_auto_arguments[arg]['type'])
            value = arg_type(value)
            if arg == 'clustering_params':
                value = [a.strip() for a in value.split()]
            NS_arguments[arg] = value
    multimodal = NS_arguments.get('multimodal')
    # Read parameters order
    NS_param_names = np.loadtxt(base_name+name_paramnames, dtype='str').tolist()
    # In multimodal case, if there were no clustering params specified, ALL are
    if multimodal and not NS_arguments.get('clustering_params'):
        NS_arguments['clustering_params'] = NS_param_names

    # Extract the necessary information from the log.param file
    # Including line numbers of the parameters
    with open(os.path.join(folder, '..', name_logparam), 'r') as log_file:
        log_lines = log_file.readlines()
    # Number of the lines to be changed
    param_names = []
    param_lines = {}
    param_data  = {}
    pre, pos = 'data.parameters[', ']'
    for i, line in enumerate(log_lines):
        if pre in line:
            if line.strip()[0] == '#':
                continue
            param_name = line.split('=')[0][line.find(pre)+len(pre):
                                            line.find(pos)]
            param_name = param_name.replace('"','').replace("'",'').strip()
            param_names.append(param_name)
            param_data[param_name] = [a.strip() for a in
                                      line.split('=')[1].strip('[]').split(',')]
            param_lines[param_name] = i

    # Create the mapping from NS ordering to log.param ordering
    columns_reorder = [NS_param_names.index(param) for param in param_names]

    # Open the 'stats.dat' file to see what happened and retrieve some info
    stats_file = open(base_name+name_stats, 'r')
    lines = stats_file.readlines()
    stats_file.close()
    # Mode-separated info
    i = 0
    n_modes = 0
    stats_mode_lines = {0: []}
    for line in lines:
        if 'Nested Sampling Global Log-Evidence' in line:
            global_logZ, global_logZ_err = [float(a.strip()) for a in
                                            line.split(':')[1].split('+/-')]
        if 'Total Modes Found' in line:
            n_modes = int(line.split(':')[1].strip())
        if line[:4] == 'Mode':
            i += 1
            stats_mode_lines[i] = []
        # This stores the info of each mode i>1 in stats_mode_lines[i] and in
        # i=0 the lines previous to the modes, in the multi-modal case or the
        # info of the only mode, in the mono-modal case
        stats_mode_lines[i].append(line)
    assert n_modes == max(stats_mode_lines.keys()), (
        'Something is wrong... (strange error n.1)')

    # Prepare the accepted-points file -- modes are separated by 2 line breaks
    accepted_name = base_name + (name_post_sep if multimodal else name_post)
    with open(accepted_name, 'r') as accepted_file:
        mode_lines = [a for a in ''.join(accepted_file.readlines()).split('\n\n')
                      if a != '']
    if multimodal:
        mode_lines = [[]] + mode_lines
    assert len(mode_lines) == 1+n_modes, 'Something is wrong... (strange error n.2)'

# TODO: prepare total and rejected chain

    # Process each mode:
    ini = 1 if multimodal else 0
    for i in range(ini, 1+n_modes):
        # Create subfolder
        if multimodal:
            mode_subfolder = 'mode_'+str(i).zfill(len(str(n_modes)))
        else:
            mode_subfolder = ''
        mode_subfolder = os.path.join(folder, '..', mode_subfolder)
        if not os.path.exists(mode_subfolder):
            os.makedirs(mode_subfolder)

        # Add ACCEPTED points
        mode_data = np.array(mode_lines[i].split(), dtype='float64')
        columns = 2+NS_arguments['n_params']
        mode_data = mode_data.reshape([mode_data.shape[0]/columns, columns])
        # Rearrange: sample-prob | -2*loglik | params (clustering first)
        #       ---> sample-prob |   -loglik | params (log.param order)
        mode_data[:, 1]  = mode_data[:, 1] / 2.
        mode_data[:, 2:] = mode_data[:, [2+j for j in columns_reorder]]
        np.savetxt(os.path.join(mode_subfolder, name_chain_acc),
                   mode_data, fmt='%.6e')

        # If we are not in the multimodal case, we are done!
        if not multimodal:
            break
        # In the multimodal case, we want to write a log.param for each mod
        this_log_lines = copy(log_lines)

        # Get the necessary info of the parameters:
        #  -- max_posterior (MAP), sigma  <---  stats.dat file
        for j, line in enumerate(stats_mode_lines[i]):
            if 'Sigma' in line:
                line_sigma = j+1
            if 'MAP' in line:
                line_MAP = j+2
        MAPs   = {}
        sigmas = {}
        for j, param in enumerate(NS_param_names):
            n, MAP = stats_mode_lines[i][line_MAP+j].split()
            assert int(n) == j+1,  'Something is wrong... (strange error n.3)'
            MAPs[param] = MAP
            n, mean, sigma = stats_mode_lines[i][line_sigma+j].split()
            assert int(n) == j+1,  'Something is wrong... (strange error n.4)'
            sigmas[param] = sigma
        #  -- minimum rectangle containing the mode (only clustering params)
        mins = {}
        maxs = {}
        for param in NS_arguments['clustering_params']:
            # Notice that in the next line we use param_names and not
            # NS_param_names: the chain lines have already been reordered
            values = mode_data[:, 2+param_names.index(param)]
            mins[param] = min(values)
            maxs[param] = max(values)
        # Create the log.param file
        for param in param_names:
            if param in NS_arguments['clustering_params']:
                mini, maxi = '%.6e'%mins[param], '%.6e'%maxs[param]
            else:
                mini, maxi = param_data[param][1], param_data[param][2]
            scaling = param_data[param][4]
            ptype   = param_data[param][5]
            line = pre+"'"+param+"'"+pos
            values = [MAPs[param], mini, maxi, sigmas[param], scaling, ptype]
            line += ' = [' + ', '.join(values) + ']\n'
            this_log_lines[param_lines[param]] = line
        # Write it!
        with open(os.path.join(mode_subfolder, 'log.param'), 'w') as log_file:
            log_file.writelines(this_log_lines)

Example 43

Project: sncosmo
Source File: fitting.py
View license
def fit_lc(data, model, vparam_names, bounds=None, method='minuit',
           guess_amplitude=True, guess_t0=True, guess_z=True,
           minsnr=5., modelcov=False, verbose=False, maxcall=10000,
           **kwargs):
    """Fit model parameters to data by minimizing chi^2.

    Ths function defines a chi^2 to minimize, makes initial guesses for
    t0 and amplitude, then runs a minimizer.

    Parameters
    ----------
    data : `~astropy.table.Table` or `~numpy.ndarray` or `dict`
        Table of photometric data. Must include certain columns.
        See the "Photometric Data" section of the documentation for
        required columns.
    model : `~sncosmo.Model`
        The model to fit.
    vparam_names : list
        Model parameters to vary in the fit.
    bounds : `dict`, optional
        Bounded range for each parameter. Keys should be parameter
        names, values are tuples. If a bound is not given for some
        parameter, the parameter is unbounded. The exception is
        ``t0``: by default, the minimum bound is such that the latest
        phase of the model lines up with the earliest data point and
        the maximum bound is such that the earliest phase of the model
        lines up with the latest data point.
    guess_amplitude : bool, optional
        Whether or not to guess the amplitude from the data. If false, the
        current model amplitude is taken as the initial value. Only has an
        effect when fitting amplitude. Default is True.
    guess_t0 : bool, optional
        Whether or not to guess t0. Only has an effect when fitting t0.
        Default is True.
    guess_z : bool, optional
        Whether or not to guess z (redshift). Only has an effect when fitting
        redshift. Default is True.
    minsnr : float, optional
        When guessing amplitude and t0, only use data with signal-to-noise
        ratio (flux / fluxerr) greater than this value. Default is 5.
    method : {'minuit'}, optional
        Minimization method to use. Currently there is only one choice.
    modelcov : bool, optional
        Include model covariance when calculating chisq. Default is False.
    verbose : bool, optional
        Print messages during fitting.

    Returns
    -------
    res : Result
        The optimization result represented as a ``Result`` object, which is
        a `dict` subclass with attribute access. Therefore, ``res.keys()``
        provides a list of the attributes. Attributes are:

        - ``success``: boolean describing whether fit succeeded.
        - ``message``: string with more information about exit status.
        - ``ncall``: number of function evaluations.
        - ``chisq``: minimum chi^2 value.
        - ``ndof``: number of degrees of freedom
          (len(data) - len(vparam_names)).
        - ``param_names``: same as ``model.param_names``.
        - ``parameters``: 1-d `~numpy.ndarray` of best-fit values
          (including fixed parameters) corresponding to ``param_names``.
        - ``vparam_names``: list of varied parameter names.
        - ``covariance``: 2-d `~numpy.ndarray` of parameter covariance;
          indicies correspond to order of ``vparam_names``.
        - ``errors``: OrderedDict of varied parameter uncertainties.
          Corresponds to square root of diagonal entries in covariance matrix.

    fitmodel : `~sncosmo.Model`
        A copy of the model with parameters set to best-fit values.

    Notes
    -----

    **t0 guess:** If ``t0`` is being fit and ``guess_t0=True``, the
    function will guess the initial starting point for ``t0`` based on
    the data. The guess is made as follows:

    * Evaluate the time and value of peak flux for the model in each band
      given the current model parameters.
    * Determine the data point with maximum flux in each band, for points
      with signal-to-noise ratio > ``minsnr`` (default is 5). If no points
      meet this criteria, the band is ignored (for the purpose of guessing
      only).
    * For each band, compare model's peak flux to the peak data point. Choose
      the band with the highest ratio of data / model.
    * Set ``t0`` so that the model's time of peak in the chosen band
      corresponds to the peak data point in this band.

    **amplitude guess:** If amplitude (assumed to be the first model parameter)
    is being fit and ``guess_amplitude=True``, the function will guess the
    initial starting point for the amplitude based on the data.

    **redshift guess:** If redshift (``z``) is being fit and ``guess_z=True``,
    the function will set the initial value of ``z`` to the average of the
    bounds on ``z``.

    Examples
    --------

    The `~sncosmo.flatten_result` function can be used to make the result
    a dictionary suitable for appending as rows of a table:

    >>> from astropy.table import Table               # doctest: +SKIP
    >>> table_rows = []                               # doctest: +SKIP
    >>> for sn in sne:                                # doctest: +SKIP
    ...     res, fitmodel = sncosmo.fit_lc(           # doctest: +SKIP
    ...          sn, model, ['t0', 'x0', 'x1', 'c'])  # doctest: +SKIP
    ...     table_rows.append(flatten_result(res))    # doctest: +SKIP
    >>> t = Table(table_rows)                         # doctest: +SKIP

    """

    # Standardize and normalize data.
    data = standardize_data(data)
    data = normalize_data(data)

    # Make a copy of the model so we can modify it with impunity.
    model = copy.copy(model)

    # Check that vparam_names isn't empty and contains only parameters
    # known to the model.
    if len(vparam_names) == 0:
        raise ValueError("no parameters supplied")
    for s in vparam_names:
        if s not in model.param_names:
            raise ValueError("Parameter not in model: " + repr(s))

    # Order vparam_names the same way it is ordered in the model:
    vparam_names = [s for s in model.param_names if s in vparam_names]

    # initialize bounds
    if bounds is None:
        bounds = {}

    # Check that 'z' is bounded (if it is going to be fit).
    if 'z' in vparam_names:
        if 'z' not in bounds or None in bounds['z']:
            raise ValueError('z must be bounded if fit.')
        if guess_z:
            model.set(z=sum(bounds['z']) / 2.)
        if model.get('z') < bounds['z'][0] or model.get('z') > bounds['z'][1]:
            raise ValueError('z out of range.')

    # Cut bands that are not allowed by the wavelength range of the model.
    data = cut_bands(data, model, z_bounds=bounds.get('z', None))

    # Unique set of bands in data
    bands = set(data['band'].tolist())

    # Find t0 bounds to use, if not explicitly given
    if 't0' in vparam_names and 't0' not in bounds:
        bounds['t0'] = t0_bounds(data, model)

    # Note that in the parameter guessing below, we assume that the source
    # amplitude is the 3rd parameter of the Model (1st parameter of the Source)

    # Turn off guessing if we're not fitting the parameter.
    if model.param_names[2] not in vparam_names:
        guess_amplitude = False
    if 't0' not in vparam_names:
        guess_t0 = False

    # Make guesses for t0 and amplitude.
    # (For now, we assume it is the 3rd parameter of the model.)
    if (guess_amplitude or guess_t0):
        t0, amplitude = guess_t0_and_amplitude(data, model, minsnr)
        if guess_amplitude:
            model.parameters[2] = amplitude
        if guess_t0:
            model.set(t0=t0)

    # count degrees of freedom
    ndof = len(data) - len(vparam_names)

    if method == 'minuit':
        try:
            import iminuit
        except ImportError:
            raise ValueError("Minimization method 'minuit' requires the "
                             "iminuit package")

        # The iminuit minimizer expects the function signature to have an
        # argument for each parameter.
        def fitchisq(*parameters):
            model.parameters = parameters
            return _chisq(data, model, modelcov=modelcov)

        # Set up keyword arguments to pass to Minuit initializer.
        kwargs = {}
        for name in model.param_names:
            kwargs[name] = model.get(name)  # Starting point.

            # Fix parameters not being varied in the fit.
            if name not in vparam_names:
                kwargs['fix_' + name] = True
                kwargs['error_' + name] = 0.
                continue

            # Bounds
            if name in bounds:
                if None in bounds[name]:
                    raise ValueError('one-sided bounds not allowed for '
                                     'minuit minimizer')
                kwargs['limit_' + name] = bounds[name]

            # Initial step size
            if name in bounds:
                step = 0.02 * (bounds[name][1] - bounds[name][0])
            elif model.get(name) != 0.:
                step = 0.1 * model.get(name)
            else:
                step = 1.
            kwargs['error_' + name] = step

        if verbose:
            print("Initial parameters:")
            for name in vparam_names:
                print(name, kwargs[name], 'step=', kwargs['error_' + name],
                      end=" ")
                if 'limit_' + name in kwargs:
                    print('bounds=', kwargs['limit_' + name], end=" ")
                print()

        m = iminuit.Minuit(fitchisq, errordef=1.,
                           forced_parameters=model.param_names,
                           print_level=(1 if verbose else 0),
                           throw_nan=True, **kwargs)
        d, l = m.migrad(ncall=maxcall)

        # Build a message.
        message = []
        if d.has_reached_call_limit:
            message.append('Reached call limit.')
        if d.hesse_failed:
            message.append('Hesse Failed.')
        if not d.has_covariance:
            message.append('No covariance.')
        elif not d.has_accurate_covar:  # iminuit docs wrong
            message.append('Covariance may not be accurate.')
        if not d.has_posdef_covar:  # iminuit docs wrong
            message.append('Covariance not positive definite.')
        if d.has_made_posdef_covar:
            message.append('Covariance forced positive definite.')
        if not d.has_valid_parameters:
            message.append('Parameter(s) value and/or error invalid.')
        if len(message) == 0:
            message.append('Minimization exited successfully.')
        # iminuit: m.np_matrix() doesn't work

        # numpy array of best-fit values (including fixed parameters).
        parameters = np.array([m.values[name] for name in model.param_names])
        model.parameters = parameters  # set model parameters to best fit.

        # Covariance matrix (only varied parameters) as numpy array.
        if m.covariance is None:
            covariance = None
        else:
            covariance = np.array([
                [m.covariance[(n1, n2)] for n1 in vparam_names]
                for n2 in vparam_names])

        # OrderedDict of errors
        if m.errors is None:
            errors = None
        else:
            errors = odict([(name, m.errors[name]) for name in vparam_names])

        # Compile results
        res = Result(success=d.is_valid,
                     message=' '.join(message),
                     ncall=d.nfcn,
                     chisq=d.fval,
                     ndof=ndof,
                     param_names=model.param_names,
                     parameters=parameters,
                     vparam_names=vparam_names,
                     covariance=covariance,
                     errors=errors)

        # TODO remove cov_names in a future release.
        depmsg = ("The `cov_names` attribute is deprecated in sncosmo v1.0 "
                  "and will be removed in v1.1. Use `vparam_names` instead.")
        res.__dict__['deprecated']['cov_names'] = (vparam_names, depmsg)

    else:
        raise ValueError("unknown method {0:r}".format(method))

    # TODO remove this in a future release.
    if "flatten" in kwargs:
        warn("The `flatten` keyword is deprecated in sncosmo v1.0 "
             "and will be removed in v1.1. Use the flatten_result() "
             "function instead.")
        if kwargs["flatten"]:
            res = flatten_result(res)
    return res, model

Example 44

Project: sncosmo
Source File: fitting.py
View license
def mcmc_lc(data, model, vparam_names, bounds=None, priors=None,
            guess_amplitude=True, guess_t0=True, guess_z=True,
            minsnr=5., modelcov=False, nwalkers=10, nburn=200,
            nsamples=1000, sampler='ensemble', ntemps=4, thin=1,
            a=2.0):
    """Run an MCMC chain to get model parameter samples.

    This is a convenience function around `emcee.EnsembleSampler` andx
    `emcee.PTSampler`. It defines the likelihood function and makes a
    heuristic guess at a good set of starting points for the
    walkers. It then runs the sampler, starting with a burn-in run.

    If you're not getting good results, you might want to try
    increasing the burn-in, increasing the walkers, or specifying a
    better starting position.  To get a better starting position, you
    could first run `~sncosmo.fit_lc`, then run this function with all
    ``guess_[name]`` keyword arguments set to False, so that the
    current model parameters are used as the starting point.

    Parameters
    ----------
    data : `~astropy.table.Table` or `~numpy.ndarray` or `dict`
        Table of photometric data. Must include certain columns.
        See the "Photometric Data" section of the documentation for
        required columns.
    model : `~sncosmo.Model`
        The model to fit.
    vparam_names : iterable
        Model parameters to vary.
    bounds : `dict`, optional
        Bounded range for each parameter. Keys should be parameter
        names, values are tuples. If a bound is not given for some
        parameter, the parameter is unbounded. The exception is
        ``t0``: by default, the minimum bound is such that the latest
        phase of the model lines up with the earliest data point and
        the maximum bound is such that the earliest phase of the model
        lines up with the latest data point.
    priors : `dict`, optional
        Prior probability functions. Keys are parameter names, values are
        functions that return probability given the parameter value.
        The default prior is a flat distribution.
    guess_amplitude : bool, optional
        Whether or not to guess the amplitude from the data. If false, the
        current model amplitude is taken as the initial value. Only has an
        effect when fitting amplitude. Default is True.
    guess_t0 : bool, optional
        Whether or not to guess t0. Only has an effect when fitting t0.
        Default is True.
    guess_z : bool, optional
        Whether or not to guess z (redshift). Only has an effect when fitting
        redshift. Default is True.
    minsnr : float, optional
        When guessing amplitude and t0, only use data with signal-to-noise
        ratio (flux / fluxerr) greater than this value. Default is 5.
    modelcov : bool, optional
        Include model covariance when calculating chisq. Default is False.
    nwalkers : int, optional
        Number of walkers in the sampler.
    nburn : int, optional
        Number of samples in burn-in phase.
    nsamples : int, optional
        Number of samples in production run.
    sampler: str, optional
        The kind of sampler to use. Currently 'ensemble' for
        `emcee.EnsembleSampler` and 'pt' for `emcee.PTSampler` are
        supported.
    ntemps : int, optional
        If `sampler == 'pt'` the number of temperatures to use for the
        parallel tempered sampler.
    thin : int, optional
        Factor by which to thin samples in production run. Output samples
        array will have (nsamples/thin) samples.
    a : float, optional
        Proposal scale parameter passed to the sampler.

    Returns
    -------
    res : Result
        Has the following attributes:

        * ``param_names``: All parameter names of model, including fixed.
        * ``parameters``: Model parameters, with varied parameters set to
          mean value in samples.
        * ``vparam_names``: Names of parameters varied. Order of parameters
          matches order of samples.
        * ``samples``: 2-d array with shape ``(N, len(vparam_names))``.
          Order of parameters in each row  matches order in
          ``res.vparam_names``.
        * ``covariance``: 2-d array giving covariance, measured from samples.
          Order corresponds to ``res.vparam_names``.
        * ``errors``: dictionary giving square root of diagonal of covariance
          matrix for varied parameters. Useful for ``plot_lc``.
        * ``mean_acceptance_fraction``: mean acceptance fraction for all
          walkers in the sampler.

    est_model : `~sncosmo.Model`
        Copy of input model with varied parameters set to mean value in
        samples.

    """

    try:
        import emcee
    except ImportError:
        raise ImportError("mcmc_lc() requires the emcee package.")

    # Standardize and normalize data.
    data = standardize_data(data)
    data = normalize_data(data)

    # Make a copy of the model so we can modify it with impunity.
    model = copy.copy(model)

    if bounds is None:
        bounds = {}
    if priors is None:
        priors = {}

    # Check that vparam_names isn't empty, check for unknown parameters.
    if len(vparam_names) == 0:
        raise ValueError("no parameters supplied")
    for names in (vparam_names, bounds, priors):
        for name in names:
            if name not in model.param_names:
                raise ValueError("Parameter not in model: " + repr(name))

    # Order vparam_names the same way it is ordered in the model:
    vparam_names = [s for s in model.param_names if s in vparam_names]
    ndim = len(vparam_names)

    # Check that 'z' is bounded (if it is going to be fit).
    if 'z' in vparam_names:
        if 'z' not in bounds or None in bounds['z']:
            raise ValueError('z must be bounded if allowed to vary.')
        if guess_z:
            model.set(z=sum(bounds['z']) / 2.)
        if model.get('z') < bounds['z'][0] or model.get('z') > bounds['z'][1]:
            raise ValueError('z out of range.')

    # Cut bands that are not allowed by the wavelength range of the model.
    data = cut_bands(data, model, z_bounds=bounds.get('z', None))

    # Find t0 bounds to use, if not explicitly given
    if 't0' in vparam_names and 't0' not in bounds:
        bounds['t0'] = t0_bounds(data, model)

    # Note that in the parameter guessing below, we assume that the source
    # amplitude is the 3rd parameter of the Model (1st parameter of the Source)

    # Turn off guessing if we're not fitting the parameter.
    if model.param_names[2] not in vparam_names:
        guess_amplitude = False
    if 't0' not in vparam_names:
        guess_t0 = False

    # Make guesses for t0 and amplitude.
    # (we assume amplitude is the 3rd parameter of the model.)
    if guess_amplitude or guess_t0:
        t0, amplitude = guess_t0_and_amplitude(data, model, minsnr)
        if guess_amplitude:
            model.parameters[2] = amplitude
        if guess_t0:
            model.set(t0=t0)

    # Indicies used in probability function.
    # modelidx: Indicies of model parameters corresponding to vparam_names.
    # idxbounds: tuples of (varied parameter index, low bound, high bound).
    # idxpriors: tuples of (varied parameter index, function).
    modelidx = np.array([model.param_names.index(k) for k in vparam_names])
    idxbounds = [(vparam_names.index(k), bounds[k][0], bounds[k][1])
                 for k in bounds]
    idxpriors = [(vparam_names.index(k), priors[k]) for k in priors]

    # Posterior function.
    def lnlike(parameters):
        for i, low, high in idxbounds:
            if not low < parameters[i] < high:
                return -np.inf

        model.parameters[modelidx] = parameters
        logp = -0.5 * _chisq(data, model, modelcov=modelcov)
        return logp

    def lnprior(parameters):
        logp = 0
        for i, func in idxpriors:
            logp += math.log(func(parameters[i]))
        return logp

    def lnprob(parameters):
        return lnprior(parameters) + lnlike(parameters)

    # Heuristic determination of walker initial positions: distribute
    # walkers uniformly over parameter space. If no bounds are
    # supplied for a given parameter, use a heuristically determined
    # scale.

    if sampler == 'pt':
        pos = np.empty((ndim, nwalkers, ntemps))
        for i, name in enumerate(vparam_names):
            if name in bounds:
                pos[i] = np.random.uniform(low=bounds[name][0],
                                           high=bounds[name][1],
                                           size=(nwalkers, ntemps))
            else:
                ctr = model.get(name)
                scale = np.abs(ctr)
                pos[i] = np.random.uniform(low=ctr-scale, high=ctr+scale,
                                           size=(nwalkers, ntemps))
        pos = np.swapaxes(pos, 0, 2)
        sampler = emcee.PTSampler(ntemps, nwalkers, ndim, lnlike, lnprob, a=a)

    # Heuristic determination of walker initial positions: distribute
    # walkers in a symmetric gaussian ball, with heuristically
    # determined scale.

    elif sampler == 'ensemble':
        ctr = model.parameters[modelidx]
        scale = np.ones(ndim)
        for i, name in enumerate(vparam_names):
            if name in bounds:
                scale[i] = 0.0001 * (bounds[name][1] - bounds[name][0])
            elif model.get(name) != 0.:
                scale[i] = 0.01 * model.get(name)
            else:
                scale[i] = 0.1
        pos = ctr + scale * np.random.normal(size=(nwalkers, ndim))
        sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, a=a)

    else:
        raise ValueError('Invalid sampler type. Currently "pt" '
                         'and "ensemble" are supported.')

    # Run the sampler.
    pos, prob, state = sampler.run_mcmc(pos, nburn)  # burn-in
    sampler.reset()
    sampler.run_mcmc(pos, nsamples, thin=thin)  # production run
    samples = sampler.flatchain.reshape(-1, ndim)

    # Summary statistics.
    vparameters = np.mean(samples, axis=0)
    cov = np.cov(samples, rowvar=0)
    model.set(**dict(zip(vparam_names, vparameters)))
    errors = odict(zip(vparam_names, np.sqrt(np.diagonal(cov))))
    mean_acceptance_fraction = np.mean(sampler.acceptance_fraction)

    res = Result(param_names=copy.copy(model.param_names),
                 parameters=model.parameters.copy(),
                 vparam_names=vparam_names,
                 samples=samples,
                 covariance=cov,
                 errors=errors,
                 mean_acceptance_fraction=mean_acceptance_fraction)

    return res, model

Example 45

Project: opticspy
Source File: axis3d.py
View license
    def draw(self, renderer):
        self.label._transform = self.axes.transData
        renderer.open_group('axis3d')

        # code from XAxis
        majorTicks = self.get_major_ticks()
        majorLocs = self.major.locator()

        info = self._axinfo
        index = info['i']

        # filter locations here so that no extra grid lines are drawn
        locmin, locmax = self.get_view_interval()
        if locmin > locmax:
            locmin, locmax = locmax, locmin

        # Rudimentary clipping
        majorLocs = [loc for loc in majorLocs if
                     locmin <= loc <= locmax]
        self.major.formatter.set_locs(majorLocs)
        majorLabels = [self.major.formatter(val, i)
                       for i, val in enumerate(majorLocs)]

        mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)

        # Determine grid lines
        minmax = np.where(highs, maxs, mins)

        # Draw main axis line
        juggled = info['juggled']
        edgep1 = minmax.copy()
        edgep1[juggled[0]] = get_flip_min_max(edgep1, juggled[0], mins, maxs)

        edgep2 = edgep1.copy()
        edgep2[juggled[1]] = get_flip_min_max(edgep2, juggled[1], mins, maxs)
        pep = proj3d.proj_trans_points([edgep1, edgep2], renderer.M)
        centpt = proj3d.proj_transform(centers[0], centers[1], centers[2], renderer.M)
        self.line.set_data((pep[0][0], pep[0][1]), (pep[1][0], pep[1][1]))
        self.line.draw(renderer)

        # Grid points where the planes meet
        xyz0 = []
        for val in majorLocs:
            coord = minmax.copy()
            coord[index] = val
            xyz0.append(coord)

        # Draw labels
        peparray = np.asanyarray(pep)
        # The transAxes transform is used because the Text object
        # rotates the text relative to the display coordinate system.
        # Therefore, if we want the labels to remain parallel to the
        # axis regardless of the aspect ratio, we need to convert the
        # edge points of the plane to display coordinates and calculate
        # an angle from that.
        # TODO: Maybe Text objects should handle this themselves?
        dx, dy = (self.axes.transAxes.transform(peparray[0:2, 1]) -
                  self.axes.transAxes.transform(peparray[0:2, 0]))

        lxyz = 0.5*(edgep1 + edgep2)

        labeldeltas = info['label']['space_factor'] * deltas
        axmask = [True, True, True]
        axmask[index] = False
        lxyz = move_from_center(lxyz, centers, labeldeltas, axmask)
        tlx, tly, tlz = proj3d.proj_transform(lxyz[0], lxyz[1], lxyz[2], \
                renderer.M)
        self.label.set_position((tlx, tly))
        if self.get_rotate_label(self.label.get_text()):
            angle = art3d.norm_text_angle(math.degrees(math.atan2(dy, dx)))
            self.label.set_rotation(angle)
        self.label.set_va(info['label']['va'])
        self.label.set_ha(info['label']['ha'])
        self.label.draw(renderer)


        # Draw Offset text

        # Which of the two edge points do we want to
        # use for locating the offset text?
        if juggled[2] == 2 :
            outeredgep = edgep1
            outerindex = 0
        else :
            outeredgep = edgep2
            outerindex = 1

        pos = copy.copy(outeredgep)
        pos = move_from_center(pos, centers, labeldeltas, axmask)
        olx, oly, olz = proj3d.proj_transform(pos[0], pos[1], pos[2], renderer.M)
        self.offsetText.set_text( self.major.formatter.get_offset() )
        self.offsetText.set_position( (olx, oly) )
        angle = art3d.norm_text_angle(math.degrees(math.atan2(dy, dx)))
        self.offsetText.set_rotation(angle)
        # Must set rotation mode to "anchor" so that
        # the alignment point is used as the "fulcrum" for rotation.
        self.offsetText.set_rotation_mode('anchor')

        #-----------------------------------------------------------------------
        # Note: the following statement for determining the proper alignment of
        #       the offset text. This was determined entirely by trial-and-error
        #       and should not be in any way considered as "the way".  There are
        #       still some edge cases where alignment is not quite right, but
        #       this seems to be more of a geometry issue (in other words, I
        #       might be using the wrong reference points).
        #
        #   (TT, FF, TF, FT) are the shorthand for the tuple of
        #     (centpt[info['tickdir']] <= peparray[info['tickdir'], outerindex],
        #      centpt[index] <= peparray[index, outerindex])
        #
        #   Three-letters (e.g., TFT, FTT) are short-hand for the array
        #    of bools from the variable 'highs'.
        # ---------------------------------------------------------------------
        if centpt[info['tickdir']] > peparray[info['tickdir'], outerindex] :
            # if FT and if highs has an even number of Trues
            if (centpt[index] <= peparray[index, outerindex]
                and ((len(highs.nonzero()[0]) % 2) == 0)) :
                # Usually, this means align right, except for the FTT case,
                # in which offset for axis 1 and 2 are aligned left.
                if highs.tolist() == [False, True, True] and index in (1, 2) :
                    align = 'left'
                else :
                    align = 'right'
            else :
                # The FF case
                align = 'left'
        else :
            # if TF and if highs has an even number of Trues
            if (centpt[index] > peparray[index, outerindex]
                and ((len(highs.nonzero()[0]) % 2) == 0)) :
                # Usually mean align left, except if it is axis 2
                if index == 2 :
                    align = 'right'
                else :
                    align = 'left'
            else :
                # The TT case
                align = 'right'

        self.offsetText.set_va('center')
        self.offsetText.set_ha(align)
        self.offsetText.draw(renderer)

        # Draw grid lines
        if len(xyz0) > 0:
            # Grid points at end of one plane
            xyz1 = copy.deepcopy(xyz0)
            newindex = (index + 1) % 3
            newval = get_flip_min_max(xyz1[0], newindex, mins, maxs)
            for i in range(len(majorLocs)):
                xyz1[i][newindex] = newval

            # Grid points at end of the other plane
            xyz2 = copy.deepcopy(xyz0)
            newindex = (index + 2) %  3
            newval = get_flip_min_max(xyz2[0], newindex, mins, maxs)
            for i in range(len(majorLocs)):
                xyz2[i][newindex] = newval

            lines = list(zip(xyz1, xyz0, xyz2))
            if self.axes._draw_grid:
                self.gridlines.set_segments(lines)
                self.gridlines.set_color([info['grid']['color']] * len(lines))
                self.gridlines.draw(renderer, project=True)

        # Draw ticks
        tickdir = info['tickdir']
        tickdelta = deltas[tickdir]
        if highs[tickdir]:
            ticksign = 1
        else:
            ticksign = -1

        for tick, loc, label in zip(majorTicks, majorLocs, majorLabels):
            if tick is None:
                continue

            # Get tick line positions
            pos = copy.copy(edgep1)
            pos[index] = loc
            pos[tickdir] = edgep1[tickdir] + info['tick']['outward_factor'] * \
                                             ticksign * tickdelta
            x1, y1, z1 = proj3d.proj_transform(pos[0], pos[1], pos[2], \
                    renderer.M)
            pos[tickdir] = edgep1[tickdir] - info['tick']['inward_factor'] * \
                                             ticksign * tickdelta
            x2, y2, z2 = proj3d.proj_transform(pos[0], pos[1], pos[2], \
                    renderer.M)

            # Get position of label
            labeldeltas = [info['ticklabel']['space_factor'] * x for
                           x in deltas]
            axmask = [True, True, True]
            axmask[index] = False
            pos[tickdir] = edgep1[tickdir]
            pos = move_from_center(pos, centers, labeldeltas, axmask)
            lx, ly, lz = proj3d.proj_transform(pos[0], pos[1], pos[2], \
                    renderer.M)

            tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly))
            tick.set_label1(label)
            tick.set_label2(label)
            tick.draw(renderer)

        renderer.close_group('axis3d')

Example 46

Project: timestring
Source File: Date.py
View license
    def __init__(self, date, offset=None, start_of_week=None, tz=None, verbose=False):
        if isinstance(date, Date):
            self.date = copy(date.date)
            return

        # The original request
        self._original = date
        if tz:
            tz = pytz.timezone(str(tz))

        if date == 'infinity':
            self.date = 'infinity'

        elif date == 'now':
            self.date = datetime.now()

        elif type(date) in (str, unicode) and re.match(r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d+-\d{2}", date):
            self.date = datetime.strptime(date[:-3], "%Y-%m-%d %H:%M:%S.%f") - timedelta(hours=int(date[-3:]))

        else:
            # Determinal starting date.
            if type(date) in (str, unicode):
                """The date is a string and needs to be converted into a <dict> for processesing
                """
                _date = date.lower()
                res = TIMESTRING_RE.search(_date.strip())
                if res:
                    date = res.groupdict()
                    if verbose:
                        print("Matches:\n", ''.join(["\t%s: %s\n" % (k, v) for k, v in date.items() if v]))
                else:
                    raise TimestringInvalid('Invalid date string >> %s' % date)

                date = dict((k, v if type(v) is str else v) for k, v in date.items() if v)
                #print(_date, dict(map(lambda a: (a, date.get(a)), filter(lambda a: date.get(a), date))))

            if isinstance(date, dict):
                # Initial date.
                new_date = datetime(*time.localtime()[:3])
                if tz and tz.zone != "UTC":
                    #
                    # The purpose here is to adjust what day it is based on the timezeone
                    #
                    ts = datetime.now()
                    # Daylight savings === second Sunday in March and reverts to standard time on the first Sunday in November
                    # Monday is 0 and Sunday is 6.
                    # 14 days - dst_start.weekday()
                    dst_start = datetime(ts.year, 3, 1, 2, 0, 0) + timedelta(13 - datetime(ts.year, 3, 1).weekday())
                    dst_end = datetime(ts.year, 11, 1, 2, 0, 0) + timedelta(6 - datetime(ts.year, 11, 1).weekday())

                    ts = ts + tz.utcoffset(new_date, is_dst=(dst_start < ts < dst_end))
                    new_date = datetime(ts.year, ts.month, ts.day)

                if date.get('unixtime'):
                    new_date = datetime.fromtimestamp(int(date.get('unixtime')))

                # !number of (days|...) (ago)?
                elif date.get('num') and (date.get('delta') or date.get('delta_2')):
                    if date.get('num', '').find('couple') > -1:
                        i = 2 * int(1 if date.get('ago', True) or date.get('ref') == 'last' else -1)
                    else:
                        i = int(text2num(date.get('num', 'one'))) * int(1 if date.get('ago') or (date.get('ref', '') or '') == 'last' else -1)

                    delta = (date.get('delta') or date.get('delta_2')).lower()
                    if delta.startswith('y'):
                        try:
                            new_date = new_date.replace(year=(new_date.year - i))
                        # day is out of range for month
                        except ValueError:
                            new_date = new_date - timedelta(days=(365*i))
                    elif delta.startswith('month'):
                        try:
                            new_date = new_date.replace(month=(new_date.month - i))
                        # day is out of range for month
                        except ValueError:
                            new_date = new_date - timedelta(days=(30*i))

                    elif delta.startswith('q'):
                        '''
                        This section is not working...
                        Most likely need a generator that will take me to the right quater.
                        '''
                        q1, q2, q3, q4 = datetime(new_date.year, 1, 1), datetime(new_date.year, 4, 1), datetime(new_date.year, 7, 1), datetime(new_date.year, 10, 1)
                        if q1 <= new_date < q2:
                            # We are in Q1
                            if i == -1:
                                new_date = datetime(new_date.year-1, 10, 1)
                            else:
                                new_date = q2
                        elif q2 <= new_date < q3:
                            # We are in Q2
                            pass
                        elif q3 <= new_date < q4:
                            # We are in Q3
                            pass
                        else:
                            # We are in Q4
                            pass
                        new_date = new_date - timedelta(days=(91*i))

                    elif delta.startswith('w'):
                        new_date = new_date - timedelta(days=(i * 7))

                    else:
                        new_date = new_date - timedelta(**{('days' if delta.startswith('d') else 'hours' if delta.startswith('h') else 'minutes' if delta.startswith('m') else 'seconds'): i})

                # !dow
                if [date.get(key) for key in ('day', 'day_2', 'day_3') if date.get(key)]:
                    dow = max([date.get(key) for key in ('day', 'day_2', 'day_3') if date.get(key)])
                    iso = dict(monday=1, tuesday=2, wednesday=3, thursday=4, friday=5, saturday=6, sunday=7, mon=1, tue=2, tues=2, wed=3, wedn=3, thu=4, thur=4, fri=5, sat=6, sun=7).get(dow)
                    if iso:
                        # determin which direction
                        if date.get('ref') not in ('this', 'next'):
                            days = iso - new_date.isoweekday() - (7 if iso >= new_date.isoweekday() else 0)
                        else:
                            days = iso - new_date.isoweekday() + (7 if iso < new_date.isoweekday() else 0)

                        new_date = new_date + timedelta(days=days)

                    elif dow == 'yesterday':
                        new_date = new_date - timedelta(days=1)
                    elif dow == 'tomorrow':
                        new_date = new_date + timedelta(days=1)

                # !year
                year = [int(CLEAN_NUMBER.sub('', date[key])) for key in ('year', 'year_2', 'year_3', 'year_4', 'year_5', 'year_6') if date.get(key)]
                if year:
                    year = max(year)
                    if len(str(year)) != 4:
                        year += 2000 if year <= 40 else 1900
                    new_date = new_date.replace(year=year)

                # !month
                month = [date.get(key) for key in ('month', 'month_1', 'month_2', 'month_3', 'month_4') if date.get(key)]
                if month:
                    new_date = new_date.replace(day=1)
                    new_date = new_date.replace(month=int(max(month)) if re.match('^\d+$', max(month)) else dict(january=1, february=2, march=3, april=4, june=6, july=7, august=8, september=9, october=10, november=11, december=12, jan=1, feb=2, mar=3, apr=4, may=5, jun=6, jul=7, aug=8, sep=9, sept=9, oct=10, nov=11, dec=12).get(max(month),  new_date.month))

                # !day
                day = [date.get(key) for key in ('date', 'date_2', 'date_3') if date.get(key)]
                if day:
                    new_date = new_date.replace(day=int(max(day)))

                # !daytime
                if date.get('daytime'):
                    if date['daytime'].find('this time') >= 1:
                        new_date = new_date.replace(hour=datetime(*time.localtime()[:5]).hour,
                                                    minute=datetime(*time.localtime()[:5]).minute)
                    else:
                        new_date = new_date.replace(hour=dict(morning=9, noon=12, afternoon=15, evening=18, night=21, nighttime=21, midnight=24).get(date.get('daytime'), 12))
                    # No offset because the hour was set.
                    offset = False

                # !hour
                hour = [date.get(key) for key in ('hour', 'hour_2', 'hour_3') if date.get(key)]
                if hour:
                    new_date = new_date.replace(hour=int(max(hour)))
                    am = [date.get(key) for key in ('am', 'am_1') if date.get(key)]
                    if am and max(am) in ('p', 'pm'):
                        h = int(max(hour))
                        if h < 12:
                            new_date = new_date.replace(hour=h+12)
                    # No offset because the hour was set.
                    offset = False

                    #minute
                    minute = [date.get(key) for key in ('minute', 'minute_2') if date.get(key)]
                    if minute:
                        new_date = new_date.replace(minute=int(max(minute)))

                    #second
                    seconds = date.get('seconds', 0)
                    if seconds:
                        new_date = new_date.replace(second=int(seconds))

                self.date = new_date

            elif type(date) in (int, long, float) and re.match('^\d{10}$', str(date)):
                self.date = datetime.fromtimestamp(int(date))

            elif isinstance(date, datetime):
                self.date = date

            elif date is None:
                self.date = datetime.now()

            else:
                # Set to the current date Y, M, D, H0, M0, S0
                self.date = datetime(*time.localtime()[:3])

            if tz:
                self.date = self.date.replace(tzinfo=tz)

            # end if type(date) is types.DictType: and self.date.hour == 0:
            if offset and isinstance(offset, dict):
                self.date = self.date.replace(**offset)

Example 47

Project: timestring
Source File: Range.py
View license
    def __init__(self, start, end=None, offset=None, start_of_week=0, tz=None, verbose=False):
        """`start` can be type <class timestring.Date> or <type str>
        """
        self._dates = []
        pgoffset = None

        if start is None:
            raise TimestringInvalid("Range object requires a start valie")

        if not isinstance(start, (Date, datetime)):
            start = str(start)
        if end and not isinstance(end, (Date, datetime)):
            end = str(end)

        if start and end:
            """start and end provided
            """
            self._dates = (Date(start, tz=tz), Date(end, tz=tz))

        elif start == 'infinity':
            # end was not provided
            self._dates = (Date('infinity'), Date('infinity'))

        elif re.search(r'(\s(and|to)\s)', start):
            """Both sides where provided in the start
            """
            start = re.sub('^(between|from)\s', '', start.lower())
            # Both arguments found in start variable
            r = tuple(re.split(r'(\s(and|to)\s)', start.strip()))
            self._dates = (Date(r[0], tz=tz), Date(r[-1], tz=tz))

        elif re.match(r"(\[|\()((\"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+|\-)\d{2}\")|infinity),((\"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(\.\d+)?(\+|\-)\d{2}\")|infinity)(\]|\))", start):
            """postgresql tsrange and tstzranges support
            """
            start, end = tuple(re.sub('[^\w\s\-\:\.\+\,]', '', start).split(','))
            self._dates = (Date(start), Date(end))

        else:
            now = datetime.now()
            # no tz info but offset provided, we are UTC so convert

            if re.search(r"(\+|\-)\d{2}$", start):
                # postgresql tsrange and tstzranges
                pgoffset = re.search(r"(\+|\-)\d{2}$", start).group() + " hours"

            # tz info provided
            if tz:
                now = now.replace(tzinfo=pytz.timezone(str(tz)))

            # Parse
            res = TIMESTRING_RE.search(start)
            if res:
                group = res.groupdict()
                if verbose:
                    print(dict(map(lambda a: (a, group.get(a)), filter(lambda a: group.get(a), group))))
                if (group.get('delta') or group.get('delta_2')) is not None:
                    delta = (group.get('delta') or group.get('delta_2')).lower()

                    # always start w/ today
                    start = Date("today", offset=offset, tz=tz)

                    # make delta
                    di = "%s %s" % (str(int(group['num'] or 1)), delta)

                    # this           [   x  ]
                    if group['ref'] == 'this':

                        if delta.startswith('y'):
                            start = Date(datetime(now.year, 1, 1), offset=offset, tz=tz)

                        # month
                        elif delta.startswith('month'):
                            start = Date(datetime(now.year, now.month, 1), offset=offset, tz=tz)

                        # week
                        elif delta.startswith('w'):
                            start = Date("today", offset=offset, tz=tz) - (str(Date("today", tz=tz).date.weekday())+' days')

                        # day
                        elif delta.startswith('d'):
                            start = Date("today", offset=offset, tz=tz)

                        # hour
                        elif delta.startswith('h'):
                            start = Date("today", offset=dict(hour=now.hour+1), tz=tz)

                        # minute, second
                        elif delta.startswith('m') or delta.startswith('s'):
                            start = Date("now", tz=tz)

                        else:
                            raise TimestringInvalid("Not a valid time reference")

                        end = start + di

                    #next          x [      ]
                    elif group['ref'] == 'next':
                        if int(group['num'] or 1) > 1:
                            di = "%s %s" % (str(int(group['num'] or 1) - 1), delta)
                        end = start + di

                    # ago             [     ] x
                    elif group.get('ago') or group['ref'] == 'last' and int(group['num'] or 1) == 1:
                        #if group['ref'] == 'last' and int(group['num'] or 1) == 1:
                        #    start = start - ('1 ' + delta)
                        end = start - di

                    # last & no ref   [    x]
                    else:
                        # need to include today with this reference
                        if not (delta.startswith('h') or delta.startswith('m') or delta.startswith('s')):
                            start = Range('today', offset=offset, tz=tz).end
                        end = start - di                    

                elif group.get('month_1'):
                    # a single month of this yeear
                    start = Date(start, offset=offset, tz=tz)
                    start = start.replace(day=1)
                    end = start + '1 month'

                elif group.get('year_5'):
                    # a whole year
                    start = Date(start, offset=offset, tz=tz)
                    start = start.replace(day=1, month=1)
                    end = start + '1 year'

                else:
                    # after all else, we set the end to + 1 day
                    start = Date(start, offset=offset, tz=tz)
                    end = start + '1 day'

            else:
                raise TimestringInvalid("Invalid timestring request")


            if end is None:
                # no end provided, so assume 24 hours
                end = start + '24 hours'

            if start > end:
                # flip them if this is so
                start, end = copy(end), copy(start)
            
            if pgoffset:
                start = start - pgoffset
                if end != 'infinity':
                    end = end - pgoffset

            self._dates = (start, end)

        if self._dates[0] > self._dates[1]:
            self._dates = (self._dates[0], self._dates[1] + '1 day')

Example 48

Project: stopstalk-deployment
Source File: appadmin.py
View license
def ccache():
    if is_gae:
        form = FORM(
            P(TAG.BUTTON(T("Clear CACHE?"), _type="submit", _name="yes", _value="yes")))
    else:
        cache.ram.initialize()
        cache.disk.initialize()

        form = FORM(
            P(TAG.BUTTON(
                T("Clear CACHE?"), _type="submit", _name="yes", _value="yes")),
            P(TAG.BUTTON(
                T("Clear RAM"), _type="submit", _name="ram", _value="ram")),
            P(TAG.BUTTON(
                T("Clear DISK"), _type="submit", _name="disk", _value="disk")),
        )

    if form.accepts(request.vars, session):
        session.flash = ""
        if is_gae:
            if request.vars.yes:
                cache.ram.clear()
                session.flash += T("Cache Cleared")
        else:
            clear_ram = False
            clear_disk = False
            if request.vars.yes:
                clear_ram = clear_disk = True
            if request.vars.ram:
                clear_ram = True
            if request.vars.disk:
                clear_disk = True
            if clear_ram:
                cache.ram.clear()
                session.flash += T("Ram Cleared")
            if clear_disk:
                cache.disk.clear()
                session.flash += T("Disk Cleared")
        redirect(URL(r=request))

    try:
        from guppy import hpy
        hp = hpy()
    except ImportError:
        hp = False

    import shelve
    import os
    import copy
    import time
    import math
    from gluon import portalocker

    ram = {
        'entries': 0,
        'bytes': 0,
        'objects': 0,
        'hits': 0,
        'misses': 0,
        'ratio': 0,
        'oldest': time.time(),
        'keys': []
    }

    disk = copy.copy(ram)
    total = copy.copy(ram)
    disk['keys'] = []
    total['keys'] = []

    def GetInHMS(seconds):
        hours = math.floor(seconds / 3600)
        seconds -= hours * 3600
        minutes = math.floor(seconds / 60)
        seconds -= minutes * 60
        seconds = math.floor(seconds)

        return (hours, minutes, seconds)

    if is_gae:
        gae_stats = cache.ram.client.get_stats()
        try:
            gae_stats['ratio'] = ((gae_stats['hits'] * 100) /
                (gae_stats['hits'] + gae_stats['misses']))
        except ZeroDivisionError:
            gae_stats['ratio'] = T("?")
        gae_stats['oldest'] = GetInHMS(time.time() - gae_stats['oldest_item_age'])
        total.update(gae_stats)
    else:
        for key, value in cache.ram.storage.iteritems():
            if isinstance(value, dict):
                ram['hits'] = value['hit_total'] - value['misses']
                ram['misses'] = value['misses']
                try:
                    ram['ratio'] = ram['hits'] * 100 / value['hit_total']
                except (KeyError, ZeroDivisionError):
                    ram['ratio'] = 0
            else:
                if hp:
                    ram['bytes'] += hp.iso(value[1]).size
                    ram['objects'] += hp.iso(value[1]).count
                ram['entries'] += 1
                if value[0] < ram['oldest']:
                    ram['oldest'] = value[0]
                ram['keys'].append((key, GetInHMS(time.time() - value[0])))

        for key in cache.disk.storage:
            value = cache.disk.storage[key]
            if isinstance(value, dict):
                disk['hits'] = value['hit_total'] - value['misses']
                disk['misses'] = value['misses']
                try:
                    disk['ratio'] = disk['hits'] * 100 / value['hit_total']
                except (KeyError, ZeroDivisionError):
                    disk['ratio'] = 0
            else:
                if hp:
                    disk['bytes'] += hp.iso(value[1]).size
                    disk['objects'] += hp.iso(value[1]).count
                disk['entries'] += 1
                if value[0] < disk['oldest']:
                    disk['oldest'] = value[0]
                disk['keys'].append((key, GetInHMS(time.time() - value[0])))

        ram_keys = ram.keys() # ['hits', 'objects', 'ratio', 'entries', 'keys', 'oldest', 'bytes', 'misses']
        ram_keys.remove('ratio')
        ram_keys.remove('oldest')
        for key in ram_keys:
            total[key] = ram[key] + disk[key]

        try:
            total['ratio'] = total['hits'] * 100 / (total['hits'] +
                                                total['misses'])
        except (KeyError, ZeroDivisionError):
            total['ratio'] = 0

        if disk['oldest'] < ram['oldest']:
            total['oldest'] = disk['oldest']
        else:
            total['oldest'] = ram['oldest']

        ram['oldest'] = GetInHMS(time.time() - ram['oldest'])
        disk['oldest'] = GetInHMS(time.time() - disk['oldest'])
        total['oldest'] = GetInHMS(time.time() - total['oldest'])

    def key_table(keys):
        return TABLE(
            TR(TD(B(T('Key'))), TD(B(T('Time in Cache (h:m:s)')))),
            *[TR(TD(k[0]), TD('%02d:%02d:%02d' % k[1])) for k in keys],
            **dict(_class='cache-keys',
                   _style="border-collapse: separate; border-spacing: .5em;"))

    if not is_gae:
        ram['keys'] = key_table(ram['keys'])
        disk['keys'] = key_table(disk['keys'])
        total['keys'] = key_table(total['keys'])

    return dict(form=form, total=total,
                ram=ram, disk=disk, object_stats=hp != False)

Example 49

Project: feed9
Source File: appadmin.py
View license
def ccache():
    form = FORM(
        P(TAG.BUTTON("Clear CACHE?", _type="submit", _name="yes", _value="yes")),
        P(TAG.BUTTON("Clear RAM", _type="submit", _name="ram", _value="ram")),
        P(TAG.BUTTON("Clear DISK", _type="submit", _name="disk", _value="disk")),
    )

    if form.accepts(request.vars, session):
        clear_ram = False
        clear_disk = False
        session.flash = ""
        if request.vars.yes:
            clear_ram = clear_disk = True
        if request.vars.ram:
            clear_ram = True
        if request.vars.disk:
            clear_disk = True

        if clear_ram:
            cache.ram.clear()
            session.flash += "Ram Cleared "
        if clear_disk:
            cache.disk.clear()
            session.flash += "Disk Cleared"

        redirect(URL(r=request))

    try:
        from guppy import hpy; hp=hpy()
    except ImportError:
        hp = False

    import shelve, os, copy, time, math
    from gluon import portalocker

    ram = {
        'entries': 0,
        'bytes': 0,
        'objects': 0,
        'hits': 0,
        'misses': 0,
        'ratio': 0,
        'oldest': time.time(),
        'keys': []
    }
    disk = copy.copy(ram)
    total = copy.copy(ram)
    disk['keys'] = []
    total['keys'] = []

    def GetInHMS(seconds):
        hours = math.floor(seconds / 3600)
        seconds -= hours * 3600
        minutes = math.floor(seconds / 60)
        seconds -= minutes * 60
        seconds = math.floor(seconds)

        return (hours, minutes, seconds)

    for key, value in cache.ram.storage.items():
        if isinstance(value, dict):
            ram['hits'] = value['hit_total'] - value['misses']
            ram['misses'] = value['misses']
            try:
                ram['ratio'] = ram['hits'] * 100 / value['hit_total']
            except (KeyError, ZeroDivisionError):
                ram['ratio'] = 0
        else:
            if hp:
                ram['bytes'] += hp.iso(value[1]).size
                ram['objects'] += hp.iso(value[1]).count
            ram['entries'] += 1
            if value[0] < ram['oldest']:
                ram['oldest'] = value[0]
            ram['keys'].append((key, GetInHMS(time.time() - value[0])))

    locker = open(os.path.join(request.folder,
                                        'cache/cache.lock'), 'a')
    portalocker.lock(locker, portalocker.LOCK_EX)
    disk_storage = shelve.open(os.path.join(request.folder, 'cache/cache.shelve'))
    try:
        for key, value in disk_storage.items():
            if isinstance(value, dict):
                disk['hits'] = value['hit_total'] - value['misses']
                disk['misses'] = value['misses']
                try:
                    disk['ratio'] = disk['hits'] * 100 / value['hit_total']
                except (KeyError, ZeroDivisionError):
                    disk['ratio'] = 0
            else:
                if hp:
                    disk['bytes'] += hp.iso(value[1]).size
                    disk['objects'] += hp.iso(value[1]).count
                disk['entries'] += 1
                if value[0] < disk['oldest']:
                    disk['oldest'] = value[0]
                disk['keys'].append((key, GetInHMS(time.time() - value[0])))

    finally:
        portalocker.unlock(locker)
        locker.close()
        disk_storage.close()

    total['entries'] = ram['entries'] + disk['entries']
    total['bytes'] = ram['bytes'] + disk['bytes']
    total['objects'] = ram['objects'] + disk['objects']
    total['hits'] = ram['hits'] + disk['hits']
    total['misses'] = ram['misses'] + disk['misses']
    total['keys'] = ram['keys'] + disk['keys']
    try:
        total['ratio'] = total['hits'] * 100 / (total['hits'] + total['misses'])
    except (KeyError, ZeroDivisionError):
        total['ratio'] = 0

    if disk['oldest'] < ram['oldest']:
        total['oldest'] = disk['oldest']
    else:
        total['oldest'] = ram['oldest']

    ram['oldest'] = GetInHMS(time.time() - ram['oldest'])
    disk['oldest'] = GetInHMS(time.time() - disk['oldest'])
    total['oldest'] = GetInHMS(time.time() - total['oldest'])

    def key_table(keys):
        return TABLE(
            TR(TD(B('Key')), TD(B('Time in Cache (h:m:s)'))),
            *[TR(TD(k[0]), TD('%02d:%02d:%02d' % k[1])) for k in keys],
            **dict(_class='cache-keys',
                   _style="border-collapse: separate; border-spacing: .5em;"))

    ram['keys'] = key_table(ram['keys'])
    disk['keys'] = key_table(disk['keys'])
    total['keys'] = key_table(total['keys'])

    return dict(form=form, total=total,
                ram=ram, disk=disk, object_stats=hp != False)

Example 50

Project: pyscf
Source File: rks.py
View license
def hess_elec(hess_mf, mo_energy=None, mo_coeff=None, mo_occ=None,
              atmlst=None, max_memory=4000, verbose=None):
    if isinstance(verbose, logger.Logger):
        log = verbose
    else:
        log = logger.Logger(hess_mf.stdout, hess_mf.verbose)

    time0 = (time.clock(), time.time())

    mf = hess_mf._scf
    mol = hess_mf.mol
    if mo_energy is None: mo_energy = mf.mo_energy
    if mo_occ is None:    mo_occ = mf.mo_occ
    if mo_coeff is None:  mo_coeff = mf.mo_coeff
    if atmlst is None: atmlst = range(mol.natm)

    nao, nmo = mo_coeff.shape
    nocc = int(mo_occ.sum()) // 2
    mocc = mo_coeff[:,:nocc]
    dm0 = mf.make_rdm1(mo_coeff, mo_occ)

    ni = copy.copy(mf._numint)
    if USE_XCFUN:
        try:
            ni.libxc = dft.xcfun
            xctype = ni._xc_type(mf.xc)
        except (ImportError, KeyError, NotImplementedError):
            ni.libxc = dft.libxc
            xctype = ni._xc_type(mf.xc)
    else:
        xctype = ni._xc_type(mf.xc)
    grids = mf.grids
    hyb = ni.libxc.hybrid_coeff(mf.xc)
    max_memory = 4000

    h1aos = hess_mf.make_h1(mo_coeff, mo_occ, hess_mf.chkfile, atmlst, log)
    t1 = log.timer('making H1', *time0)
    def fx(mo1):
        # *2 for alpha + beta
        dm1 = numpy.einsum('xai,pa,qi->xpq', mo1, mo_coeff, mocc*2)
        dm1 = dm1 + dm1.transpose(0,2,1)
        vindxc = _contract_xc_kernel(mf, mf.xc, dm1, max_memory)
        if abs(hyb) > 1e-10:
            vj, vk = mf.get_jk(mol, dm1)
            veff = vj - hyb * .5 * vk + vindxc
        else:
            vj = mf.get_j(mol, dm1)
            veff = vj + vindxc
        v1 = numpy.einsum('xpq,pa,qi->xai', veff, mo_coeff, mocc)
        return v1.reshape(v1.shape[0],-1)
    mo1s, e1s = hess_mf.solve_mo1(mo_energy, mo_coeff, mo_occ, h1aos,
                                  fx, atmlst, max_memory, log)
    t1 = log.timer('solving MO1', *t1)

    tmpf = tempfile.NamedTemporaryFile()
    with h5py.File(tmpf.name, 'w') as f:
        for i0, ia in enumerate(atmlst):
            mol.set_rinv_origin(mol.atom_coord(ia))
            f['rinv2aa/%d'%ia] = (mol.atom_charge(ia) *
                                  mol.intor('cint1e_ipiprinv_sph', comp=9))
            f['rinv2ab/%d'%ia] = (mol.atom_charge(ia) *
                                  mol.intor('cint1e_iprinvip_sph', comp=9))

    h1aa =(mol.intor('cint1e_ipipkin_sph', comp=9) +
           mol.intor('cint1e_ipipnuc_sph', comp=9))
    h1ab =(mol.intor('cint1e_ipkinip_sph', comp=9) +
           mol.intor('cint1e_ipnucip_sph', comp=9))
    s1aa = mol.intor('cint1e_ipipovlp_sph', comp=9)
    s1ab = mol.intor('cint1e_ipovlpip_sph', comp=9)
    s1a =-mol.intor('cint1e_ipovlp_sph', comp=3)

    # Energy weighted density matrix
    dme0 = numpy.einsum('pi,qi,i->pq', mocc, mocc, mo_energy[:nocc]) * 2

    if abs(hyb) > 1e-10:
        vj1, vk1 = _vhf.direct_mapdm('cint2e_ipip1_sph', 's2kl',
                                     ('lk->s1ij', 'jk->s1il'), dm0, 9,
                                     mol._atm, mol._bas, mol._env)
        veff1ii = vj1 - hyb * .5 * vk1
    else:
        vj1 = _vhf.direct_mapdm('cint2e_ipip1_sph', 's2kl', 'lk->s1ij', dm0, 9,
                                mol._atm, mol._bas, mol._env)
        veff1ii = vj1.copy()
    vj1[:] = 0
    if xctype == 'LDA':
        ao_deriv = 2
        for ao, mask, weight, coords \
                in ni.block_loop(mol, grids, nao, ao_deriv, max_memory, ni.non0tab):
            rho = ni.eval_rho2(mol, ao[0], mo_coeff, mo_occ, mask, 'LDA')
            vxc = ni.eval_xc(mf.xc, rho, 0, deriv=1)[1]
            vrho = vxc[0]
            aow = numpy.einsum('pi,p->pi', ao[0], weight*vrho)
            for i in range(6):
                vj1[i] += pyscf.lib.dot(ao[i+4].T, aow)
            aow = aow1 = None
    elif xctype == 'GGA':
        ao_deriv = 3
        for ao, mask, weight, coords \
                in ni.block_loop(mol, grids, nao, ao_deriv, max_memory, ni.non0tab):
            rho = ni.eval_rho2(mol, ao[:4], mo_coeff, mo_occ, mask, 'GGA')
            vxc = ni.eval_xc(mf.xc, rho, 0, deriv=1)[1]
            vrho, vgamma = vxc[:2]
            wv = numpy.empty_like(rho)
            wv[0]  = weight * vrho
            wv[1:] = rho[1:] * (weight * vgamma * 2)
            aow = numpy.einsum('npi,np->pi', ao[:4], wv)
            for i in range(6):
                vj1[i] += pyscf.lib.dot(ao[i+4].T, aow)
            aow = numpy.einsum('npi,np->pi', ao[[XXX,XXY,XXZ]], wv[1:4])
            vj1[0] += pyscf.lib.dot(aow.T, ao[0])
            aow = numpy.einsum('npi,np->pi', ao[[XXY,XYY,XYZ]], wv[1:4])
            vj1[1] += pyscf.lib.dot(aow.T, ao[0])
            aow = numpy.einsum('npi,np->pi', ao[[XXZ,XYZ,XZZ]], wv[1:4])
            vj1[2] += pyscf.lib.dot(aow.T, ao[0])
            aow = numpy.einsum('npi,np->pi', ao[[XYY,YYY,YYZ]], wv[1:4])
            vj1[3] += pyscf.lib.dot(aow.T, ao[0])
            aow = numpy.einsum('npi,np->pi', ao[[XYZ,YYZ,YZZ]], wv[1:4])
            vj1[4] += pyscf.lib.dot(aow.T, ao[0])
            aow = numpy.einsum('npi,np->pi', ao[[XZZ,YZZ,ZZZ]], wv[1:4])
            vj1[5] += pyscf.lib.dot(aow.T, ao[0])
            rho = vxc = vrho = vgamma = wv = aow = None
    else:
        raise NotImplementedError('meta-GGA')
    veff1ii += vj1[[0,1,2,1,3,4,2,4,5]]
    vj1 = vk1 = None

    t1 = log.timer('contracting cint2e_ipip1_sph', *t1)

    offsetdic = mol.offset_nr_by_atom()
    frinv = h5py.File(tmpf.name, 'r')
    rinv2aa = frinv['rinv2aa']
    rinv2ab = frinv['rinv2ab']

    de2 = numpy.zeros((mol.natm,mol.natm,3,3))
    for i0, ia in enumerate(atmlst):
        shl0, shl1, p0, p1 = offsetdic[ia]

        h_2 = rinv2ab[str(ia)] + rinv2aa[str(ia)].value.transpose(0,2,1)
        h_2[:,p0:p1] += h1ab[:,p0:p1]
        s1ao = numpy.zeros((3,nao,nao))
        s1ao[:,p0:p1] += s1a[:,p0:p1]
        s1ao[:,:,p0:p1] += s1a[:,p0:p1].transpose(0,2,1)
        s1oo = numpy.einsum('xpq,pi,qj->xij', s1ao, mocc, mocc)

        shls_slice = (shl0, shl1) + (0, mol.nbas)*3
        if abs(hyb) > 1e-10:
            vj1, vk1, vk2 = _vhf.direct_bindm('cint2e_ip1ip2_sph', 's1',
                                              ('ji->s1kl', 'li->s1kj', 'lj->s1ki'),
                                              (dm0[:,p0:p1], dm0[:,p0:p1], dm0), 9,
                                              mol._atm, mol._bas, mol._env,
                                              shls_slice=shls_slice)
            veff2 = vj1 * 2 - hyb * .5 * vk1
            veff2[:,:,p0:p1] -= hyb * .5 * vk2
            t1 = log.timer('contracting cint2e_ip1ip2_sph for atom %d'%ia, *t1)

            vj1, vk1 = _vhf.direct_bindm('cint2e_ipvip1_sph', 's2kl',
                                         ('lk->s1ij', 'li->s1kj'),
                                         (dm0, dm0[:,p0:p1]), 9,
                                         mol._atm, mol._bas, mol._env,
                                         shls_slice=shls_slice)
            veff2[:,:,p0:p1] += vj1.transpose(0,2,1)
            veff2 -= hyb * .5 * vk1.transpose(0,2,1)
            vj1 = vk1 = vk2 = None
            t1 = log.timer('contracting cint2e_ipvip1_sph for atom %d'%ia, *t1)
        else:
            vj1 = _vhf.direct_bindm('cint2e_ip1ip2_sph', 's1',
                                    'ji->s1kl', dm0[:,p0:p1], 9,
                                    mol._atm, mol._bas, mol._env,
                                    shls_slice=shls_slice)
            veff2 = vj1 * 2
            t1 = log.timer('contracting cint2e_ip1ip2_sph for atom %d'%ia, *t1)

            vj1 = _vhf.direct_bindm('cint2e_ipvip1_sph', 's2kl',
                                    'lk->s1ij', dm0, 9,
                                    mol._atm, mol._bas, mol._env,
                                    shls_slice=shls_slice)
            veff2[:,:,p0:p1] += vj1.transpose(0,2,1)
            t1 = log.timer('contracting cint2e_ipvip1_sph for atom %d'%ia, *t1)

        if xctype == 'LDA':
            ao_deriv = 1
            vj1[:] = 0
            for ao, mask, weight, coords \
                    in ni.block_loop(mol, grids, nao, ao_deriv, max_memory, ni.non0tab):
                rho = ni.eval_rho2(mol, ao[0], mo_coeff, mo_occ, mask, 'LDA')
                vxc, fxc = ni.eval_xc(mf.xc, rho, 0, deriv=2)[1:3]
                vrho = vxc[0]
                frr = fxc[0]
                half = pyscf.lib.dot(ao[0], dm0[:,p0:p1].copy())
                rho1 = numpy.einsum('xpi,pi->xp', ao[1:,:,p0:p1], half)
                aow = numpy.einsum('pi,xp->xpi', ao[0], weight*frr*rho1)
                veff2[0] += pyscf.lib.dot(ao[1].T, aow[0]) * 2
                veff2[1] += pyscf.lib.dot(ao[1].T, aow[1]) * 2
                veff2[2] += pyscf.lib.dot(ao[1].T, aow[2]) * 2
                veff2[3] += pyscf.lib.dot(ao[2].T, aow[0]) * 2
                veff2[4] += pyscf.lib.dot(ao[2].T, aow[1]) * 2
                veff2[5] += pyscf.lib.dot(ao[2].T, aow[2]) * 2
                veff2[6] += pyscf.lib.dot(ao[3].T, aow[0]) * 2
                veff2[7] += pyscf.lib.dot(ao[3].T, aow[1]) * 2
                veff2[8] += pyscf.lib.dot(ao[3].T, aow[2]) * 2
                aow = numpy.einsum('xpi,p->xpi', ao[1:,:,p0:p1], weight*vrho)
                vj1[0] += pyscf.lib.dot(aow[0].T, ao[1])
                vj1[1] += pyscf.lib.dot(aow[0].T, ao[2])
                vj1[2] += pyscf.lib.dot(aow[0].T, ao[3])
                vj1[3] += pyscf.lib.dot(aow[1].T, ao[1])
                vj1[4] += pyscf.lib.dot(aow[1].T, ao[2])
                vj1[5] += pyscf.lib.dot(aow[1].T, ao[3])
                vj1[6] += pyscf.lib.dot(aow[2].T, ao[1])
                vj1[7] += pyscf.lib.dot(aow[2].T, ao[2])
                vj1[8] += pyscf.lib.dot(aow[2].T, ao[3])
                half = aow = None

            veff2[:,:,p0:p1] += vj1.transpose(0,2,1)

        elif xctype == 'GGA':
            def get_wv(rho, rho1, weight, vxc, fxc):
                vgamma = vxc[1]
                frr, frg, fgg = fxc[:3]
                ngrid = weight.size
                sigma1 = numpy.einsum('xi,xi->i', rho[1:], rho1[1:])
                wv = numpy.empty((4,ngrid))
                wv[0]  = frr * rho1[0]
                wv[0] += frg * sigma1 * 2
                wv[1:]  = (fgg * sigma1 * 4 + frg * rho1[0] * 2) * rho[1:]
                wv[1:] += vgamma * rho1[1:] * 2
                wv *= weight
                return wv
            ao_deriv = 2
            vj1[:] = 0
            for ao, mask, weight, coords \
                    in ni.block_loop(mol, grids, nao, ao_deriv, max_memory, ni.non0tab):
                rho = ni.eval_rho2(mol, ao[:4], mo_coeff, mo_occ, mask, 'GGA')
                vxc, fxc = ni.eval_xc(mf.xc, rho, 0, deriv=2)[1:3]
                vrho, vgamma = vxc[:2]
                # (d_X \nabla_x mu) nu DM_{mu,nu}
                half = pyscf.lib.dot(ao[0], dm0[:,p0:p1].copy())
                rho1X = numpy.einsum('xpi,pi->xp', ao[[1,XX,XY,XZ],:,p0:p1], half)
                rho1Y = numpy.einsum('xpi,pi->xp', ao[[2,YX,YY,YZ],:,p0:p1], half)
                rho1Z = numpy.einsum('xpi,pi->xp', ao[[3,ZX,ZY,ZZ],:,p0:p1], half)
                # (d_X mu) (\nabla_x nu) DM_{mu,nu}
                half = pyscf.lib.dot(ao[1], dm0[:,p0:p1].copy())
                rho1X[1] += numpy.einsum('pi,pi->p', ao[1,:,p0:p1], half)
                rho1Y[1] += numpy.einsum('pi,pi->p', ao[2,:,p0:p1], half)
                rho1Z[1] += numpy.einsum('pi,pi->p', ao[3,:,p0:p1], half)
                half = pyscf.lib.dot(ao[2], dm0[:,p0:p1].copy())
                rho1X[2] += numpy.einsum('pi,pi->p', ao[1,:,p0:p1], half)
                rho1Y[2] += numpy.einsum('pi,pi->p', ao[2,:,p0:p1], half)
                rho1Z[2] += numpy.einsum('pi,pi->p', ao[3,:,p0:p1], half)
                half = pyscf.lib.dot(ao[3], dm0[:,p0:p1].copy())
                rho1X[3] += numpy.einsum('pi,pi->p', ao[1,:,p0:p1], half)
                rho1Y[3] += numpy.einsum('pi,pi->p', ao[2,:,p0:p1], half)
                rho1Z[3] += numpy.einsum('pi,pi->p', ao[3,:,p0:p1], half)

                wv = get_wv(rho, rho1X, weight, vxc, fxc) * 2  # ~ vj1*2
                aow = numpy.einsum('npi,np->pi', ao[[1,XX,XY,XZ]], wv)  # dX
                veff2[0] += pyscf.lib.dot(aow.T, ao[0])
                aow = numpy.einsum('npi,np->pi', ao[[2,YX,YY,YZ]], wv)  # dY
                veff2[3] += pyscf.lib.dot(aow.T, ao[0])
                aow = numpy.einsum('npi,np->pi', ao[[3,ZX,ZY,ZZ]], wv)  # dZ
                veff2[6] += pyscf.lib.dot(aow.T, ao[0])
                aow = numpy.einsum('npi,np->pi', ao[1:4], wv[1:4])
                veff2[0] += pyscf.lib.dot(ao[1].T, aow)
                veff2[3] += pyscf.lib.dot(ao[2].T, aow)
                veff2[6] += pyscf.lib.dot(ao[3].T, aow)
                wv = get_wv(rho, rho1Y, weight, vxc, fxc) * 2
                aow = numpy.einsum('npi,np->pi', ao[[1,XX,XY,XZ]], wv)
                veff2[1] += pyscf.lib.dot(aow.T, ao[0])
                aow = numpy.einsum('npi,np->pi', ao[[2,YX,YY,YZ]], wv)
                veff2[4] += pyscf.lib.dot(aow.T, ao[0])
                aow = numpy.einsum('npi,np->pi', ao[[3,ZX,ZY,ZZ]], wv)
                veff2[7] += pyscf.lib.dot(aow.T, ao[0])
                aow = numpy.einsum('npi,np->pi', ao[1:4], wv[1:4])
                veff2[1] += pyscf.lib.dot(ao[1].T, aow)
                veff2[4] += pyscf.lib.dot(ao[2].T, aow)
                veff2[7] += pyscf.lib.dot(ao[3].T, aow)
                wv = get_wv(rho, rho1Z, weight, vxc, fxc) * 2
                aow = numpy.einsum('npi,np->pi', ao[[1,XX,XY,XZ]], wv)
                veff2[2] += pyscf.lib.dot(aow.T, ao[0])
                aow = numpy.einsum('npi,np->pi', ao[[2,YX,YY,YZ]], wv)
                veff2[5] += pyscf.lib.dot(aow.T, ao[0])
                aow = numpy.einsum('npi,np->pi', ao[[3,ZX,ZY,ZZ]], wv)
                veff2[8] += pyscf.lib.dot(aow.T, ao[0])
                aow = numpy.einsum('npi,np->pi', ao[1:4], wv[1:4])
                veff2[2] += pyscf.lib.dot(ao[1].T, aow)
                veff2[5] += pyscf.lib.dot(ao[2].T, aow)
                veff2[8] += pyscf.lib.dot(ao[3].T, aow)

                wv = numpy.empty_like(rho)
                wv[0]  = weight * vrho * .5
                wv[1:] = rho[1:] * (weight * vgamma * 2)
                aowx = numpy.einsum('npi,np->pi', ao[[1,XX,XY,XZ]], wv)
                aowy = numpy.einsum('npi,np->pi', ao[[2,YX,YY,YZ]], wv)
                aowz = numpy.einsum('npi,np->pi', ao[[3,ZX,ZY,ZZ]], wv)
                ao1 = aowx[:,p0:p1].T.copy()
                ao2 = aowy[:,p0:p1].T.copy()
                ao3 = aowz[:,p0:p1].T.copy()
                vj1[0] += pyscf.lib.dot(ao1, ao[1])
                vj1[1] += pyscf.lib.dot(ao1, ao[2])
                vj1[2] += pyscf.lib.dot(ao1, ao[3])
                vj1[3] += pyscf.lib.dot(ao2, ao[1])
                vj1[4] += pyscf.lib.dot(ao2, ao[2])
                vj1[5] += pyscf.lib.dot(ao2, ao[3])
                vj1[6] += pyscf.lib.dot(ao3, ao[1])
                vj1[7] += pyscf.lib.dot(ao3, ao[2])
                vj1[8] += pyscf.lib.dot(ao3, ao[3])
                ao1 = ao[1,:,p0:p1].T.copy()
                ao2 = ao[2,:,p0:p1].T.copy()
                ao3 = ao[3,:,p0:p1].T.copy()
                vj1[0] += pyscf.lib.dot(ao1, aowx)
                vj1[1] += pyscf.lib.dot(ao1, aowy)
                vj1[2] += pyscf.lib.dot(ao1, aowz)
                vj1[3] += pyscf.lib.dot(ao2, aowx)
                vj1[4] += pyscf.lib.dot(ao2, aowy)
                vj1[5] += pyscf.lib.dot(ao2, aowz)
                vj1[6] += pyscf.lib.dot(ao3, aowx)
                vj1[7] += pyscf.lib.dot(ao3, aowy)
                vj1[8] += pyscf.lib.dot(ao3, aowz)

            veff2[:,:,p0:p1] += vj1.transpose(0,2,1)

        else:
            raise NotImplementedError('meta-GGA')

        for j0, ja in enumerate(atmlst):
            q0, q1 = offsetdic[ja][2:]
# *2 for double occupancy, *2 for +c.c.
            mo1  = pyscf.lib.chkfile.load(hess_mf.chkfile, 'scf_mo1/%d'%ja)
            h1ao = pyscf.lib.chkfile.load(hess_mf.chkfile, 'scf_h1ao/%d'%ia)
            dm1 = numpy.einsum('ypi,qi->ypq', mo1, mocc)
            de  = numpy.einsum('xpq,ypq->xy', h1ao, dm1) * 4
            dm1 = numpy.einsum('ypi,qi,i->ypq', mo1, mocc, mo_energy[:nocc])
            de -= numpy.einsum('xpq,ypq->xy', s1ao, dm1) * 4
            de -= numpy.einsum('xpq,ypq->xy', s1oo, e1s[j0]) * 2

            de = de.reshape(-1)
            v2aa = rinv2aa[str(ja)].value
            v2ab = rinv2ab[str(ja)].value
            de += numpy.einsum('xpq,pq->x', v2aa[:,p0:p1], dm0[p0:p1])*2
            de += numpy.einsum('xpq,pq->x', v2ab[:,p0:p1], dm0[p0:p1])*2
            de += numpy.einsum('xpq,pq->x', h_2[:,:,q0:q1], dm0[:,q0:q1])*2
            de += numpy.einsum('xpq,pq->x', veff2[:,q0:q1], dm0[q0:q1])*2
            de -= numpy.einsum('xpq,pq->x', s1ab[:,p0:p1,q0:q1], dme0[p0:p1,q0:q1])*2

            if ia == ja:
                de += numpy.einsum('xpq,pq->x', h1aa[:,p0:p1], dm0[p0:p1])*2
                de -= numpy.einsum('xpq,pq->x', v2aa, dm0)*2
                de -= numpy.einsum('xpq,pq->x', v2ab, dm0)*2
                de += numpy.einsum('xpq,pq->x', veff1ii[:,p0:p1], dm0[p0:p1])*2
                de -= numpy.einsum('xpq,pq->x', s1aa[:,p0:p1], dme0[p0:p1])*2

            de2[i0,j0] = de.reshape(3,3)

    frinv.close()
    log.timer('RHF hessian', *time0)
    return de2