Here are the examples of the python api jax.numpy.clip taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.
63 Examples
3
Source : other.py
with MIT License
from bmazoure
with MIT License
from bmazoure
def action_noise(action, amount, discrete):
if amount == 0:
return action
if discrete:
probs = amount / action.shape[-1] + (1 - amount) * action
return dists.OneHotDist(probs=probs).sample()
else:
return jnp.clip(tfd.Normal(action, amount).sample(), -1, 1)
class RollingNorm():
3
Source : utils.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def linf_project_fn(epsilon: float, bounds: Tuple[float, float]) -> ProjectFn:
def project_fn(x, origin_x):
dx = jnp.clip(x - origin_x, -epsilon, epsilon)
return jnp.clip(origin_x + dx, bounds[0], bounds[1])
return project_fn
def bounded_initialize_fn(
3
Source : linear_bound_utils.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def project_params(self, relax_params: Nest[Tensor]) -> Nest[Tensor]:
lower_relax_params, upper_relax_params = relax_params
return (
jnp.clip(lower_relax_params, 0., 1.),
jnp.clip(upper_relax_params, 0., 1.))
_parameterized_posbilinear_relaxer = functools.partial(
3
Source : quaternion.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def safe_acos(t, eps=1e-7):
"""A safe version of arccos which avoids evaluating at -1 or 1."""
return jnp.arccos(jnp.clip(t, -1.0 + eps, 1.0 - eps))
def im(q):
3
Source : utils.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def clip_gradients(grad, grad_max_val=0.0, grad_max_norm=0.0, eps=1e-7):
"""Gradient clipping."""
# Clip the gradient by value.
if grad_max_val > 0:
clip_fn = lambda z: jnp.clip(z, -grad_max_val, grad_max_val)
grad = jax.tree_util.tree_map(clip_fn, grad)
# Clip the (possibly value-clipped) gradient by norm.
if grad_max_norm > 0:
grad_norm = safe_sqrt(
jax.tree_util.tree_reduce(
lambda x, y: x + jnp.sum(y**2), grad, initializer=0))
mult = jnp.minimum(1, grad_max_norm / (eps + grad_norm))
grad = jax.tree_util.tree_map(lambda z: mult * z, grad)
return grad
def matmul(a, b):
3
Source : rnn_mlp_lopt.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def _normalize(self, state: _DynamicGradientClipperState,
grads: opt_base.Params) -> opt_base.Params:
t, snd = state.iteration, state.value
clip_amount = (snd / (1 - self.alpha**t)) * self.clip_mult
summary.summary("dynamic_grad_clip", clip_amount)
return jax.tree_map(lambda g: jnp.clip(g, -clip_amount, clip_amount), grads)
def next_state_and_normalize(
3
Source : image_mlp.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def normalizer(self, loss):
num_classes = self.datasets.extra_info["num_classes"]
maxval = 1.5 * onp.log(num_classes)
loss = jnp.clip(loss, 0, maxval)
return jnp.nan_to_num(loss, nan=maxval, posinf=maxval, neginf=maxval)
@gin.configurable
3
Source : gpt2.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def logit(x):
x = np.clip(x, 1e-5, 1 - 1e-5)
return np.log(x / (1 - x))
# Normalization layer used in the transformer
class Norm(objax.module.Module):
3
Source : train_utils.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def softmax_xent(*, logits, labels, reduction=True, kl=False):
"""Computes a softmax cross-entropy (Categorical NLL) loss over examples."""
log_p = jax.nn.log_softmax(logits)
nll = -jnp.sum(labels * log_p, axis=-1)
if kl:
nll += jnp.sum(labels * jnp.log(jnp.clip(labels, 1e-8)), axis=-1)
return jnp.mean(nll) if reduction else nll
def accumulate_gradient(loss_and_grad_fn, params, images, labels, accum_steps):
3
Source : generation_flax_logits_process.py
with Apache License 2.0
from huggingface
with Apache License 2.0
from huggingface
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
# create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
scores = jnp.where(
apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
)
return scores
3
Source : utils.py
with MIT License
from Hwhitetooth
with MIT License
from Hwhitetooth
def scalar_to_two_hot(x: chex.Array, num_bins: int):
"""A categorical representation of real values. Ref: https://www.nature.com/articles/s41586-020-03051-4.pdf."""
max_val = (num_bins - 1) // 2
x = jnp.clip(x, -max_val, max_val)
x_low = jnp.floor(x).astype(jnp.int32)
x_high = jnp.ceil(x).astype(jnp.int32)
p_high = x - x_low
p_low = 1. - p_high
idx_low = x_low + max_val
idx_high = x_high + max_val
cat_low = jax.nn.one_hot(idx_low, num_bins) * p_low[..., None]
cat_high = jax.nn.one_hot(idx_high, num_bins) * p_high[..., None]
return cat_low + cat_high
def logits_to_scalar(logits: chex.Array):
3
Source : optim.py
with MIT License
from ku2482
with MIT License
from ku2482
def clip_gradient(
grad: Any,
max_value: float,
) -> Any:
"""
Clip gradients.
"""
return jax.tree_map(lambda g: jnp.clip(g, -max_value, max_value), grad)
@jax.jit
3
Source : preprocess.py
with MIT License
from ku2482
with MIT License
from ku2482
def add_noise(
x: jnp.ndarray,
key: jnp.ndarray,
std: float,
out_min: float = -np.inf,
out_max: float = np.inf,
noise_min: float = -np.inf,
noise_max: float = np.inf,
) -> jnp.ndarray:
"""
Add noise to actions.
"""
noise = jnp.clip(jax.random.normal(key, x.shape), noise_min, noise_max)
return jnp.clip(x + noise * std, out_min, out_max)
@jax.jit
3
Source : denoise_tv_iso_pgm.py
with BSD 3-Clause "New" or "Revised" License
from lanl
with BSD 3-Clause "New" or "Revised" License
from lanl
def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
xint = self.y - self.lmbda * self.A(x)
return -1.0 * self.functional(xint - jnp.clip(xint, 0.0, 1.0)) + self.functional(xint)
"""
3
Source : _balloon_lung.py
with Apache License 2.0
from MinRegret
with Apache License 2.0
from MinRegret
def PropValve(x):
y = 3.0 * x
flow_new = 1.0 * (jnp.tanh(0.03 * (y - 130)) + 1.0)
flow_new = jnp.clip(flow_new, 0.0, 1.72)
return flow_new
def Solenoid(x):
3
Source : _learned_lung.py
with Apache License 2.0
from MinRegret
with Apache License 2.0
from MinRegret
def step(self, action):
self.state = self.dynamics(self.state, action)
self.pressure = (self.state['normalized_pressures'][-1] * self.pressure_std) + self.pressure_mean
self.pressure = jnp.clip(self.pressure, 0.0, 100.0)
self.target = self.waveform.at(self.time)
reward = -jnp.abs(self.target - self.pressure)
self.time += self.dt
return self.observation, reward, False, {}
3
Source : model.py
with MIT License
from NTT123
with MIT License
from NTT123
def p_mean_variance(self, x, t, clip_denoised: bool):
x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn(x, t))
if clip_denoised:
x_recon = jnp.clip(x_recon, -1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t
)
return model_mean, posterior_variance, posterior_log_variance
def p_sample(self, x, t, rng_key, clip_denoised=True, repeat_noise=False):
3
Source : numpy_ops.py
with GNU General Public License v3.0
from PKU-NIP-Lab
with GNU General Public License v3.0
from PKU-NIP-Lab
def clip(a, a_min=None, a_max=None):
a = _remove_jaxarray(a)
a_min = _remove_jaxarray(a_min)
a_max = _remove_jaxarray(a_max)
return JaxArray(jnp.clip(a, a_min, a_max))
def angle(z, deg=False):
3
Source : acquisitions.py
with Apache License 2.0
from PredictiveIntelligenceLab
with Apache License 2.0
from PredictiveIntelligenceLab
def EI(mean, std, best):
# from https://people.orie.cornell.edu/pfrazier/Presentations/2011.11.INFORMS.Tutorial.pdf
delta = -(mean - best)
deltap = -(mean - best)
deltap = np.clip(deltap, a_min=0.)
Z = delta/std
EI = deltap - np.abs(deltap)*norm.cdf(-Z) + std*norm.pdf(Z)
return -EI[0]
@jit
3
Source : acquisitions.py
with Apache License 2.0
from PredictiveIntelligenceLab
with Apache License 2.0
from PredictiveIntelligenceLab
def EIC(mean, std, best):
# Constrained expected improvement
delta = -(mean[0,:] - best)
deltap = -(mean[0,:] - best)
deltap = np.clip(deltap, a_min=0.)
Z = delta/std[0,:]
EI = deltap - np.abs(deltap)*norm.cdf(-Z) + std*norm.pdf(Z)
constraints = np.prod(norm.cdf(mean[1:,:]/std[1:,:]), axis = 0)
return -EI[0]*constraints[0]
@jit
3
Source : ops.py
with Apache License 2.0
from pyro-ppl
with Apache License 2.0
from pyro-ppl
def _max(x, y):
return np.clip(x, a_min=y, a_max=None)
# TODO: replace (int, float) by numbers.Number
@ops.min.register((int, float), array)
3
Source : discrete.py
with Apache License 2.0
from pyro-ppl
with Apache License 2.0
from pyro-ppl
def log_prob(self, value):
log_factorial_n = gammaln(self.total_count + 1)
log_factorial_k = gammaln(value + 1)
log_factorial_nmk = gammaln(self.total_count - value + 1)
normalize_term = (
self.total_count * jnp.clip(self.logits, 0)
+ xlog1py(self.total_count, jnp.exp(-jnp.abs(self.logits)))
- log_factorial_n
)
return (
value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
)
@lazy_property
3
Source : flows.py
with Apache License 2.0
from pyro-ppl
with Apache License 2.0
from pyro-ppl
def _clamp_preserve_gradients(x, min, max):
return x + lax.stop_gradient(jnp.clip(x, a_min=min, a_max=max) - x)
# adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py
class InverseAutoregressiveTransform(Transform):
3
Source : util.py
with Apache License 2.0
from pyro-ppl
with Apache License 2.0
from pyro-ppl
def binary_cross_entropy_with_logits(x, y):
# compute -y * log(sigmoid(x)) - (1 - y) * log(1 - sigmoid(x))
# Ref: https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
return jnp.clip(x, 0) + jnp.log1p(jnp.exp(-jnp.abs(x))) - x * y
def _reshape(x, shape):
3
Source : optim.py
with Apache License 2.0
from pyro-ppl
with Apache License 2.0
from pyro-ppl
def update(self, g, state):
i, opt_state = state
# clip norm
g = tree_map(
lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g
)
opt_state = self.update_fn(i, g, opt_state)
return i + 1, opt_state
@_add_doc(optimizers.adagrad)
3
Source : generation_flax_logits_process.py
with Apache License 2.0
from UKPLab
with Apache License 2.0
from UKPLab
def __call__(
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
# create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
scores = jnp.where(
apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
)
return scores
0
Source : optim.py
with MIT License
from alonfnt
with MIT License
from alonfnt
def suggest_next(
key: Array,
params: GParameters,
x: Array,
y: Array,
bounds: Array,
dtypes: DataTypes,
acq: Callable,
n_seed: int = 1000,
lr: float = 0.1,
n_epochs: int = 150,
) -> Tuple[Array, Array]:
"""
Suggests the new point to sample by optimizing the acquisition function.
Parameters:
-----------
key: The pseudo-random generator key used for jax random functions.
params: Hyperparameters of the Gaussian Process Regressor.
x: Sampled points.
y: Sampled targets.
bounds: Array of (2, dim) shape with the lower and upper bounds of the
variables.y_max: The current maximum value of the target values Y.
dtypes: The type of non-real variables in the target function.
n_seed (optional): the number of points to probe and minimize until
finding the one that maximizes the acquisition functions.
lr (optional): The step size of the gradient descent.
n_epochs (optional): The number of steps done on the descent to minimize
the seeds.
Returns:
--------
A tuple with the parameters that maximize the acquisition function and a
jax PRGKey to be used in the next sampling.
"""
key1, key2 = random.split(key, 2)
dim = bounds.shape[0]
domain = random.uniform(
key1, shape=(n_seed, dim), minval=bounds[:, 0], maxval=bounds[:, 1]
)
_acq = partial(acq, params=params, x=x, y=y, dtypes=dtypes)
J = jacobian(lambda x: _acq(x.reshape(-1, dim)).reshape())
HS = vmap(lambda x: x + lr * J(x))
domain = lax.fori_loop(0, n_epochs, lambda _, d: HS(d), domain)
domain = jnp.clip(
domain.reshape(-1, dim), a_min=bounds[:, 0], a_max=bounds[:, 1]
)
domain = replace_nan_values(domain)
domain = round_integers(domain, dtypes)
ys = _acq(domain)
next_X = domain[ys.argmax()]
return next_X, key2
@partial(jit, static_argnums=(1, 2))
0
Source : dists.py
with MIT License
from bmazoure
with MIT License
from bmazoure
def sample(self, *args, **kwargs):
event = super().sample(*args, **kwargs)
if self._clip:
clipped = jnp.clip(event, a_min=self.low + self._clip,
a_max=self.high - self._clip)
event = event - jax.lax.stop_gradient(event) + jax.lax.stop_gradient(clipped)
if self._mult:
event *= self._mult
return event
# class TanhNormalDist():
# def __init__(self, dist):
# self.dist = dist
# def entropy(self, *args, **kwargs):
# import ipdb;ipdb.set_trace()
# return self.dist.entropy(*args, **kwargs)
# def mode(self, *args, **kwargs):
# return self.dist.mode(*args, **kwargs)
# @property
# def name(self):
# return "tanh_normal"
# @property
# def dtype(self):
# return jnp.float32
0
Source : crossentropy.py
with MIT License
from cgarciae
with MIT License
from cgarciae
def crossentropy(
target: jnp.ndarray,
preds: jnp.ndarray,
*,
binary: bool = False,
from_logits: bool = True,
label_smoothing: tp.Optional[float] = None,
check_bounds: bool = True,
) -> jnp.ndarray:
n_classes = preds.shape[-1]
if target.ndim == preds.ndim - 1:
if target.shape != preds.shape[:-1]:
raise ValueError(
f"Target shape '{target.shape}' does not match preds shape '{preds.shape}'"
)
target = jax.nn.one_hot(target, n_classes)
else:
if target.ndim != preds.ndim:
raise ValueError(
f"Target shape '{target.shape}' does not match preds shape '{preds.shape}'"
)
if label_smoothing is not None:
target = optax.smooth_labels(target, label_smoothing)
if from_logits:
if binary:
loss = optax.sigmoid_binary_cross_entropy(preds, target).mean(axis=-1)
else:
loss = optax.softmax_cross_entropy(preds, target)
else:
preds = jnp.clip(preds, types.EPSILON, 1.0 - types.EPSILON)
if binary:
loss = target * jnp.log(preds) # + types.EPSILON)
loss += (1 - target) * jnp.log(1 - preds) # + types.EPSILON)
loss = -loss.mean(axis=-1)
else:
loss = -(target * jnp.log(preds)).sum(axis=-1)
# TODO: implement check_bounds
# if check_bounds:
# # set NaN where target is negative or larger/equal to the number of preds channels
# loss = jnp.where(target < 0, jnp.nan, loss)
# loss = jnp.where(target >= n_classes, jnp.nan, loss)
return loss
class Crossentropy(Loss):
0
Source : augmentations.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def _color_transform_single_image(image, rng, brightness, contrast, saturation,
hue, to_grayscale_prob, color_jitter_prob,
apply_prob, shuffle):
"""Applies color jittering to a single image."""
apply_rng, transform_rng = jax.random.split(rng)
perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split(
transform_rng, 7)
# Whether the transform should be applied at all.
should_apply = jax.random.uniform(apply_rng, shape=()) < = apply_prob
# Whether to apply grayscale transform.
should_apply_gs = jax.random.uniform(gs_rng, shape=()) < = to_grayscale_prob
# Whether to apply color jittering.
should_apply_color = jax.random.uniform(cj_rng, shape=()) < = color_jitter_prob
# Decorator to conditionally apply fn based on an index.
def _make_cond(fn, idx):
def identity_fn(unused_rng, x):
return x
def cond_fn(args, i):
def clip(args):
return jax.tree_map(lambda arg: jnp.clip(arg, 0., 1.), args)
out = jax.lax.cond(should_apply & should_apply_color & (i == idx), args,
lambda a: clip(fn(*a)), args,
lambda a: identity_fn(*a))
return jax.lax.stop_gradient(out)
return cond_fn
random_brightness = functools.partial(
pix.random_brightness, max_delta=brightness)
random_contrast = functools.partial(
pix.random_contrast, lower=1-contrast, upper=1+contrast)
random_hue = functools.partial(pix.random_hue, max_delta=hue)
random_saturation = functools.partial(
pix.random_saturation, lower=1-saturation, upper=1+saturation)
to_grayscale = functools.partial(pix.rgb_to_grayscale, keep_dims=True)
random_brightness_cond = _make_cond(random_brightness, idx=0)
random_contrast_cond = _make_cond(random_contrast, idx=1)
random_saturation_cond = _make_cond(random_saturation, idx=2)
random_hue_cond = _make_cond(random_hue, idx=3)
def _color_jitter(x):
if shuffle:
order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32))
else:
order = range(4)
for idx in order:
if brightness > 0:
x = random_brightness_cond((b_rng, x), idx)
if contrast > 0:
x = random_contrast_cond((c_rng, x), idx)
if saturation > 0:
x = random_saturation_cond((s_rng, x), idx)
if hue > 0:
x = random_hue_cond((h_rng, x), idx)
return x
out_apply = _color_jitter(image)
out_apply = jax.lax.cond(should_apply & should_apply_gs, out_apply,
to_grayscale, out_apply, lambda x: x)
return jnp.clip(out_apply, 0., 1.)
def random_flip(images, rng):
0
Source : attacks.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def adversarial_attack(
params: ModelParams,
data_spec: DataSpec,
spec_type: verify_utils.SpecType,
key: PRNGKey,
num_steps: int,
learning_rate: float,
num_samples: int = 1,
) -> float:
"""Adversarial attack on uncertainty spec (with parameter sampling)."""
l = jnp.clip(data_spec.input-data_spec.epsilon,
data_spec.input_bounds[0], data_spec.input_bounds[1])
u = jnp.clip(data_spec.input+data_spec.epsilon,
data_spec.input_bounds[0], data_spec.input_bounds[1])
projection_fn = lambda x: jnp.clip(x, l, u)
forward_fn = make_forward(params, num_samples)
def max_objective_fn_uncertainty(x, prng_key):
logits = jnp.reshape(forward_fn(x, prng_key), [-1])
return logits[data_spec.target_label]
def max_objective_fn_adversarial(x, prng_key):
logits = jnp.reshape(forward_fn(x, prng_key), [-1])
return logits[data_spec.target_label] - logits[data_spec.true_label]
def max_objective_fn_adversarial_softmax(x, prng_key):
logits = jnp.reshape(forward_fn(x, prng_key), [-1])
probs = jax.nn.softmax(logits, axis=-1)
return probs[data_spec.target_label] - probs[data_spec.true_label]
if (spec_type in (verify_utils.SpecType.UNCERTAINTY,
verify_utils.SpecType.PROBABILITY_THRESHOLD)):
max_objective_fn = max_objective_fn_uncertainty
elif spec_type == verify_utils.SpecType.ADVERSARIAL:
max_objective_fn = max_objective_fn_adversarial
elif spec_type == verify_utils.SpecType.ADVERSARIAL_SOFTMAX:
max_objective_fn = max_objective_fn_adversarial_softmax
else:
raise ValueError('Unsupported spec.')
return _run_attack(
max_objective_fn=max_objective_fn,
projection_fn=projection_fn,
x_init=data_spec.input,
prng_key=key,
num_steps=num_steps,
learning_rate=learning_rate)
0
Source : linear_bound_utils.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def project_params(self, relax_params: Nest[Tensor]) -> Nest[Tensor]:
return jax.tree_map(lambda x: jnp.clip(x, 0., 1.), relax_params)
def eltwise_linfun_from_coeff(slope: Tensor, offset: Tensor) -> LinFun:
0
Source : linear_bound_utils.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def project_params(self, relax_params: Nest[Tensor]) -> Nest[Tensor]:
return jnp.clip(relax_params, 0., 1.)
def _parameterized_relu_relaxer(
0
Source : clipping.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def clip(max_delta: chex.Numeric) -> base.GradientTransformation:
"""Clips updates element-wise, to be in ``[-max_delta, +max_delta]``.
Args:
max_delta: The maximum absolute value for each element in the update.
Returns:
An (init_fn, update_fn) tuple.
"""
def init_fn(params):
del params
return ClipState()
def update_fn(updates, state, params=None):
del params
updates = jax.tree_map(lambda g: jnp.clip(g, -max_delta, max_delta),
updates)
return updates, state
return base.GradientTransformation(init_fn, update_fn)
def clip_by_block_rms(threshold: float) -> base.GradientTransformation:
0
Source : schedule.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def polynomial_schedule(
init_value: chex.Scalar,
end_value: chex.Scalar,
power: chex.Scalar,
transition_steps: int,
transition_begin: int = 0
) -> base.Schedule:
"""Constructs a schedule with polynomial transition from init to end value.
Args:
init_value: initial value for the scalar to be annealed.
end_value: end value of the scalar to be annealed.
power: the power of the polynomial used to transition from init to end.
transition_steps: number of steps over which annealing takes place,
the scalar starts changing at `transition_begin` steps and completes
the transition by `transition_begin + transition_steps` steps.
If `transition_steps < = 0`, then the entire annealing process is disabled
and the value is held fixed at `init_value`.
transition_begin: must be positive. After how many steps to start annealing
(before this many steps the scalar value is held fixed at `init_value`).
Returns:
schedule: A function that maps step counts to values.
"""
if transition_steps < = 0:
logging.info(
'A polynomial schedule was set with a non-positive `transition_steps` '
'value; this results in a constant schedule with value `init_value`.')
return lambda count: init_value
if transition_begin < 0:
logging.info(
'An exponential schedule was set with a negative `transition_begin` '
'value; this will result in `transition_begin` falling back to `0`.')
transition_begin = 0
def schedule(count):
count = jnp.clip(count - transition_begin, 0, transition_steps)
frac = 1 - count / transition_steps
return (init_value - end_value) * (frac**power) + end_value
return schedule
# Alias polynomial schedule to linear schedule for convenience.
def linear_schedule(
0
Source : interprenet.py
with Apache License 2.0
from FINRAOS
with Apache License 2.0
from FINRAOS
def clip(x, eps=2 ** -16):
return jax.numpy.clip(x, eps, 1 - eps)
@public.add
0
Source : api_test.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def test_clip_gradient(self):
# https://github.com/google/jax/issues/2784
@api.custom_vjp
def _clip_gradient(lo, hi, x):
return x # identity function when not differentiating
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi,)
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi),)
_clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
def clip_gradient(x):
lo = -0.1
hi = x + 0.1
return _clip_gradient(lo, hi, x)
g = jax.grad(clip_gradient)(0.1) # doesn't crash
self.assertAllClose(g, jnp.array(0.2))
def test_nestable_vjp(self):
0
Source : base.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def update(self, opt_state, grad, *args, **kwargs):
grad = jax.tree_map(lambda x: jnp.clip(x, -self.grad_clip, self.grad_clip),
grad)
return self.opt.update(opt_state, grad, *args, **kwargs)
0
Source : image_mlp_ae.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def _make_task(hk_fn: LossFN, datasets: datasets_base.Datasets) -> base.Task:
"""Make a Task subclass for the haiku loss and datasets."""
init_net, apply_net = hk.transform(hk_fn)
class _Task(base.Task):
"""Annonomous task object with corresponding loss and datasets."""
def __init__(self):
self.datasets = datasets
def init(self, key: PRNGKey) -> base.Params:
batch = next(datasets.train)
return init_net(key, batch)
def loss(self, params, key, data):
return apply_net(params, key, data)
def normalizer(self, loss):
return jnp.clip(loss, .0, 1.)
return _Task()
@gin.configurable
0
Source : train_utils.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def create_learning_rate_schedule(total_steps,
base=0.,
decay_type="linear",
warmup_steps=0,
linear_end=1e-5):
"""Creates a learning rate schedule.
Currently only warmup + {linear,cosine} but will be a proper mini-language
like preprocessing one in the future.
Args:
total_steps: The total number of steps to run.
base: The starting learning-rate (without warmup).
decay_type: 'linear' or 'cosine'.
warmup_steps: how many steps to warm up for.
linear_end: Minimum learning rate.
Returns:
A function learning_rate(step): float -> {"learning_rate": float}.
"""
def step_fn(step):
"""Step to learning rate function."""
lr = base
progress = (step - warmup_steps) / float(total_steps - warmup_steps)
progress = jnp.clip(progress, 0.0, 1.0)
if decay_type == "linear":
lr = linear_end + (lr - linear_end) * (1.0 - progress)
elif decay_type == "cosine":
lr = lr * 0.5 * (1. + jnp.cos(jnp.pi * progress))
else:
raise ValueError(f"Unknown lr type {decay_type}")
if warmup_steps:
lr = lr * jnp.minimum(1., step / warmup_steps)
return jnp.asarray(lr, dtype=jnp.float32)
return step_fn
def get_weight_decay_fn(
0
Source : img_log_utils.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def append_images(
self,
image_key: Optional[str],
rgb: f32["h w 3"],
rgb_ground_truth: f32["h w 3"],
semantic_logits: Optional[f32["h w c"]],
semantic_ground_truth: Optional[i32["h w 1"]],
):
"""Append a set of images to the log.
Args:
image_key: String identifier for this frame. Format is,
${SCENE_NAME}_rgba_${IMAGE_NAME}. Use None if this information is
missing.
rgb: RGB prediction image. Values must be in [0, 1].
rgb_ground_truth: RGB ground truth image. Values must be in [0, 1].
semantic_logits: Optional semantic predictions. Contains per-class
logits.
semantic_ground_truth: Optional semantic prediction ground truth.
Contains integer ID of the ground truth semantic class. Values must be
in {0...255}.
"""
self._image_keys.append(image_key)
self._rgb.append(jax.device_get(jnp.clip(rgb, 0, 1)))
self._rgb_ground_truth.append(
jax.device_get(jnp.clip(rgb_ground_truth, 0, 1)))
if semantic_logits is not None and semantic_logits.shape[-1]:
# Semantic prediction.
semantic = (jnp.argmax(semantic_logits, axis=-1)
.reshape(semantic_logits.shape[0:-1] + (1,)))
self._semantic.append(jax.device_get(semantic))
# Semantic ground truth.
self._semantic_ground_truth.append(jax.device_get(semantic_ground_truth))
@property
0
Source : test_modeling_flax_bart.py
with Apache License 2.0
from huggingface
with Apache License 2.0
from huggingface
def prepare_config_and_inputs(self):
input_ids = jnp.clip(ids_tensor([self.batch_size, self.seq_length], self.vocab_size), 3, self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = BartConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
initializer_range=self.initializer_range,
use_cache=False,
)
return config, input_ids, attention_mask
def prepare_config_and_inputs_for_common(self):
0
Source : agents.py
with MIT License
from Hwhitetooth
with MIT License
from Hwhitetooth
def mcts(self, rng_key: chex.PRNGKey, params: Params, root: AgentOutput, is_eval: bool):
num_actions = self._action_space.n
max_search_depth = self._max_search_depth
c1 = self._mcts_c1
c2 = self._mcts_c2
discount_factor = self._discount_factor
def simulate(rng_key: chex.PRNGKey, tree: Tree):
# First compute the minimum and the maximum action-value in the current tree.
# Note that these statistics are hard to maintain incrementally because they are non-monotonic.
is_valid = jnp.clip(tree.visit_count, 0, 1)
action_value = tree.action_value
q_min = jnp.min(jnp.where(is_valid, action_value, jnp.full_like(action_value, jnp.inf)))
q_max = jnp.max(jnp.where(is_valid, action_value, jnp.full_like(action_value, -jnp.inf)))
q_min = jax.lax.select(is_valid.sum() == 0, 0., q_min)
q_max = jax.lax.select(is_valid.sum() == 0, 0., q_max)
def _select_action(rng_key: chex.PRNGKey, t, q_mean):
# Assign an estimated value to the unvisited nodes.
# See Eq. (8) in https://arxiv.org/pdf/2111.00210.pdf
# and https://github.com/YeWR/EfficientZero/blob/main/core/ctree/cnode.cpp#L96.
q = action_value[t]
q = jax.lax.select(tree.visit_count[t] > 0, q, jnp.full_like(q, q_mean))
# Normalize the action-values of the current node so that they are in [0, 1].
# This is required for the pUCT rule.
# See Eq. (5) in https://www.nature.com/articles/s41586-020-03051-4.pdf
q = (q - q_min) / jnp.maximum(q_max - q_min, self._q_normalize_epsilon)
p = tree.prob[t]
n = tree.visit_count[t]
# The action scores are computed by the pUCT rule.
# See Eq. (2) in https://www.nature.com/articles/s41586-020-03051-4.pdf.
score = q + p * jnp.sqrt(n.sum()) / (1 + n) * (c1 + jnp.log((n.sum() + c2 + 1) / c2))
best_actions = score >= score.max() - self._child_select_epsilon
tie_breaking_prob = best_actions / best_actions.sum()
return jax.random.choice(rng_key, num_actions, p=tie_breaking_prob)
def _cond(loop_state):
rng_key, p, a, q_mean = loop_state
return jnp.logical_and(tree.depth[p] + 1 < max_search_depth, tree.visit_count[p, a] > 0)
def _body(loop_state):
rng_key, p, a, q_mean = loop_state
p = tree.child[p, a]
is_valid_child = jnp.clip(tree.visit_count[p], 0, 1)
q_mean = (q_mean + jnp.sum(tree.action_value[p] * is_valid_child)) / (jnp.sum(is_valid_child) + 1)
rng_key, sub_key = jax.random.split(rng_key)
a = _select_action(sub_key, p, q_mean)
return rng_key, p, a, q_mean
is_valid_child = jnp.clip(tree.visit_count[0], 0, 1)
q_mean = jnp.sum(tree.action_value[0] * is_valid_child) / jnp.maximum(jnp.sum(is_valid_child), 1)
rng_key, sub_key = jax.random.split(rng_key)
a = _select_action(sub_key, 0, q_mean)
_, p, a, _ = jax.lax.while_loop(
_cond,
_body,
(rng_key, 0, a, q_mean),
)
return p, a
def expand(tree: Tree, p, a, c):
p_state = tree.state[p]
model_out = self.model_step(params, p_state, a)
tree = tree._replace(
state=tree.state.at[c].set(model_out.state),
logits=tree.logits.at[c].set(model_out.logits),
prob=tree.prob.at[c].set(jax.nn.softmax(model_out.logits)),
reward_logits=tree.reward_logits.at[c].set(model_out.reward_logits),
reward=tree.reward.at[c].set(model_out.reward),
value_logits=tree.value_logits.at[c].set(model_out.value_logits),
value=tree.value.at[c].set(model_out.value),
depth=tree.depth.at[c].set(tree.depth[p] + 1),
parent=tree.parent.at[c].set(p),
parent_action=tree.parent_action.at[c].set(a),
child=tree.child.at[p, a].set(c),
)
return tree
def backup(tree: Tree, c):
def _update(tree, c, g):
g = tree.reward[c] + discount_factor * g
p = tree.parent[c]
a = tree.parent_action[c]
new_n = tree.visit_count[p, a] + 1
new_q = (tree.action_value[p, a] * tree.visit_count[p, a] + g) / new_n
tree = tree._replace(
visit_count=tree.visit_count.at[p, a].add(1),
action_value=tree.action_value.at[p, a].set(new_q),
)
return tree, p, g
tree, _, _ = jax.lax.while_loop(
lambda t: t[1] > 0,
lambda t: _update(t[0], t[1], t[2]),
(tree, c, tree.value[c]),
)
return tree
def body_fn(sim, loop_state):
rng_key, tree = loop_state
rng_key, simulate_key = jax.random.split(rng_key)
p, a = simulate(simulate_key, tree)
c = sim + 1
tree = expand(tree, p, a, c)
tree = backup(tree, c)
return rng_key, tree
rng_key, init_key = jax.random.split(rng_key)
tree = self.init_tree(init_key, root, is_eval)
rng_key, tree = jax.lax.fori_loop(
0, self._num_simulations, body_fn, (rng_key, tree))
return tree
def act_prob(self, visit_count: chex.Array, temperature: float):
0
Source : policy.py
with MIT License
from ikostrikov
with MIT License
from ikostrikov
def __call__(self,
observations: jnp.ndarray,
temperature: float = 1.0,
training: bool = False) -> tfd.Distribution:
outputs = MLP(self.hidden_dims,
activate_final=True,
dropout_rate=self.dropout_rate)(observations,
training=training)
means = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)
if self.state_dependent_std:
log_stds = nn.Dense(self.action_dim,
kernel_init=default_init(
self.log_std_scale))(outputs)
else:
log_stds = self.param('log_stds', nn.initializers.zeros,
(self.action_dim, ))
log_std_min = self.log_std_min or LOG_STD_MIN
log_std_max = self.log_std_max or LOG_STD_MAX
log_stds = jnp.clip(log_stds, log_std_min, log_std_max)
if not self.tanh_squash_distribution:
means = nn.tanh(means)
base_dist = tfd.MultivariateNormalDiag(loc=means,
scale_diag=jnp.exp(log_stds) *
temperature)
if self.tanh_squash_distribution:
return tfd.TransformedDistribution(distribution=base_dist,
bijector=tfb.Tanh())
else:
return base_dist
@functools.partial(jax.jit, static_argnames=('actor_def', 'distribution'))
0
Source : jax_backend.py
with Apache License 2.0
from LinjianMa
with Apache License 2.0
from LinjianMa
def clip(tensor, a_min=None, a_max=None, inplace=False):
return np.clip(tensor, a_min, a_max)
@staticmethod
0
Source : _pendulum.py
with Apache License 2.0
from MinRegret
with Apache License 2.0
from MinRegret
def __init__(self, reward_fn=None, seed=0, horizon=50):
# self.reward_fn = reward_fn or default_reward_fn
self.dt = 0.05
self.viewer = None
self.state_size = 2
self.action_size = 1
self.action_dim = 1 # redundant with action_size but needed by ILQR
self.H = horizon
self.n, self.m = 2, 1
self.angle_normalize = angle_normalize
self.nsamples = 0
self.random = Random(seed)
self.reset()
# @jax.jit
def _dynamics(state, action):
self.nsamples += 1
th, thdot = state
g = 10.0
m = 1.0
ell = 1.0
dt = self.dt
# Do not limit the control signals
action = jnp.clip(action, -self.max_torque, self.max_torque)
newthdot = (
thdot + (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 / (m * ell ** 2) * action) * dt
)
newth = th + newthdot * dt
newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)
return jnp.reshape(jnp.array([newth, newthdot]), (2,))
@jax.jit
def c(x, u):
# return np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2))
return angle_normalize(x[0])**2 + .1*(u[0]**2)
self.reward_fn = reward_fn or c
self.dynamics = _dynamics
self.f, self.f_x, self.f_u = (
_dynamics,
jax.jacfwd(_dynamics, argnums=0),
jax.jacfwd(_dynamics, argnums=1),
)
self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
c,
jax.grad(c, argnums=0),
jax.grad(c, argnums=1),
jax.hessian(c, argnums=0),
jax.hessian(c, argnums=1),
)
def reset(self):
0
Source : _balloon_lung.py
with Apache License 2.0
from MinRegret
with Apache License 2.0
from MinRegret
def dynamics(self, state, action):
"""
state: (volume, pressure)
action: (u_in, u_out)
"""
volume, pressure = state['volume'], state['pressure']
u_in, u_out = action
flow = jnp.clip(PropValve(u_in) * self.R, 0.0, 2.0)
flow -= jax.lax.cond(
pressure > self.peep_valve,
lambda x: jnp.clip(Solenoid(u_out), 0.0, 2.0) * 0.05 * pressure,
lambda x: 0.0,
flow,
)
volume += flow * self.dt
volume += jax.lax.cond(
self.leak,
lambda x: (self.dt / (5.0 + self.dt) * (self.min_volume - volume)),
lambda x: 0.0,
0.0,
)
r = (3.0 * volume / (4.0 * jnp.pi)) ** (1.0 / 3.0)
pressure = self.P0 + self.PC * (1.0 - (self.r0 / r) ** 6.0) / (self.r0 ** 2.0 * r)
# pressure = flow * self.R + volume / self.C + self.peep_valve
return {'volume': volume, 'pressure': pressure}
def step(self, action):
0
Source : utils.py
with MIT License
from nikikilbertus
with MIT License
from nikikilbertus
def interp_regular_1d(x: np.ndarray,
xmin: float,
xmax: float,
yp: np.ndarray) -> np.ndarray:
"""One-dimensional linear interpolation.
Returns the one-dimensional piecewise linear interpolation of the data points
(xp, yp) evaluated at x. We extrapolate with the constants xmin and xmax
outside the range [xmin, xmax].
Args:
x: The x-coordinates at which to evaluate the interpolated values.
xmin: The lower bound of the regular input x-coordinate grid.
xmax: The upper bound of the regular input x-coordinate grid.
yp: The y coordinates of the data points.
Returns:
y: The interpolated values, same shape as x.
"""
ny = len(yp)
fractional_idx = (x - xmin) / (xmax - xmin)
x_idx_unclipped = fractional_idx * (ny - 1)
x_idx = np.clip(x_idx_unclipped, 0, ny - 1)
idx_below = np.floor(x_idx)
idx_above = np.minimum(idx_below + 1, ny - 1)
idx_below = np.maximum(idx_above - 1, 0)
y_ref_below = yp[idx_below.astype(np.int32)]
y_ref_above = yp[idx_above.astype(np.int32)]
t = x_idx - idx_below
y = t * y_ref_above + (1 - t) * y_ref_below
return y
interp1d = jit(vmap(interp_regular_1d, in_axes=(None, None, None, 0)))
0
Source : text2mel.py
with MIT License
from NTT123
with MIT License
from NTT123
def text2mel(
text: str, lexicon_fn=FLAGS.data_dir / "lexicon.txt", silence_duration: float = -1.0
):
tokens = text2tokens(text, lexicon_fn)
durations = predict_duration(tokens)
durations = jnp.where(
np.array(tokens)[None, :] == FLAGS.sp_index,
jnp.clip(durations, a_min=silence_duration, a_max=None),
durations,
)
durations = jnp.where(
np.array(tokens)[None, :] == FLAGS.word_end_index, 0.0, durations
)
mels = predict_mel(tokens, durations)
if tokens[-1] == FLAGS.sp_index:
end_silence = durations[0, -1].item()
silence_frame = int(end_silence * FLAGS.sample_rate / (FLAGS.n_fft // 4))
mels = mels[:, : (mels.shape[1] - silence_frame)]
return mels
if __name__ == "__main__":
0
Source : utils.py
with Apache License 2.0
from PredictiveIntelligenceLab
with Apache License 2.0
from PredictiveIntelligenceLab
def fit_kernel_density(X, xi, weights = None, bw=None):
X, weights = onp.array(X), onp.array(weights)
X = X.flatten()
if bw is None:
try:
sc = gaussian_kde(X, weights=weights)
bw = onp.sqrt(sc.covariance).flatten()[0]
except:
bw = 1.0
if bw < 1e-8:
bw = 1.0
kde_pdf_x, kde_pdf_y = FFTKDE(bw=bw).fit(X, weights).evaluate()
# Define the interpolation function
interp1d_fun = interp1d(kde_pdf_x,
kde_pdf_y,
kind = 'linear',
fill_value = 'extrapolate')
# Evaluate the weights on the input data
pdf = interp1d_fun(xi)
return np.clip(pdf, a_min=0.0) + 1e-8
def init_NN(Q):
See More Examples