jax.numpy.clip

Here are the examples of the python api jax.numpy.clip taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

63 Examples 7

3 Source : other.py
with MIT License
from bmazoure

def action_noise(action, amount, discrete):
    if amount == 0:
        return action
    if discrete:
        probs = amount / action.shape[-1] + (1 - amount) * action
        return dists.OneHotDist(probs=probs).sample()
    else:
        return jnp.clip(tfd.Normal(action, amount).sample(), -1, 1)


class RollingNorm():

3 Source : utils.py
with Apache License 2.0
from deepmind

def linf_project_fn(epsilon: float, bounds: Tuple[float, float]) -> ProjectFn:

  def project_fn(x, origin_x):
    dx = jnp.clip(x - origin_x, -epsilon, epsilon)
    return jnp.clip(origin_x + dx, bounds[0], bounds[1])

  return project_fn


def bounded_initialize_fn(

3 Source : linear_bound_utils.py
with Apache License 2.0
from deepmind

  def project_params(self, relax_params: Nest[Tensor]) -> Nest[Tensor]:
    lower_relax_params, upper_relax_params = relax_params
    return (
        jnp.clip(lower_relax_params, 0., 1.),
        jnp.clip(upper_relax_params, 0., 1.))


_parameterized_posbilinear_relaxer = functools.partial(

3 Source : quaternion.py
with Apache License 2.0
from google

def safe_acos(t, eps=1e-7):
  """A safe version of arccos which avoids evaluating at -1 or 1."""
  return jnp.arccos(jnp.clip(t, -1.0 + eps, 1.0 - eps))


def im(q):

3 Source : utils.py
with Apache License 2.0
from google

def clip_gradients(grad, grad_max_val=0.0, grad_max_norm=0.0, eps=1e-7):
  """Gradient clipping."""
  # Clip the gradient by value.
  if grad_max_val > 0:
    clip_fn = lambda z: jnp.clip(z, -grad_max_val, grad_max_val)
    grad = jax.tree_util.tree_map(clip_fn, grad)

  # Clip the (possibly value-clipped) gradient by norm.
  if grad_max_norm > 0:
    grad_norm = safe_sqrt(
        jax.tree_util.tree_reduce(
            lambda x, y: x + jnp.sum(y**2), grad, initializer=0))
    mult = jnp.minimum(1, grad_max_norm / (eps + grad_norm))
    grad = jax.tree_util.tree_map(lambda z: mult * z, grad)

  return grad


def matmul(a, b):

3 Source : rnn_mlp_lopt.py
with Apache License 2.0
from google

  def _normalize(self, state: _DynamicGradientClipperState,
                 grads: opt_base.Params) -> opt_base.Params:
    t, snd = state.iteration, state.value
    clip_amount = (snd / (1 - self.alpha**t)) * self.clip_mult
    summary.summary("dynamic_grad_clip", clip_amount)

    return jax.tree_map(lambda g: jnp.clip(g, -clip_amount, clip_amount), grads)

  def next_state_and_normalize(

3 Source : image_mlp.py
with Apache License 2.0
from google

  def normalizer(self, loss):
    num_classes = self.datasets.extra_info["num_classes"]
    maxval = 1.5 * onp.log(num_classes)
    loss = jnp.clip(loss, 0, maxval)
    return jnp.nan_to_num(loss, nan=maxval, posinf=maxval, neginf=maxval)


@gin.configurable

3 Source : gpt2.py
with Apache License 2.0
from google

def logit(x):
    x = np.clip(x, 1e-5, 1 - 1e-5)
    return np.log(x / (1 - x))


# Normalization layer used in the transformer
class Norm(objax.module.Module):

3 Source : train_utils.py
with Apache License 2.0
from google

def softmax_xent(*, logits, labels, reduction=True, kl=False):
  """Computes a softmax cross-entropy (Categorical NLL) loss over examples."""
  log_p = jax.nn.log_softmax(logits)
  nll = -jnp.sum(labels * log_p, axis=-1)
  if kl:
    nll += jnp.sum(labels * jnp.log(jnp.clip(labels, 1e-8)), axis=-1)
  return jnp.mean(nll) if reduction else nll


def accumulate_gradient(loss_and_grad_fn, params, images, labels, accum_steps):

3 Source : generation_flax_logits_process.py
with Apache License 2.0
from huggingface

    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:

        # create boolean flag to decide if min length penalty should be applied
        apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)

        scores = jnp.where(
            apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
        )

        return scores

3 Source : utils.py
with MIT License
from Hwhitetooth

def scalar_to_two_hot(x: chex.Array, num_bins: int):
    """A categorical representation of real values. Ref: https://www.nature.com/articles/s41586-020-03051-4.pdf."""
    max_val = (num_bins - 1) // 2
    x = jnp.clip(x, -max_val, max_val)
    x_low = jnp.floor(x).astype(jnp.int32)
    x_high = jnp.ceil(x).astype(jnp.int32)
    p_high = x - x_low
    p_low = 1. - p_high
    idx_low = x_low + max_val
    idx_high = x_high + max_val
    cat_low = jax.nn.one_hot(idx_low, num_bins) * p_low[..., None]
    cat_high = jax.nn.one_hot(idx_high, num_bins) * p_high[..., None]
    return cat_low + cat_high


def logits_to_scalar(logits: chex.Array):

3 Source : optim.py
with MIT License
from ku2482

def clip_gradient(
    grad: Any,
    max_value: float,
) -> Any:
    """
    Clip gradients.
    """
    return jax.tree_map(lambda g: jnp.clip(g, -max_value, max_value), grad)


@jax.jit

3 Source : preprocess.py
with MIT License
from ku2482

def add_noise(
    x: jnp.ndarray,
    key: jnp.ndarray,
    std: float,
    out_min: float = -np.inf,
    out_max: float = np.inf,
    noise_min: float = -np.inf,
    noise_max: float = np.inf,
) -> jnp.ndarray:
    """
    Add noise to actions.
    """
    noise = jnp.clip(jax.random.normal(key, x.shape), noise_min, noise_max)
    return jnp.clip(x + noise * std, out_min, out_max)


@jax.jit

3 Source : denoise_tv_iso_pgm.py
with BSD 3-Clause "New" or "Revised" License
from lanl

    def __call__(self, x: Union[JaxArray, BlockArray]) -> float:

        xint = self.y - self.lmbda * self.A(x)
        return -1.0 * self.functional(xint - jnp.clip(xint, 0.0, 1.0)) + self.functional(xint)


"""

3 Source : _balloon_lung.py
with Apache License 2.0
from MinRegret

def PropValve(x):
    y = 3.0 * x
    flow_new = 1.0 * (jnp.tanh(0.03 * (y - 130)) + 1.0)
    flow_new = jnp.clip(flow_new, 0.0, 1.72)
    return flow_new


def Solenoid(x):

3 Source : _learned_lung.py
with Apache License 2.0
from MinRegret

    def step(self, action):
        self.state = self.dynamics(self.state, action)
        self.pressure = (self.state['normalized_pressures'][-1] * self.pressure_std) + self.pressure_mean
        self.pressure = jnp.clip(self.pressure, 0.0, 100.0)

        self.target = self.waveform.at(self.time)
        reward = -jnp.abs(self.target - self.pressure)

        self.time += self.dt

        return self.observation, reward, False, {}

3 Source : model.py
with MIT License
from NTT123

    def p_mean_variance(self, x, t, clip_denoised: bool):
        x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn(x, t))

        if clip_denoised:
            x_recon = jnp.clip(x_recon, -1.0, 1.0)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
            x_start=x_recon, x_t=x, t=t
        )
        return model_mean, posterior_variance, posterior_log_variance

    def p_sample(self, x, t, rng_key, clip_denoised=True, repeat_noise=False):

3 Source : numpy_ops.py
with GNU General Public License v3.0
from PKU-NIP-Lab

def clip(a, a_min=None, a_max=None):
  a = _remove_jaxarray(a)
  a_min = _remove_jaxarray(a_min)
  a_max = _remove_jaxarray(a_max)
  return JaxArray(jnp.clip(a, a_min, a_max))


def angle(z, deg=False):

3 Source : acquisitions.py
with Apache License 2.0
from PredictiveIntelligenceLab

def EI(mean, std, best):
    # from https://people.orie.cornell.edu/pfrazier/Presentations/2011.11.INFORMS.Tutorial.pdf
    delta = -(mean - best)
    deltap = -(mean - best)
    deltap = np.clip(deltap, a_min=0.)
    Z = delta/std
    EI = deltap - np.abs(deltap)*norm.cdf(-Z) + std*norm.pdf(Z)
    return -EI[0]

@jit

3 Source : acquisitions.py
with Apache License 2.0
from PredictiveIntelligenceLab

def EIC(mean, std, best):
    # Constrained expected improvement
    delta = -(mean[0,:] - best)
    deltap = -(mean[0,:] - best)
    deltap = np.clip(deltap, a_min=0.)
    Z = delta/std[0,:]
    EI = deltap - np.abs(deltap)*norm.cdf(-Z) + std*norm.pdf(Z)
    constraints = np.prod(norm.cdf(mean[1:,:]/std[1:,:]), axis = 0)
    return -EI[0]*constraints[0]

@jit

3 Source : ops.py
with Apache License 2.0
from pyro-ppl

def _max(x, y):
    return np.clip(x, a_min=y, a_max=None)


# TODO: replace (int, float) by numbers.Number
@ops.min.register((int, float), array)

3 Source : discrete.py
with Apache License 2.0
from pyro-ppl

    def log_prob(self, value):
        log_factorial_n = gammaln(self.total_count + 1)
        log_factorial_k = gammaln(value + 1)
        log_factorial_nmk = gammaln(self.total_count - value + 1)
        normalize_term = (
            self.total_count * jnp.clip(self.logits, 0)
            + xlog1py(self.total_count, jnp.exp(-jnp.abs(self.logits)))
            - log_factorial_n
        )
        return (
            value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
        )

    @lazy_property

3 Source : flows.py
with Apache License 2.0
from pyro-ppl

def _clamp_preserve_gradients(x, min, max):
    return x + lax.stop_gradient(jnp.clip(x, a_min=min, a_max=max) - x)


# adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py
class InverseAutoregressiveTransform(Transform):

3 Source : util.py
with Apache License 2.0
from pyro-ppl

def binary_cross_entropy_with_logits(x, y):
    # compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x))
    # Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
    return jnp.clip(x, 0) + jnp.log1p(jnp.exp(-jnp.abs(x))) - x * y


def _reshape(x, shape):

3 Source : optim.py
with Apache License 2.0
from pyro-ppl

    def update(self, g, state):
        i, opt_state = state
        # clip norm
        g = tree_map(
            lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g
        )
        opt_state = self.update_fn(i, g, opt_state)
        return i + 1, opt_state


@_add_doc(optimizers.adagrad)

3 Source : generation_flax_logits_process.py
with Apache License 2.0
from UKPLab

    def __call__(
        self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
    ) -> jax_xla.DeviceArray:

        # create boolean flag to decide if min length penalty should be applied
        apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)

        scores = jnp.where(
            apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
        )

        return scores

0 Source : optim.py
with MIT License
from alonfnt

def suggest_next(
    key: Array,
    params: GParameters,
    x: Array,
    y: Array,
    bounds: Array,
    dtypes: DataTypes,
    acq: Callable,
    n_seed: int = 1000,
    lr: float = 0.1,
    n_epochs: int = 150,
) -> Tuple[Array, Array]:
    """
    Suggests the new point to sample by optimizing the acquisition function.

    Parameters:
    -----------
    key: The pseudo-random generator key used for jax random functions.
    params: Hyperparameters of the Gaussian Process Regressor.
    x: Sampled points.
    y: Sampled targets.
    bounds: Array of (2, dim) shape with the lower and upper bounds of the
            variables.y_max: The current maximum value of the target values Y.
    dtypes: The type of non-real variables in the target function.
    n_seed (optional): the number of points to probe and minimize until
            finding the one that maximizes the acquisition functions.
    lr (optional): The step size of the gradient descent.
    n_epochs (optional): The number of steps done on the descent to minimize
            the seeds.


    Returns:
    --------
    A tuple with the parameters that maximize the acquisition function and a
    jax PRGKey to be used in the next sampling.
    """

    key1, key2 = random.split(key, 2)
    dim = bounds.shape[0]

    domain = random.uniform(
        key1, shape=(n_seed, dim), minval=bounds[:, 0], maxval=bounds[:, 1]
    )

    _acq = partial(acq, params=params, x=x, y=y, dtypes=dtypes)

    J = jacobian(lambda x: _acq(x.reshape(-1, dim)).reshape())
    HS = vmap(lambda x: x + lr * J(x))

    domain = lax.fori_loop(0, n_epochs, lambda _, d: HS(d), domain)
    domain = jnp.clip(
        domain.reshape(-1, dim), a_min=bounds[:, 0], a_max=bounds[:, 1]
    )
    domain = replace_nan_values(domain)
    domain = round_integers(domain, dtypes)

    ys = _acq(domain)
    next_X = domain[ys.argmax()]
    return next_X, key2


@partial(jit, static_argnums=(1, 2))

0 Source : dists.py
with MIT License
from bmazoure

    def sample(self, *args, **kwargs):
        event = super().sample(*args, **kwargs)
        if self._clip:
            clipped = jnp.clip(event, a_min=self.low + self._clip,
                                      a_max=self.high - self._clip)
            event = event - jax.lax.stop_gradient(event) + jax.lax.stop_gradient(clipped)
        if self._mult:
            event *= self._mult
        return event

# class TanhNormalDist():
#     def __init__(self, dist):
#         self.dist = dist
    
#     def entropy(self, *args, **kwargs):
#         import ipdb;ipdb.set_trace()
#         return self.dist.entropy(*args, **kwargs)
    
#     def mode(self, *args, **kwargs):
#         return self.dist.mode(*args, **kwargs)

#     @property
#     def name(self):
#         return "tanh_normal"

#     @property
#     def dtype(self):
#         return jnp.float32

0 Source : crossentropy.py
with MIT License
from cgarciae

def crossentropy(
    target: jnp.ndarray,
    preds: jnp.ndarray,
    *,
    binary: bool = False,
    from_logits: bool = True,
    label_smoothing: tp.Optional[float] = None,
    check_bounds: bool = True,
) -> jnp.ndarray:

    n_classes = preds.shape[-1]

    if target.ndim == preds.ndim - 1:
        if target.shape != preds.shape[:-1]:
            raise ValueError(
                f"Target shape '{target.shape}' does not match preds shape '{preds.shape}'"
            )
        target = jax.nn.one_hot(target, n_classes)
    else:
        if target.ndim != preds.ndim:
            raise ValueError(
                f"Target shape '{target.shape}' does not match preds shape '{preds.shape}'"
            )

    if label_smoothing is not None:
        target = optax.smooth_labels(target, label_smoothing)

    if from_logits:
        if binary:
            loss = optax.sigmoid_binary_cross_entropy(preds, target).mean(axis=-1)
        else:
            loss = optax.softmax_cross_entropy(preds, target)
    else:
        preds = jnp.clip(preds, types.EPSILON, 1.0 - types.EPSILON)

        if binary:
            loss = target * jnp.log(preds)  # + types.EPSILON)
            loss += (1 - target) * jnp.log(1 - preds)  # + types.EPSILON)
            loss = -loss.mean(axis=-1)
        else:
            loss = -(target * jnp.log(preds)).sum(axis=-1)

    # TODO: implement check_bounds
    # if check_bounds:
    #     # set NaN where target is negative or larger/equal to the number of preds channels
    #     loss = jnp.where(target   <   0, jnp.nan, loss)
    #     loss = jnp.where(target >= n_classes, jnp.nan, loss)

    return loss


class Crossentropy(Loss):

0 Source : augmentations.py
with Apache License 2.0
from deepmind

def _color_transform_single_image(image, rng, brightness, contrast, saturation,
                                  hue, to_grayscale_prob, color_jitter_prob,
                                  apply_prob, shuffle):
  """Applies color jittering to a single image."""
  apply_rng, transform_rng = jax.random.split(rng)
  perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split(
      transform_rng, 7)

  # Whether the transform should be applied at all.
  should_apply = jax.random.uniform(apply_rng, shape=())   <  = apply_prob
  # Whether to apply grayscale transform.
  should_apply_gs = jax.random.uniform(gs_rng, shape=())  < = to_grayscale_prob
  # Whether to apply color jittering.
  should_apply_color = jax.random.uniform(cj_rng, shape=())  < = color_jitter_prob

  # Decorator to conditionally apply fn based on an index.
  def _make_cond(fn, idx):

    def identity_fn(unused_rng, x):
      return x

    def cond_fn(args, i):
      def clip(args):
        return jax.tree_map(lambda arg: jnp.clip(arg, 0., 1.), args)
      out = jax.lax.cond(should_apply & should_apply_color & (i == idx), args,
                         lambda a: clip(fn(*a)), args,
                         lambda a: identity_fn(*a))
      return jax.lax.stop_gradient(out)

    return cond_fn

  random_brightness = functools.partial(
      pix.random_brightness, max_delta=brightness)
  random_contrast = functools.partial(
      pix.random_contrast, lower=1-contrast, upper=1+contrast)
  random_hue = functools.partial(pix.random_hue, max_delta=hue)
  random_saturation = functools.partial(
      pix.random_saturation, lower=1-saturation, upper=1+saturation)
  to_grayscale = functools.partial(pix.rgb_to_grayscale, keep_dims=True)

  random_brightness_cond = _make_cond(random_brightness, idx=0)
  random_contrast_cond = _make_cond(random_contrast, idx=1)
  random_saturation_cond = _make_cond(random_saturation, idx=2)
  random_hue_cond = _make_cond(random_hue, idx=3)

  def _color_jitter(x):
    if shuffle:
      order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32))
    else:
      order = range(4)
    for idx in order:
      if brightness > 0:
        x = random_brightness_cond((b_rng, x), idx)
      if contrast > 0:
        x = random_contrast_cond((c_rng, x), idx)
      if saturation > 0:
        x = random_saturation_cond((s_rng, x), idx)
      if hue > 0:
        x = random_hue_cond((h_rng, x), idx)
    return x

  out_apply = _color_jitter(image)
  out_apply = jax.lax.cond(should_apply & should_apply_gs, out_apply,
                           to_grayscale, out_apply, lambda x: x)
  return jnp.clip(out_apply, 0., 1.)


def random_flip(images, rng):

0 Source : attacks.py
with Apache License 2.0
from deepmind

def adversarial_attack(
    params: ModelParams,
    data_spec: DataSpec,
    spec_type: verify_utils.SpecType,
    key: PRNGKey,
    num_steps: int,
    learning_rate: float,
    num_samples: int = 1,
) -> float:
  """Adversarial attack on uncertainty spec (with parameter sampling)."""
  l = jnp.clip(data_spec.input-data_spec.epsilon,
               data_spec.input_bounds[0], data_spec.input_bounds[1])
  u = jnp.clip(data_spec.input+data_spec.epsilon,
               data_spec.input_bounds[0], data_spec.input_bounds[1])
  projection_fn = lambda x: jnp.clip(x, l, u)

  forward_fn = make_forward(params, num_samples)

  def max_objective_fn_uncertainty(x, prng_key):
    logits = jnp.reshape(forward_fn(x, prng_key), [-1])
    return logits[data_spec.target_label]

  def max_objective_fn_adversarial(x, prng_key):
    logits = jnp.reshape(forward_fn(x, prng_key), [-1])
    return logits[data_spec.target_label] - logits[data_spec.true_label]

  def max_objective_fn_adversarial_softmax(x, prng_key):
    logits = jnp.reshape(forward_fn(x, prng_key), [-1])
    probs = jax.nn.softmax(logits, axis=-1)
    return probs[data_spec.target_label] - probs[data_spec.true_label]

  if (spec_type in (verify_utils.SpecType.UNCERTAINTY,
                    verify_utils.SpecType.PROBABILITY_THRESHOLD)):
    max_objective_fn = max_objective_fn_uncertainty
  elif spec_type == verify_utils.SpecType.ADVERSARIAL:
    max_objective_fn = max_objective_fn_adversarial
  elif spec_type == verify_utils.SpecType.ADVERSARIAL_SOFTMAX:
    max_objective_fn = max_objective_fn_adversarial_softmax
  else:
    raise ValueError('Unsupported spec.')

  return _run_attack(
      max_objective_fn=max_objective_fn,
      projection_fn=projection_fn,
      x_init=data_spec.input,
      prng_key=key,
      num_steps=num_steps,
      learning_rate=learning_rate)

0 Source : linear_bound_utils.py
with Apache License 2.0
from deepmind

  def project_params(self, relax_params: Nest[Tensor]) -> Nest[Tensor]:
    return jax.tree_map(lambda x: jnp.clip(x, 0., 1.), relax_params)


def eltwise_linfun_from_coeff(slope: Tensor, offset: Tensor) -> LinFun:

0 Source : linear_bound_utils.py
with Apache License 2.0
from deepmind

  def project_params(self, relax_params: Nest[Tensor]) -> Nest[Tensor]:
    return jnp.clip(relax_params, 0., 1.)


def _parameterized_relu_relaxer(

0 Source : clipping.py
with Apache License 2.0
from deepmind

def clip(max_delta: chex.Numeric) -> base.GradientTransformation:
  """Clips updates element-wise, to be in ``[-max_delta, +max_delta]``.

  Args:
    max_delta: The maximum absolute value for each element in the update.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(params):
    del params
    return ClipState()

  def update_fn(updates, state, params=None):
    del params
    updates = jax.tree_map(lambda g: jnp.clip(g, -max_delta, max_delta),
                           updates)
    return updates, state

  return base.GradientTransformation(init_fn, update_fn)


def clip_by_block_rms(threshold: float) -> base.GradientTransformation:

0 Source : schedule.py
with Apache License 2.0
from deepmind

def polynomial_schedule(
    init_value: chex.Scalar,
    end_value: chex.Scalar,
    power: chex.Scalar,
    transition_steps: int,
    transition_begin: int = 0
) -> base.Schedule:
  """Constructs a schedule with polynomial transition from init to end value.

  Args:
    init_value: initial value for the scalar to be annealed.
    end_value: end value of the scalar to be annealed.
    power: the power of the polynomial used to transition from init to end.
    transition_steps: number of steps over which annealing takes place,
      the scalar starts changing at `transition_begin` steps and completes
      the transition by `transition_begin + transition_steps` steps.
      If `transition_steps   <  = 0`, then the entire annealing process is disabled
      and the value is held fixed at `init_value`.
    transition_begin: must be positive. After how many steps to start annealing
      (before this many steps the scalar value is held fixed at `init_value`).

  Returns:
    schedule: A function that maps step counts to values.
  """
  if transition_steps  < = 0:
    logging.info(
        'A polynomial schedule was set with a non-positive `transition_steps` '
        'value; this results in a constant schedule with value `init_value`.')
    return lambda count: init_value

  if transition_begin  <  0:
    logging.info(
        'An exponential schedule was set with a negative `transition_begin` '
        'value; this will result in `transition_begin` falling back to `0`.')
    transition_begin = 0

  def schedule(count):
    count = jnp.clip(count - transition_begin, 0, transition_steps)
    frac = 1 - count / transition_steps
    return (init_value - end_value) * (frac**power) + end_value
  return schedule


# Alias polynomial schedule to linear schedule for convenience.
def linear_schedule(

0 Source : interprenet.py
with Apache License 2.0
from FINRAOS

def clip(x, eps=2 ** -16):
    return jax.numpy.clip(x, eps, 1 - eps)


@public.add

0 Source : api_test.py
with Apache License 2.0
from google

  def test_clip_gradient(self):
    # https://github.com/google/jax/issues/2784
    @api.custom_vjp
    def _clip_gradient(lo, hi, x):
      return x  # identity function when not differentiating

    def clip_gradient_fwd(lo, hi, x):
      return x, (lo, hi,)

    def clip_gradient_bwd(res, g):
      lo, hi = res
      return (None, None, jnp.clip(g, lo, hi),)

    _clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

    def clip_gradient(x):
      lo = -0.1
      hi = x + 0.1
      return _clip_gradient(lo, hi, x)

    g = jax.grad(clip_gradient)(0.1)  # doesn't crash
    self.assertAllClose(g, jnp.array(0.2))

  def test_nestable_vjp(self):

0 Source : base.py
with Apache License 2.0
from google

  def update(self, opt_state, grad, *args, **kwargs):
    grad = jax.tree_map(lambda x: jnp.clip(x, -self.grad_clip, self.grad_clip),
                        grad)
    return self.opt.update(opt_state, grad, *args, **kwargs)

0 Source : image_mlp_ae.py
with Apache License 2.0
from google

def _make_task(hk_fn: LossFN, datasets: datasets_base.Datasets) -> base.Task:
  """Make a Task subclass for the haiku loss and datasets."""
  init_net, apply_net = hk.transform(hk_fn)

  class _Task(base.Task):
    """Annonomous task object with corresponding loss and datasets."""

    def __init__(self):
      self.datasets = datasets

    def init(self, key: PRNGKey) -> base.Params:
      batch = next(datasets.train)
      return init_net(key, batch)

    def loss(self, params, key, data):
      return apply_net(params, key, data)

    def normalizer(self, loss):
      return jnp.clip(loss, .0, 1.)

  return _Task()


@gin.configurable

0 Source : train_utils.py
with Apache License 2.0
from google

def create_learning_rate_schedule(total_steps,
                                  base=0.,
                                  decay_type="linear",
                                  warmup_steps=0,
                                  linear_end=1e-5):
  """Creates a learning rate schedule.

  Currently only warmup + {linear,cosine} but will be a proper mini-language
  like preprocessing one in the future.

  Args:
    total_steps: The total number of steps to run.
    base: The starting learning-rate (without warmup).
    decay_type: 'linear' or 'cosine'.
    warmup_steps: how many steps to warm up for.
    linear_end: Minimum learning rate.

  Returns:
    A function learning_rate(step): float -> {"learning_rate": float}.
  """

  def step_fn(step):
    """Step to learning rate function."""
    lr = base

    progress = (step - warmup_steps) / float(total_steps - warmup_steps)
    progress = jnp.clip(progress, 0.0, 1.0)
    if decay_type == "linear":
      lr = linear_end + (lr - linear_end) * (1.0 - progress)
    elif decay_type == "cosine":
      lr = lr * 0.5 * (1. + jnp.cos(jnp.pi * progress))
    else:
      raise ValueError(f"Unknown lr type {decay_type}")

    if warmup_steps:
      lr = lr * jnp.minimum(1., step / warmup_steps)

    return jnp.asarray(lr, dtype=jnp.float32)

  return step_fn


def get_weight_decay_fn(

0 Source : img_log_utils.py
with Apache License 2.0
from google-research

  def append_images(
      self,
      image_key: Optional[str],
      rgb: f32["h w 3"],
      rgb_ground_truth: f32["h w 3"],
      semantic_logits: Optional[f32["h w c"]],
      semantic_ground_truth: Optional[i32["h w 1"]],
  ):
    """Append a set of images to the log.

    Args:
      image_key: String identifier for this frame. Format is,
        ${SCENE_NAME}_rgba_${IMAGE_NAME}. Use None if this information is
        missing.
      rgb: RGB prediction image. Values must be in [0, 1].
      rgb_ground_truth: RGB ground truth image. Values must be in [0, 1].
      semantic_logits: Optional semantic predictions. Contains per-class
        logits.
      semantic_ground_truth: Optional semantic prediction ground truth.
        Contains integer ID of the ground truth semantic class. Values must be
        in {0...255}.
    """
    self._image_keys.append(image_key)
    self._rgb.append(jax.device_get(jnp.clip(rgb, 0, 1)))
    self._rgb_ground_truth.append(
        jax.device_get(jnp.clip(rgb_ground_truth, 0, 1)))

    if semantic_logits is not None and semantic_logits.shape[-1]:
      # Semantic prediction.
      semantic = (jnp.argmax(semantic_logits, axis=-1)
                  .reshape(semantic_logits.shape[0:-1] + (1,)))
      self._semantic.append(jax.device_get(semantic))

      # Semantic ground truth.
      self._semantic_ground_truth.append(jax.device_get(semantic_ground_truth))

  @property

0 Source : test_modeling_flax_bart.py
with Apache License 2.0
from huggingface

    def prepare_config_and_inputs(self):
        input_ids = jnp.clip(ids_tensor([self.batch_size, self.seq_length], self.vocab_size), 3, self.vocab_size)

        attention_mask = None
        if self.use_attention_mask:
            attention_mask = random_attention_mask([self.batch_size, self.seq_length])

        config = BartConfig(
            vocab_size=self.vocab_size,
            d_model=self.hidden_size,
            encoder_layers=self.num_hidden_layers,
            decoder_layers=self.num_hidden_layers,
            encoder_attention_heads=self.num_attention_heads,
            decoder_attention_heads=self.num_attention_heads,
            encoder_ffn_dim=self.intermediate_size,
            decoder_ffn_dim=self.intermediate_size,
            dropout=self.hidden_dropout_prob,
            attention_dropout=self.attention_probs_dropout_prob,
            max_position_embeddings=self.max_position_embeddings,
            eos_token_id=self.eos_token_id,
            bos_token_id=self.bos_token_id,
            pad_token_id=self.pad_token_id,
            initializer_range=self.initializer_range,
            use_cache=False,
        )

        return config, input_ids, attention_mask

    def prepare_config_and_inputs_for_common(self):

0 Source : agents.py
with MIT License
from Hwhitetooth

    def mcts(self, rng_key: chex.PRNGKey, params: Params, root: AgentOutput, is_eval: bool):
        num_actions = self._action_space.n
        max_search_depth = self._max_search_depth
        c1 = self._mcts_c1
        c2 = self._mcts_c2
        discount_factor = self._discount_factor

        def simulate(rng_key: chex.PRNGKey, tree: Tree):
            # First compute the minimum and the maximum action-value in the current tree.
            # Note that these statistics are hard to maintain incrementally because they are non-monotonic.
            is_valid = jnp.clip(tree.visit_count, 0, 1)
            action_value = tree.action_value
            q_min = jnp.min(jnp.where(is_valid, action_value, jnp.full_like(action_value, jnp.inf)))
            q_max = jnp.max(jnp.where(is_valid, action_value, jnp.full_like(action_value, -jnp.inf)))
            q_min = jax.lax.select(is_valid.sum() == 0, 0., q_min)
            q_max = jax.lax.select(is_valid.sum() == 0, 0., q_max)

            def _select_action(rng_key: chex.PRNGKey, t, q_mean):
                # Assign an estimated value to the unvisited nodes.
                # See Eq. (8) in https://arxiv.org/pdf/2111.00210.pdf
                # and https://github.com/YeWR/EfficientZero/blob/main/core/ctree/cnode.cpp#L96.
                q = action_value[t]
                q = jax.lax.select(tree.visit_count[t] > 0, q, jnp.full_like(q, q_mean))
                # Normalize the action-values of the current node so that they are in [0, 1].
                # This is required for the pUCT rule.
                # See Eq. (5) in https://www.nature.com/articles/s41586-020-03051-4.pdf
                q = (q - q_min) / jnp.maximum(q_max - q_min, self._q_normalize_epsilon)
                p = tree.prob[t]
                n = tree.visit_count[t]
                # The action scores are computed by the pUCT rule.
                # See Eq. (2) in https://www.nature.com/articles/s41586-020-03051-4.pdf.
                score = q + p * jnp.sqrt(n.sum()) / (1 + n) * (c1 + jnp.log((n.sum() + c2 + 1) / c2))
                best_actions = score >= score.max() - self._child_select_epsilon
                tie_breaking_prob = best_actions / best_actions.sum()
                return jax.random.choice(rng_key, num_actions, p=tie_breaking_prob)

            def _cond(loop_state):
                rng_key, p, a, q_mean = loop_state
                return jnp.logical_and(tree.depth[p] + 1   <   max_search_depth, tree.visit_count[p, a] > 0)

            def _body(loop_state):
                rng_key, p, a, q_mean = loop_state
                p = tree.child[p, a]
                is_valid_child = jnp.clip(tree.visit_count[p], 0, 1)
                q_mean = (q_mean + jnp.sum(tree.action_value[p] * is_valid_child)) / (jnp.sum(is_valid_child) + 1)
                rng_key, sub_key = jax.random.split(rng_key)
                a = _select_action(sub_key, p, q_mean)
                return rng_key, p, a, q_mean

            is_valid_child = jnp.clip(tree.visit_count[0], 0, 1)
            q_mean = jnp.sum(tree.action_value[0] * is_valid_child) / jnp.maximum(jnp.sum(is_valid_child), 1)
            rng_key, sub_key = jax.random.split(rng_key)
            a = _select_action(sub_key, 0, q_mean)
            _, p, a, _ = jax.lax.while_loop(
                _cond,
                _body,
                (rng_key, 0, a, q_mean),
            )
            return p, a

        def expand(tree: Tree, p, a, c):
            p_state = tree.state[p]
            model_out = self.model_step(params, p_state, a)
            tree = tree._replace(
                state=tree.state.at[c].set(model_out.state),
                logits=tree.logits.at[c].set(model_out.logits),
                prob=tree.prob.at[c].set(jax.nn.softmax(model_out.logits)),
                reward_logits=tree.reward_logits.at[c].set(model_out.reward_logits),
                reward=tree.reward.at[c].set(model_out.reward),
                value_logits=tree.value_logits.at[c].set(model_out.value_logits),
                value=tree.value.at[c].set(model_out.value),
                depth=tree.depth.at[c].set(tree.depth[p] + 1),
                parent=tree.parent.at[c].set(p),
                parent_action=tree.parent_action.at[c].set(a),
                child=tree.child.at[p, a].set(c),
            )
            return tree

        def backup(tree: Tree, c):
            def _update(tree, c, g):
                g = tree.reward[c] + discount_factor * g
                p = tree.parent[c]
                a = tree.parent_action[c]
                new_n = tree.visit_count[p, a] + 1
                new_q = (tree.action_value[p, a] * tree.visit_count[p, a] + g) / new_n
                tree = tree._replace(
                    visit_count=tree.visit_count.at[p, a].add(1),
                    action_value=tree.action_value.at[p, a].set(new_q),
                )
                return tree, p, g

            tree, _, _ = jax.lax.while_loop(
                lambda t: t[1] > 0,
                lambda t: _update(t[0], t[1], t[2]),
                (tree, c, tree.value[c]),
            )
            return tree

        def body_fn(sim, loop_state):
            rng_key, tree = loop_state
            rng_key, simulate_key = jax.random.split(rng_key)
            p, a = simulate(simulate_key, tree)
            c = sim + 1
            tree = expand(tree, p, a, c)
            tree = backup(tree, c)
            return rng_key, tree

        rng_key, init_key = jax.random.split(rng_key)
        tree = self.init_tree(init_key, root, is_eval)
        rng_key, tree = jax.lax.fori_loop(
            0, self._num_simulations, body_fn, (rng_key, tree))

        return tree

    def act_prob(self, visit_count: chex.Array, temperature: float):

0 Source : policy.py
with MIT License
from ikostrikov

    def __call__(self,
                 observations: jnp.ndarray,
                 temperature: float = 1.0,
                 training: bool = False) -> tfd.Distribution:
        outputs = MLP(self.hidden_dims,
                      activate_final=True,
                      dropout_rate=self.dropout_rate)(observations,
                                                      training=training)

        means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)

        if self.state_dependent_std:
            log_stds = nn.Dense(self.action_dim,
                                kernel_init=default_init(
                                    self.log_std_scale))(outputs)
        else:
            log_stds = self.param('log_stds', nn.initializers.zeros,
                                  (self.action_dim, ))

        log_std_min = self.log_std_min or LOG_STD_MIN
        log_std_max = self.log_std_max or LOG_STD_MAX
        log_stds = jnp.clip(log_stds, log_std_min, log_std_max)

        if not self.tanh_squash_distribution:
            means = nn.tanh(means)

        base_dist = tfd.MultivariateNormalDiag(loc=means,
                                               scale_diag=jnp.exp(log_stds) *
                                               temperature)
        if self.tanh_squash_distribution:
            return tfd.TransformedDistribution(distribution=base_dist,
                                               bijector=tfb.Tanh())
        else:
            return base_dist


@functools.partial(jax.jit, static_argnames=('actor_def', 'distribution'))

0 Source : jax_backend.py
with Apache License 2.0
from LinjianMa

    def clip(tensor, a_min=None, a_max=None, inplace=False):
        return np.clip(tensor, a_min, a_max)

    @staticmethod

0 Source : _pendulum.py
with Apache License 2.0
from MinRegret

    def __init__(self, reward_fn=None, seed=0, horizon=50):
        # self.reward_fn = reward_fn or default_reward_fn
        self.dt = 0.05
        self.viewer = None

        self.state_size = 2
        self.action_size = 1
        self.action_dim = 1 # redundant with action_size but needed by ILQR
        
        self.H = horizon

        self.n, self.m = 2, 1
        self.angle_normalize = angle_normalize
        self.nsamples = 0

        self.random = Random(seed)

        self.reset()
        
        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            th, thdot = state
            g = 10.0
            m = 1.0
            ell = 1.0
            dt = self.dt

            # Do not limit the control signals
            action = jnp.clip(action, -self.max_torque, self.max_torque)

            newthdot = (
                thdot + (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 / (m * ell ** 2) * action) * dt
            )
            newth = th + newthdot * dt
            newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)

            return jnp.reshape(jnp.array([newth, newthdot]), (2,))
        
        @jax.jit
        def c(x, u):
            # return np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2))
            return angle_normalize(x[0])**2 + .1*(u[0]**2)
        
        self.reward_fn = reward_fn or c
        self.dynamics = _dynamics
        self.f, self.f_x, self.f_u = (
                _dynamics,
                jax.jacfwd(_dynamics, argnums=0),
                jax.jacfwd(_dynamics, argnums=1),
            )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
                c,
                jax.grad(c, argnums=0),
                jax.grad(c, argnums=1),
                jax.hessian(c, argnums=0),
                jax.hessian(c, argnums=1),
            )

    def reset(self):

0 Source : _balloon_lung.py
with Apache License 2.0
from MinRegret

    def dynamics(self, state, action):
        """
        state: (volume, pressure)
        action: (u_in, u_out)
        """
        volume, pressure = state['volume'], state['pressure']
        u_in, u_out = action

        flow = jnp.clip(PropValve(u_in) * self.R, 0.0, 2.0)
        flow -= jax.lax.cond(
            pressure > self.peep_valve,
            lambda x: jnp.clip(Solenoid(u_out), 0.0, 2.0) * 0.05 * pressure,
            lambda x: 0.0,
            flow,
        )

        volume += flow * self.dt
        volume += jax.lax.cond(
            self.leak,
            lambda x: (self.dt / (5.0 + self.dt) * (self.min_volume - volume)),
            lambda x: 0.0,
            0.0,
        )

        r = (3.0 * volume / (4.0 * jnp.pi)) ** (1.0 / 3.0)
        pressure = self.P0 + self.PC * (1.0 - (self.r0 / r) ** 6.0) / (self.r0 ** 2.0 * r)
        # pressure = flow * self.R + volume / self.C + self.peep_valve

        return {'volume': volume, 'pressure': pressure}

    def step(self, action):

0 Source : utils.py
with MIT License
from nikikilbertus

def interp_regular_1d(x: np.ndarray,
                      xmin: float,
                      xmax: float,
                      yp: np.ndarray) -> np.ndarray:
  """One-dimensional linear interpolation.

  Returns the one-dimensional piecewise linear interpolation of the data points
  (xp, yp) evaluated at x. We extrapolate with the constants xmin and xmax
  outside the range [xmin, xmax].

  Args:
    x: The x-coordinates at which to evaluate the interpolated values.
    xmin: The lower bound of the regular input x-coordinate grid.
    xmax: The upper bound of the regular input x-coordinate grid.
    yp: The y coordinates of the data points.

  Returns:
    y: The interpolated values, same shape as x.
  """
  ny = len(yp)
  fractional_idx = (x - xmin) / (xmax - xmin)
  x_idx_unclipped = fractional_idx * (ny - 1)
  x_idx = np.clip(x_idx_unclipped, 0, ny - 1)
  idx_below = np.floor(x_idx)
  idx_above = np.minimum(idx_below + 1, ny - 1)
  idx_below = np.maximum(idx_above - 1, 0)
  y_ref_below = yp[idx_below.astype(np.int32)]
  y_ref_above = yp[idx_above.astype(np.int32)]
  t = x_idx - idx_below
  y = t * y_ref_above + (1 - t) * y_ref_below
  return y


interp1d = jit(vmap(interp_regular_1d, in_axes=(None, None, None, 0)))

0 Source : text2mel.py
with MIT License
from NTT123

def text2mel(
    text: str, lexicon_fn=FLAGS.data_dir / "lexicon.txt", silence_duration: float = -1.0
):
    tokens = text2tokens(text, lexicon_fn)
    durations = predict_duration(tokens)
    durations = jnp.where(
        np.array(tokens)[None, :] == FLAGS.sp_index,
        jnp.clip(durations, a_min=silence_duration, a_max=None),
        durations,
    )
    durations = jnp.where(
        np.array(tokens)[None, :] == FLAGS.word_end_index, 0.0, durations
    )
    mels = predict_mel(tokens, durations)
    if tokens[-1] == FLAGS.sp_index:
        end_silence = durations[0, -1].item()
        silence_frame = int(end_silence * FLAGS.sample_rate / (FLAGS.n_fft // 4))
        mels = mels[:, : (mels.shape[1] - silence_frame)]
    return mels


if __name__ == "__main__":

0 Source : utils.py
with Apache License 2.0
from PredictiveIntelligenceLab

def fit_kernel_density(X, xi, weights = None, bw=None):

    X, weights = onp.array(X), onp.array(weights)
    X = X.flatten()
    if bw is None:
        try:
            sc = gaussian_kde(X, weights=weights)
            bw = onp.sqrt(sc.covariance).flatten()[0]
        except:
            bw = 1.0
        if bw   <   1e-8:
            bw = 1.0


    kde_pdf_x, kde_pdf_y = FFTKDE(bw=bw).fit(X, weights).evaluate()

    # Define the interpolation function
    interp1d_fun = interp1d(kde_pdf_x,
                            kde_pdf_y,
                            kind = 'linear',
                            fill_value = 'extrapolate')

    # Evaluate the weights on the input data
    pdf = interp1d_fun(xi)
    return np.clip(pdf, a_min=0.0) + 1e-8

def init_NN(Q):

See More Examples