Here are the examples of the python api jax.numpy.square taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.
112 Examples
3
Source : mean_squared_error_test.py
with MIT License
from cgarciae
with MIT License
from cgarciae
def test_function():
rng = jax.random.PRNGKey(42)
target = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
preds = jax.random.uniform(rng, shape=(2, 3))
loss = tx.losses.mean_squared_error(target, preds)
assert loss.shape == (2,)
assert jnp.array_equal(loss, jnp.mean(jnp.square(target - preds), axis=-1))
if __name__ == "__main__":
3
Source : mean_square_error.py
with MIT License
from cgarciae
with MIT License
from cgarciae
def _mean_square_error(preds: jnp.ndarray, target: jnp.ndarray) -> jnp.ndarray:
"""Calculates values required to update/compute Mean Square Error. Cast preds to have the same type as target.
Args:
preds: Predicted tensor
target: Ground truth tensor
Returns:
jnp.ndarray values needed to update Mean Square Error
"""
target = target.astype(preds.dtype)
return jnp.square(preds - target)
class MeanSquareError(Mean):
3
Source : clip_sample.py
with MIT License
from crowsonkb
with MIT License
from crowsonkb
def norm2(x):
"""Normalizes a batch of vectors to the unit sphere."""
return x / jnp.sqrt(jnp.sum(jnp.square(x), axis=-1, keepdims=True))
def spherical_dist_loss(x, y):
3
Source : helpers.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def l2_normalize(
x: jnp.ndarray,
axis: Optional[int] = None,
epsilon: float = 1e-12,
) -> jnp.ndarray:
"""l2 normalize a tensor on an axis with numerical stability."""
square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
return x * x_inv_norm
def l2_weight_regularizer(params):
3
Source : activation_transform.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def _get_jax_activation_function(name):
"""Get activation function by name in JAX."""
if name == "bentid":
return lambda x: (jnp.sqrt(jnp.square(x) + 1.) - 1.) / 2. + x
elif name == "softsign":
return jax.nn.soft_sign
elif hasattr(jax.lax, name):
return getattr(jax.lax, name)
elif hasattr(jax.nn, name):
return getattr(jax.nn, name)
else:
raise ValueError(f"Unrecognized activation function name '{name}'.")
def get_transformed_activations(*args, **kwargs):
3
Source : problem.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def make_vae_sdp_verif_instance(params, data_x, bounds):
"""Make SdpDualVerifInstance for VAE reconstruction error spec."""
elided_params = params[:-1]
elided_bounds = bounds[:-1]
dual_shapes, dual_types = get_dual_shapes_and_types(elided_bounds)
def recon_loss(x_final):
x_hat = utils.predict_cnn(params[-1:], x_final).reshape(1, -1)
return jnp.sum(jnp.square(data_x.reshape(x_hat.shape) - x_hat))
def make_inner_lagrangian(dual_vars):
return make_relu_network_lagrangian(
dual_vars, elided_params, elided_bounds, recon_loss)
return utils.SdpDualVerifInstance(
make_inner_lagrangian=make_inner_lagrangian,
bounds=elided_bounds,
dual_shapes=dual_shapes,
dual_types=dual_types)
def make_vae_semantic_spec_params(x, vae_params, classifier_params):
3
Source : r3.py
with Apache License 2.0
from dptech-corp
with Apache License 2.0
from dptech-corp
def vecs_robust_norm(v: Vecs, epsilon: float = 1e-8) -> jnp.ndarray:
"""Computes norm of vectors 'v'.
Args:
v: vectors to be normalized.
epsilon: small regularizer added to squared norm before taking square root.
Returns:
norm of 'v'
"""
return jnp.sqrt(jnp.square(v.x) + jnp.square(v.y) + jnp.square(v.z) + epsilon)
def vecs_sub(v1: Vecs, v2: Vecs) -> Vecs:
3
Source : api_test.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def test_kwargs(self):
# from https://github.com/google/jax/issues/1938
@api.custom_jvp
def my_fun(x, y, c=1.):
return c * (x + y)
def my_jvp(primals, tangents):
x, y, c = primals
t_x, t_y, t_c = tangents
return my_fun(x, y, c), t_c
my_fun.defjvp(my_jvp)
f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum()
f(10., 5.) # doesn't crash
api.jvp(f, (10., 5.), (1., 1.)) # doesn't crash
def test_initial_style(self):
3
Source : api_test.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def test_kwargs(self):
# from https://github.com/google/jax/issues/1938
@api.custom_vjp
def my_fun(x, y, c=1.):
return c * (x + y)
my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None),
lambda _, g: (g, g, g))
f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum()
f(10., 5.) # doesn't crash
api.grad(f)(10., 5.) # doesn't crash
def test_initial_style(self):
3
Source : name_stack_test.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def test_jvp_should_transform_stacks(self):
def f(x):
with extend_name_stack('bar'):
with extend_name_stack('baz'):
return jnp.square(x)
g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,)))
jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack),
'foo/jvp(bar)/jvp(baz)')
def test_jvp_should_apply_to_call_jaxpr(self):
3
Source : name_stack_test.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def test_jvp_should_apply_to_call_jaxpr(self):
@jax.jit
def f(x):
with extend_name_stack('bar'):
with extend_name_stack('baz'):
return jnp.square(x)
g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,)))
jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
self.assertEqual(
str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack),
'bar/baz')
hlo_text = _get_hlo(g)(1., 1.)
self.assertIn('foo/jvp(jit(f))/jvp(bar)', hlo_text)
def test_grad_should_add_jvp_and_transpose_to_name_stack(self):
3
Source : rnn_mlp_lopt.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def next_state(self, state: _LossNormalizerState,
loss: jnp.ndarray) -> _LossNormalizerState:
new_mean = self.decay * state.mean + (1.0 - self.decay) * loss
new_var = self.decay * state.var + (
1.0 - self.decay) * jnp.square(new_mean - loss)
new_updates = state.updates + 1
return _LossNormalizerState(mean=new_mean, var=new_var, updates=new_updates)
def weight_loss(self, state: _LossNormalizerState,
3
Source : rnn_mlp_lopt.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def _avg_square_mean(tree: Any) -> jnp.ndarray:
return sum([jnp.mean(jnp.square(x)) for x in jax.tree_leaves(tree)]) / len(
jax.tree_leaves(tree))
def _clip_log_abs(value: jnp.ndarray) -> jnp.ndarray:
3
Source : image_mlp.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def loss(self, params: Params, key: PRNGKey, data: Any) -> jnp.ndarray:
num_classes = self.datasets.extra_info["num_classes"]
logits = self._mod.apply(params, key, data["image"])
labels = jax.nn.one_hot(data["label"], num_classes)
return jnp.mean(jnp.square(logits - labels))
def normalizer(self, loss):
3
Source : quadratics.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def loss(self, params, rng, _):
a = params["a"]
b = params["b"]
return jnp.sum(jnp.square(a + b))
def init(self, key):
3
Source : quadratics.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def task_fn(self, task_params: TaskParams) -> base.Task:
dim = self._dim
class _Task(base.Task):
def loss(self, params, rng, _):
return jnp.sum(jnp.square(task_params - params))
def init(self, key) -> Params:
return jax.random.normal(key, shape=(dim,))
return _Task()
@datasets_base.dataset_lru_cache
3
Source : tree_utils.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def tree_norm(val):
sum_squared = sum(map(lambda x: jnp.sum(jnp.square(x)), jax.tree_leaves(val)))
return jnp.sqrt(sum_squared)
@jax.jit
3
Source : prompt.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def l2_normalize(x, axis=None, epsilon=1e-12):
"""l2 normalizes a tensor on an axis with numerical stability."""
square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
return x * x_inv_norm
def prefix_prompt(prompt: Array, x_embed: Array) -> Array:
3
Source : wayward.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def squared_l2(x: Array) -> float:
"""Calculate the squared l2 norms of a sequence of arrays.
Note:
We use the squared l2 norm as things like the ranking will the same
without needing to do the expensive sqrt.
Args:
x: The sequence of arrays to calculate the norm of. [T, H]
Returns:
The norm over the hidden dimension of the sequence of arrays. [T]
"""
return jnp.sum(jnp.square(x), axis=1)
def squared_l2_distance(x: Array, y: Array) -> Array:
3
Source : trainer.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def _make_rms_metrics(name, tree):
"""Calculates the root-mean-square metric for a pytree."""
return {
f"{name}/{k}": metrics_lib.AveragePerStep.from_model_output(
jnp.sqrt(jnp.mean(jnp.square(v))))
for k, v in utils.flatten_dict_string_keys(tree).items()
}
@staticmethod
3
Source : ffjord_mnist.py
with MIT License
from jacobjinkelly
with MIT License
from jacobjinkelly
def _weight_fn(params):
flat_params, _ = ravel_pytree(params)
return 0.5 * jnp.sum(jnp.square(flat_params))
def loss_fn(forward, params, images, key):
3
Source : latent_ode.py
with MIT License
from jacobjinkelly
with MIT License
from jacobjinkelly
def _weight_fn(params):
flat_params, _ = ravel_pytree(params)
return 0.5 * jnp.sum(jnp.square(flat_params))
def loss_fn(forward, params, batch, kl_coef):
3
Source : ode.py
with MIT License
from jacobjinkelly
with MIT License
from jacobjinkelly
def error_ratio_tol(error_estimate, error_tolerance):
err_ratio = error_estimate / error_tolerance
# return np.square(np.max(np.abs(err_ratio))) # (square since optimal_step_size expects squared norm)
return np.mean(np.square(err_ratio))
def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
3
Source : mnist.py
with MIT License
from jacobjinkelly
with MIT License
from jacobjinkelly
def _weight_fn(params):
flat_params, _ = ravel_pytree(params)
return 0.5 * jnp.sum(jnp.square(flat_params))
def loss_fn(forward, params, images, labels, key):
3
Source : nlds_smoother.py
with Apache License 2.0
from Joshuaalbert
with Apache License 2.0
from Joshuaalbert
def clip_covariance_diag(self, cov, lo, hi):
"""
Clips the standard-deviation on the diagonal of cov.
Args:
cov: [B, M, M]
lo: float, standard-dev low value
hi: float, standard-dev high value
Returns:
[B, M, M] covarinace with clipped standard devs.
"""
variance = batched_diag(cov)
clipped_variance = jnp.clip(variance, jnp.square(lo), jnp.square(hi))
add_amount = clipped_variance - variance
return cov + batched_diag(add_amount)
def __call__(self, Y, Sigma, mu0, Gamma0, Omega, *control_params, maxiter=None, tol=1e-5, momentum=0.,
3
Source : dqn.py
with MIT License
from ku2482
with MIT License
from ku2482
def _calculate_loss_and_abs_td(
self,
q: jnp.ndarray,
target: jnp.ndarray,
weight: np.ndarray,
) -> jnp.ndarray:
td = target - q
if self.loss_type == "l2":
loss = jnp.mean(jnp.square(td) * weight)
elif self.loss_type == "huber":
loss = jnp.mean(huber(td) * weight)
return loss, jax.lax.stop_gradient(jnp.abs(td))
@partial(jax.jit, static_argnums=0)
3
Source : ppo.py
with MIT License
from ku2482
with MIT License
from ku2482
def _loss_critic(
self,
params_critic: hk.Params,
state: np.ndarray,
target: np.ndarray,
) -> jnp.ndarray:
return jnp.square(target - self.critic.apply(params_critic, state)).mean(), None
@partial(jax.jit, static_argnums=0)
3
Source : distribution.py
with MIT License
from ku2482
with MIT License
from ku2482
def gaussian_log_prob(
log_std: jnp.ndarray,
noise: jnp.ndarray,
) -> jnp.ndarray:
"""
Calculate log probabilities of gaussian distributions.
"""
return -0.5 * (jnp.square(noise) + 2 * log_std + jnp.log(2 * math.pi))
@jax.jit
3
Source : distribution.py
with MIT License
from ku2482
with MIT License
from ku2482
def gaussian_and_tanh_log_prob(
log_std: jnp.ndarray,
noise: jnp.ndarray,
action: jnp.ndarray,
) -> jnp.ndarray:
"""
Calculate log probabilities of gaussian distributions and tanh transformation.
"""
return gaussian_log_prob(log_std, noise) - jnp.log(nn.relu(1.0 - jnp.square(action)) + 1e-6)
@jax.jit
3
Source : distribution.py
with MIT License
from ku2482
with MIT License
from ku2482
def calculate_kl_divergence(
p_mean: np.ndarray,
p_std: np.ndarray,
q_mean: np.ndarray,
q_std: np.ndarray,
) -> jnp.ndarray:
"""
Calculate KL Divergence between gaussian distributions.
"""
var_ratio = jnp.square(p_std / q_std)
t1 = jnp.square((p_mean - q_mean) / q_std)
return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio))
3
Source : loss.py
with MIT License
from ku2482
with MIT License
from ku2482
def quantile_loss(
td: jnp.ndarray,
cum_p: jnp.ndarray,
weight: jnp.ndarray,
loss_type: str,
) -> jnp.ndarray:
"""
Calculate quantile loss.
"""
if loss_type == "l2":
element_wise_loss = jnp.square(td)
elif loss_type == "huber":
element_wise_loss = huber(td)
else:
NotImplementedError
element_wise_loss *= jax.lax.stop_gradient(jnp.abs(cum_p[..., None] - (td < 0)))
batch_loss = element_wise_loss.sum(axis=1).mean(axis=1, keepdims=True)
return (batch_loss * weight).mean()
3
Source : basics.py
with MIT License
from NTT123
with MIT License
from NTT123
def loss_fn(model: Linear, x: jnp.ndarray, y: jnp.ndarray):
model, y_hat = pax.purecall(model, x)
loss = jnp.mean(jnp.square(y_hat - y))
return loss, model
@jax.jit
3
Source : lazy_module.py
with MIT License
from NTT123
with MIT License
from NTT123
def loss_fn(model, x: jnp.ndarray, y: jnp.ndarray):
model, y_hat = forward(model, x)
loss = jnp.mean(jnp.square(y_hat - y))
return loss, model
@jax.jit
3
Source : test_utils.py
with MIT License
from NTT123
with MIT License
from NTT123
def test_util_update_fn():
def loss_fn(model: pax.Linear, x, target):
y = model(x)
loss = jnp.mean(jnp.square(y - target))
return loss, (loss, model)
net = pax.Linear(2, 1)
opt = opax.adamw(learning_rate=1e-1)(net.parameters())
update_fn = jax.jit(pax.utils.build_update_fn(loss_fn, scan_mode=True))
x = np.random.normal(size=(32, 2))
y = np.random.normal(size=(32, 1))
print()
for step in range(3):
(net, opt), loss = update_fn((net, opt), x, y)
print(f"step {step} loss {loss:.3f}")
def test_Rng_Seq():
3
Source : numpy_ops.py
with GNU General Public License v3.0
from PKU-NIP-Lab
with GNU General Public License v3.0
from PKU-NIP-Lab
def square(x):
x = _remove_jaxarray(x)
return JaxArray(jnp.square(x))
def fabs(x):
3
Source : sparse_regression.py
with Apache License 2.0
from pyro-ppl
with Apache License 2.0
from pyro-ppl
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-4):
eta1sq = jnp.square(eta1)
eta2sq = jnp.square(eta2)
k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(X, Z))
k2 = -0.5 * eta2sq * dot(jnp.square(X), jnp.square(Z))
k3 = (eta1sq - eta2sq) * dot(X, Z)
k4 = jnp.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * jnp.eye(X.shape[0])
return k1 + k2 + k3 + k4
# Most of the model code is concerned with constructing the sparsity inducing prior.
def model(X, Y, hypers):
3
Source : kl.py
with Apache License 2.0
from pyro-ppl
with Apache License 2.0
from pyro-ppl
def kl_divergence(p, q):
var_ratio = jnp.square(p.scale / q.scale)
t1 = jnp.square((p.loc - q.loc) / q.scale)
return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio))
@dispatch(Beta, Beta)
3
Source : autoguide.py
with Apache License 2.0
from pyro-ppl
with Apache License 2.0
from pyro-ppl
def quantiles(self, params, quantiles):
loc = params[f"{self.prefix}_loc"]
cov_factor = params[f"{self.prefix}_cov_factor"]
scale = params[f"{self.prefix}_scale"]
scale = scale * jnp.sqrt(jnp.square(cov_factor).sum(-1) + 1)
quantiles = jnp.array(quantiles)[..., None]
latent = dist.Normal(loc, scale).icdf(quantiles)
return self._unpack_and_constrain(latent, params)
class AutoLaplaceApproximation(AutoContinuous):
3
Source : toy_examples.py
with MIT License
from SamDuffield
with MIT License
from SamDuffield
def likelihood_potential(self,
x: jnp.ndarray,
random_key: jnp.ndarray = None) -> Union[float, jnp.ndarray]:
x_diff = self.precision_mul(x - self.mean, self.precision_sqrt.T)
return 0.5 * jnp.square(x_diff).sum(axis=-1)
def __setattr__(self,
3
Source : toy_examples.py
with MIT License
from SamDuffield
with MIT License
from SamDuffield
def component_potential(self,
x: jnp.ndarray,
component_index: int) -> Union[float, jnp.ndarray]:
return 0.5 * jnp.sum(jnp.square((x - self.means[component_index]) @
self.precision_sqrts[component_index].T), axis=-1) \
- jnp.log(self.weights[component_index]
* self.precision_dets[component_index]
/ jnp.power(2 * jnp.pi, self.dim * 0.5))
def component_dens(self,
3
Source : haiku_run.py
with Apache License 2.0
from Scitator
with Apache License 2.0
from Scitator
def _loss_fn(
self, params: hk.Params, state: hk.State, rng: jnp.ndarray, batch: Batch
) -> Tuple[jnp.ndarray, hk.State]:
logits, state = self.forward.apply(params, state, rng, batch)
labels = jax.nn.one_hot(batch["labels"], 10)
l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
softmax_xent /= labels.shape[0]
return (softmax_xent + 1e-4 * l2_loss, state)
def _eval_batch(
3
Source : normalizations.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def fprop(self, inputs: JTensor) -> JTensor:
"""Apply RMS norm to inputs.
Args:
inputs: The inputs JTensor. Shaped [..., input_dims].
Returns:
RMS normalized input.
"""
theta = self.local_theta()
var = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
normed_inputs = inputs * jax.lax.rsqrt(var + self.params.epsilon)
scale = theta.scale if self.params.direct_scale else 1 + theta.scale
normed_inputs *= scale
return normed_inputs
class GroupNorm(base_layer.BaseLayer):
3
Source : test_layers.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def compute_loss(self, predictions, input_batch):
targets = input_batch.targets
error = predictions - targets
loss = jnp.mean(jnp.square(error))
per_example_out = NestedMap(predictions=predictions)
return NestedMap(
loss=(loss, jnp.array(1.0, loss.dtype))), per_example_out
class TestBatchNormalizationModel(base_model.BaseModel):
3
Source : test_layers.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def compute_loss(self, predictions: JTensor,
input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]:
targets = input_batch.targets
error = predictions - targets
loss = jnp.mean(jnp.square(error))
per_example_out = NestedMap(predictions=predictions)
return NestedMap(
loss=(loss, jnp.array(1.0, loss.dtype))), per_example_out
class TestSpmdModel(base_model.BaseModel):
3
Source : optimizers.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def reduce_rms(array: JTensor) -> JTensor:
"""Computes the RMS of `array` (in a numerically stable way).
Args:
array: Input array.
Returns:
The root mean square of the input array as a scalar array.
"""
sq = jnp.square(array)
sq_mean = reduce_mean(sq)
return jnp.sqrt(sq_mean)
@dataclasses.dataclass(frozen=True)
3
Source : utils.py
with MIT License
from vballoli
with MIT License
from vballoli
def global_norm(updates):
"""Returns the l2 norm of the input.
Args:
updates: A pytree of ndarrays representing the gradient.
"""
return jnp.sqrt(
sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(updates)]))
def clip_by_global_norm(updates):
3
Source : folding_multimer.py
with MIT License
from Zuricho
with MIT License
from Zuricho
def squared_difference(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Computes Squared difference between two arrays."""
return jnp.square(x - y)
def make_backbone_affine(
0
Source : losses.py
with MIT License
from andreArtelt
with MIT License
from andreArtelt
def l2(x, x_orig):
return npx.sum(npx.square(x - x_orig))
def lmad(x, x_orig, mad):
0
Source : layer.py
with MIT License
from andreArtelt
with MIT License
from andreArtelt
def normal_distribution(x, mean, variance):
return npx.exp(-.5 * npx.square(x - mean) / variance) / npx.sqrt(2. * npx.pi * variance)
def log_normal_distribution(x, mean, variance):
0
Source : layer.py
with MIT License
from andreArtelt
with MIT License
from andreArtelt
def log_normal_distribution(x, mean, variance):
return -.5 * npx.square(x - mean) / variance - .5 * (2. + npx.pi + variance)
def log_multivariate_normal(x, mean, sigma_inv, k):
See More Examples