jax.numpy.square

Here are the examples of the python api jax.numpy.square taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

112 Examples 7

3 Source : mean_squared_error_test.py
with MIT License
from cgarciae

def test_function():

    rng = jax.random.PRNGKey(42)

    target = jax.random.randint(rng, shape=(2, 3), minval=0, maxval=2)
    preds = jax.random.uniform(rng, shape=(2, 3))

    loss = tx.losses.mean_squared_error(target, preds)

    assert loss.shape == (2,)

    assert jnp.array_equal(loss, jnp.mean(jnp.square(target - preds), axis=-1))


if __name__ == "__main__":

3 Source : mean_square_error.py
with MIT License
from cgarciae

def _mean_square_error(preds: jnp.ndarray, target: jnp.ndarray) -> jnp.ndarray:
    """Calculates values required to update/compute Mean Square Error. Cast preds to have the same type as target.

    Args:
        preds: Predicted tensor
        target: Ground truth tensor

    Returns:
        jnp.ndarray values needed to update Mean Square Error
    """

    target = target.astype(preds.dtype)
    return jnp.square(preds - target)


class MeanSquareError(Mean):

3 Source : clip_sample.py
with MIT License
from crowsonkb

def norm2(x):
    """Normalizes a batch of vectors to the unit sphere."""
    return x / jnp.sqrt(jnp.sum(jnp.square(x), axis=-1, keepdims=True))


def spherical_dist_loss(x, y):

3 Source : helpers.py
with Apache License 2.0
from deepmind

def l2_normalize(
    x: jnp.ndarray,
    axis: Optional[int] = None,
    epsilon: float = 1e-12,
) -> jnp.ndarray:
  """l2 normalize a tensor on an axis with numerical stability."""
  square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
  x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
  return x * x_inv_norm


def l2_weight_regularizer(params):

3 Source : activation_transform.py
with Apache License 2.0
from deepmind

def _get_jax_activation_function(name):
  """Get activation function by name in JAX."""

  if name == "bentid":
    return lambda x: (jnp.sqrt(jnp.square(x) + 1.) - 1.) / 2. + x
  elif name == "softsign":
    return jax.nn.soft_sign
  elif hasattr(jax.lax, name):
    return getattr(jax.lax, name)
  elif hasattr(jax.nn, name):
    return getattr(jax.nn, name)
  else:
    raise ValueError(f"Unrecognized activation function name '{name}'.")


def get_transformed_activations(*args, **kwargs):

3 Source : problem.py
with Apache License 2.0
from deepmind

def make_vae_sdp_verif_instance(params, data_x, bounds):
  """Make SdpDualVerifInstance for VAE reconstruction error spec."""
  elided_params = params[:-1]
  elided_bounds = bounds[:-1]
  dual_shapes, dual_types = get_dual_shapes_and_types(elided_bounds)
  def recon_loss(x_final):
    x_hat = utils.predict_cnn(params[-1:], x_final).reshape(1, -1)
    return jnp.sum(jnp.square(data_x.reshape(x_hat.shape) - x_hat))
  def make_inner_lagrangian(dual_vars):
    return make_relu_network_lagrangian(
        dual_vars, elided_params, elided_bounds, recon_loss)
  return utils.SdpDualVerifInstance(
      make_inner_lagrangian=make_inner_lagrangian,
      bounds=elided_bounds,
      dual_shapes=dual_shapes,
      dual_types=dual_types)


def make_vae_semantic_spec_params(x, vae_params, classifier_params):

3 Source : r3.py
with Apache License 2.0
from dptech-corp

def vecs_robust_norm(v: Vecs, epsilon: float = 1e-8) -> jnp.ndarray:
  """Computes norm of vectors 'v'.

  Args:
    v: vectors to be normalized.
    epsilon: small regularizer added to squared norm before taking square root.
  Returns:
    norm of 'v'
  """
  return jnp.sqrt(jnp.square(v.x) + jnp.square(v.y) + jnp.square(v.z) + epsilon)


def vecs_sub(v1: Vecs, v2: Vecs) -> Vecs:

3 Source : api_test.py
with Apache License 2.0
from google

  def test_kwargs(self):
    # from https://github.com/google/jax/issues/1938
    @api.custom_jvp
    def my_fun(x, y, c=1.):
      return c * (x + y)
    def my_jvp(primals, tangents):
      x, y, c = primals
      t_x, t_y, t_c = tangents
      return my_fun(x, y, c), t_c
    my_fun.defjvp(my_jvp)
    f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum()
    f(10., 5.)  # doesn't crash
    api.jvp(f, (10., 5.), (1., 1.))  # doesn't crash

  def test_initial_style(self):

3 Source : api_test.py
with Apache License 2.0
from google

  def test_kwargs(self):
    # from https://github.com/google/jax/issues/1938
    @api.custom_vjp
    def my_fun(x, y, c=1.):
      return c * (x + y)
    my_fun.defvjp(lambda x, y, c=1.: (my_fun(c, y, c), None),
                  lambda _, g: (g, g, g))
    f = lambda x, y: jnp.square(my_fun(x, y, c=2.)).sum()
    f(10., 5.)  # doesn't crash
    api.grad(f)(10., 5.)  # doesn't crash

  def test_initial_style(self):

3 Source : name_stack_test.py
with Apache License 2.0
from google

  def test_jvp_should_transform_stacks(self):
    def f(x):
      with extend_name_stack('bar'):
        with extend_name_stack('baz'):
          return jnp.square(x)
    g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,)))
    jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
    self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack),
                     'foo/jvp(bar)/jvp(baz)')

  def test_jvp_should_apply_to_call_jaxpr(self):

3 Source : name_stack_test.py
with Apache License 2.0
from google

  def test_jvp_should_apply_to_call_jaxpr(self):
    @jax.jit
    def f(x):
      with extend_name_stack('bar'):
        with extend_name_stack('baz'):
          return jnp.square(x)
    g = extend_name_stack('foo')(lambda x, t: jax.jvp(f, (x,), (t,)))
    jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
    self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
    self.assertEqual(
        str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack),
        'bar/baz')

    hlo_text = _get_hlo(g)(1., 1.)
    self.assertIn('foo/jvp(jit(f))/jvp(bar)', hlo_text)

  def test_grad_should_add_jvp_and_transpose_to_name_stack(self):

3 Source : rnn_mlp_lopt.py
with Apache License 2.0
from google

  def next_state(self, state: _LossNormalizerState,
                 loss: jnp.ndarray) -> _LossNormalizerState:
    new_mean = self.decay * state.mean + (1.0 - self.decay) * loss
    new_var = self.decay * state.var + (
        1.0 - self.decay) * jnp.square(new_mean - loss)
    new_updates = state.updates + 1
    return _LossNormalizerState(mean=new_mean, var=new_var, updates=new_updates)

  def weight_loss(self, state: _LossNormalizerState,

3 Source : rnn_mlp_lopt.py
with Apache License 2.0
from google

def _avg_square_mean(tree: Any) -> jnp.ndarray:
  return sum([jnp.mean(jnp.square(x)) for x in jax.tree_leaves(tree)]) / len(
      jax.tree_leaves(tree))


def _clip_log_abs(value: jnp.ndarray) -> jnp.ndarray:

3 Source : image_mlp.py
with Apache License 2.0
from google

  def loss(self, params: Params, key: PRNGKey, data: Any) -> jnp.ndarray:
    num_classes = self.datasets.extra_info["num_classes"]
    logits = self._mod.apply(params, key, data["image"])
    labels = jax.nn.one_hot(data["label"], num_classes)
    return jnp.mean(jnp.square(logits - labels))

  def normalizer(self, loss):

3 Source : quadratics.py
with Apache License 2.0
from google

  def loss(self, params, rng, _):
    a = params["a"]
    b = params["b"]
    return jnp.sum(jnp.square(a + b))

  def init(self, key):

3 Source : quadratics.py
with Apache License 2.0
from google

  def task_fn(self, task_params: TaskParams) -> base.Task:
    dim = self._dim

    class _Task(base.Task):

      def loss(self, params, rng, _):
        return jnp.sum(jnp.square(task_params - params))

      def init(self, key) -> Params:
        return jax.random.normal(key, shape=(dim,))

    return _Task()


@datasets_base.dataset_lru_cache

3 Source : tree_utils.py
with Apache License 2.0
from google

def tree_norm(val):
  sum_squared = sum(map(lambda x: jnp.sum(jnp.square(x)), jax.tree_leaves(val)))
  return jnp.sqrt(sum_squared)


@jax.jit

3 Source : prompt.py
with Apache License 2.0
from google-research

def l2_normalize(x, axis=None, epsilon=1e-12):
  """l2 normalizes a tensor on an axis with numerical stability."""
  square_sum = jnp.sum(jnp.square(x), axis=axis, keepdims=True)
  x_inv_norm = jax.lax.rsqrt(jnp.maximum(square_sum, epsilon))
  return x * x_inv_norm


def prefix_prompt(prompt: Array, x_embed: Array) -> Array:

3 Source : wayward.py
with Apache License 2.0
from google-research

def squared_l2(x: Array) -> float:
  """Calculate the squared l2 norms of a sequence of arrays.

  Note:
    We use the squared l2 norm as things like the ranking will the same
    without needing to do the expensive sqrt.

  Args:
    x: The sequence of arrays to calculate the norm of. [T, H]

  Returns:
    The norm over the hidden dimension of the sequence of arrays. [T]
  """
  return jnp.sum(jnp.square(x), axis=1)


def squared_l2_distance(x: Array, y: Array) -> Array:

3 Source : trainer.py
with Apache License 2.0
from google-research

  def _make_rms_metrics(name, tree):
    """Calculates the root-mean-square metric for a pytree."""
    return {
        f"{name}/{k}": metrics_lib.AveragePerStep.from_model_output(
            jnp.sqrt(jnp.mean(jnp.square(v))))
        for k, v in utils.flatten_dict_string_keys(tree).items()
    }

  @staticmethod

3 Source : ffjord_mnist.py
with MIT License
from jacobjinkelly

def _weight_fn(params):
    flat_params, _ = ravel_pytree(params)
    return 0.5 * jnp.sum(jnp.square(flat_params))


def loss_fn(forward, params, images, key):

3 Source : latent_ode.py
with MIT License
from jacobjinkelly

def _weight_fn(params):
    flat_params, _ = ravel_pytree(params)
    return 0.5 * jnp.sum(jnp.square(flat_params))


def loss_fn(forward, params, batch, kl_coef):

3 Source : ode.py
with MIT License
from jacobjinkelly

def error_ratio_tol(error_estimate, error_tolerance):
  err_ratio = error_estimate / error_tolerance
  # return np.square(np.max(np.abs(err_ratio)))  # (square since optimal_step_size expects squared norm)
  return np.mean(np.square(err_ratio))

def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,

3 Source : mnist.py
with MIT License
from jacobjinkelly

def _weight_fn(params):
    flat_params, _ = ravel_pytree(params)
    return 0.5 * jnp.sum(jnp.square(flat_params))


def loss_fn(forward, params, images, labels, key):

3 Source : nlds_smoother.py
with Apache License 2.0
from Joshuaalbert

    def clip_covariance_diag(self, cov, lo, hi):
        """
        Clips the standard-deviation on the diagonal of cov.
        Args:
            cov: [B, M, M]
            lo: float, standard-dev low value
            hi: float, standard-dev high value

        Returns:
            [B, M, M] covarinace with clipped standard devs.

        """
        variance = batched_diag(cov)
        clipped_variance = jnp.clip(variance, jnp.square(lo), jnp.square(hi))
        add_amount = clipped_variance - variance
        return cov + batched_diag(add_amount)

    def __call__(self, Y, Sigma, mu0, Gamma0, Omega, *control_params, maxiter=None, tol=1e-5, momentum=0.,

3 Source : dqn.py
with MIT License
from ku2482

    def _calculate_loss_and_abs_td(
        self,
        q: jnp.ndarray,
        target: jnp.ndarray,
        weight: np.ndarray,
    ) -> jnp.ndarray:
        td = target - q
        if self.loss_type == "l2":
            loss = jnp.mean(jnp.square(td) * weight)
        elif self.loss_type == "huber":
            loss = jnp.mean(huber(td) * weight)
        return loss, jax.lax.stop_gradient(jnp.abs(td))

    @partial(jax.jit, static_argnums=0)

3 Source : ppo.py
with MIT License
from ku2482

    def _loss_critic(
        self,
        params_critic: hk.Params,
        state: np.ndarray,
        target: np.ndarray,
    ) -> jnp.ndarray:
        return jnp.square(target - self.critic.apply(params_critic, state)).mean(), None

    @partial(jax.jit, static_argnums=0)

3 Source : distribution.py
with MIT License
from ku2482

def gaussian_log_prob(
    log_std: jnp.ndarray,
    noise: jnp.ndarray,
) -> jnp.ndarray:
    """
    Calculate log probabilities of gaussian distributions.
    """
    return -0.5 * (jnp.square(noise) + 2 * log_std + jnp.log(2 * math.pi))


@jax.jit

3 Source : distribution.py
with MIT License
from ku2482

def gaussian_and_tanh_log_prob(
    log_std: jnp.ndarray,
    noise: jnp.ndarray,
    action: jnp.ndarray,
) -> jnp.ndarray:
    """
    Calculate log probabilities of gaussian distributions and tanh transformation.
    """
    return gaussian_log_prob(log_std, noise) - jnp.log(nn.relu(1.0 - jnp.square(action)) + 1e-6)


@jax.jit

3 Source : distribution.py
with MIT License
from ku2482

def calculate_kl_divergence(
    p_mean: np.ndarray,
    p_std: np.ndarray,
    q_mean: np.ndarray,
    q_std: np.ndarray,
) -> jnp.ndarray:
    """
    Calculate KL Divergence between gaussian distributions.
    """
    var_ratio = jnp.square(p_std / q_std)
    t1 = jnp.square((p_mean - q_mean) / q_std)
    return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio))

3 Source : loss.py
with MIT License
from ku2482

def quantile_loss(
    td: jnp.ndarray,
    cum_p: jnp.ndarray,
    weight: jnp.ndarray,
    loss_type: str,
) -> jnp.ndarray:
    """
    Calculate quantile loss.
    """
    if loss_type == "l2":
        element_wise_loss = jnp.square(td)
    elif loss_type == "huber":
        element_wise_loss = huber(td)
    else:
        NotImplementedError
    element_wise_loss *= jax.lax.stop_gradient(jnp.abs(cum_p[..., None] - (td   <   0)))
    batch_loss = element_wise_loss.sum(axis=1).mean(axis=1, keepdims=True)
    return (batch_loss * weight).mean()

3 Source : basics.py
with MIT License
from NTT123

def loss_fn(model: Linear, x: jnp.ndarray, y: jnp.ndarray):
    model, y_hat = pax.purecall(model, x)
    loss = jnp.mean(jnp.square(y_hat - y))
    return loss, model


@jax.jit

3 Source : lazy_module.py
with MIT License
from NTT123

def loss_fn(model, x: jnp.ndarray, y: jnp.ndarray):
    model, y_hat = forward(model, x)
    loss = jnp.mean(jnp.square(y_hat - y))
    return loss, model


@jax.jit

3 Source : test_utils.py
with MIT License
from NTT123

def test_util_update_fn():
    def loss_fn(model: pax.Linear, x, target):
        y = model(x)
        loss = jnp.mean(jnp.square(y - target))
        return loss, (loss, model)

    net = pax.Linear(2, 1)
    opt = opax.adamw(learning_rate=1e-1)(net.parameters())
    update_fn = jax.jit(pax.utils.build_update_fn(loss_fn, scan_mode=True))
    x = np.random.normal(size=(32, 2))
    y = np.random.normal(size=(32, 1))
    print()
    for step in range(3):
        (net, opt), loss = update_fn((net, opt), x, y)
    print(f"step {step}  loss {loss:.3f}")


def test_Rng_Seq():

3 Source : numpy_ops.py
with GNU General Public License v3.0
from PKU-NIP-Lab

def square(x):
  x = _remove_jaxarray(x)
  return JaxArray(jnp.square(x))


def fabs(x):

3 Source : sparse_regression.py
with Apache License 2.0
from pyro-ppl

def kernel(X, Z, eta1, eta2, c, jitter=1.0e-4):
    eta1sq = jnp.square(eta1)
    eta2sq = jnp.square(eta2)
    k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(X, Z))
    k2 = -0.5 * eta2sq * dot(jnp.square(X), jnp.square(Z))
    k3 = (eta1sq - eta2sq) * dot(X, Z)
    k4 = jnp.square(c) - 0.5 * eta2sq
    if X.shape == Z.shape:
        k4 += jitter * jnp.eye(X.shape[0])
    return k1 + k2 + k3 + k4


# Most of the model code is concerned with constructing the sparsity inducing prior.
def model(X, Y, hypers):

3 Source : kl.py
with Apache License 2.0
from pyro-ppl

def kl_divergence(p, q):
    var_ratio = jnp.square(p.scale / q.scale)
    t1 = jnp.square((p.loc - q.loc) / q.scale)
    return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio))


@dispatch(Beta, Beta)

3 Source : autoguide.py
with Apache License 2.0
from pyro-ppl

    def quantiles(self, params, quantiles):
        loc = params[f"{self.prefix}_loc"]
        cov_factor = params[f"{self.prefix}_cov_factor"]
        scale = params[f"{self.prefix}_scale"]
        scale = scale * jnp.sqrt(jnp.square(cov_factor).sum(-1) + 1)
        quantiles = jnp.array(quantiles)[..., None]
        latent = dist.Normal(loc, scale).icdf(quantiles)
        return self._unpack_and_constrain(latent, params)


class AutoLaplaceApproximation(AutoContinuous):

3 Source : toy_examples.py
with MIT License
from SamDuffield

    def likelihood_potential(self,
                             x: jnp.ndarray,
                             random_key: jnp.ndarray = None) -> Union[float, jnp.ndarray]:
        x_diff = self.precision_mul(x - self.mean, self.precision_sqrt.T)
        return 0.5 * jnp.square(x_diff).sum(axis=-1)

    def __setattr__(self,

3 Source : toy_examples.py
with MIT License
from SamDuffield

    def component_potential(self,
                            x: jnp.ndarray,
                            component_index: int) -> Union[float, jnp.ndarray]:
        return 0.5 * jnp.sum(jnp.square((x - self.means[component_index]) @
                                      self.precision_sqrts[component_index].T), axis=-1) \
               - jnp.log(self.weights[component_index]
                        * self.precision_dets[component_index]
                        / jnp.power(2 * jnp.pi, self.dim * 0.5))

    def component_dens(self,

3 Source : haiku_run.py
with Apache License 2.0
from Scitator

    def _loss_fn(
        self, params: hk.Params, state: hk.State, rng: jnp.ndarray, batch: Batch
    ) -> Tuple[jnp.ndarray, hk.State]:
        logits, state = self.forward.apply(params, state, rng, batch)
        labels = jax.nn.one_hot(batch["labels"], 10)

        l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
        softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
        softmax_xent /= labels.shape[0]

        return (softmax_xent + 1e-4 * l2_loss, state)

    def _eval_batch(

3 Source : normalizations.py
with Apache License 2.0
from tensorflow

  def fprop(self, inputs: JTensor) -> JTensor:
    """Apply RMS norm to inputs.

    Args:
      inputs: The inputs JTensor. Shaped [..., input_dims].

    Returns:
      RMS normalized input.
    """
    theta = self.local_theta()
    var = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
    normed_inputs = inputs * jax.lax.rsqrt(var + self.params.epsilon)
    scale = theta.scale if self.params.direct_scale else 1 + theta.scale
    normed_inputs *= scale
    return normed_inputs


class GroupNorm(base_layer.BaseLayer):

3 Source : test_layers.py
with Apache License 2.0
from tensorflow

  def compute_loss(self, predictions, input_batch):
    targets = input_batch.targets
    error = predictions - targets
    loss = jnp.mean(jnp.square(error))
    per_example_out = NestedMap(predictions=predictions)
    return NestedMap(
        loss=(loss, jnp.array(1.0, loss.dtype))), per_example_out


class TestBatchNormalizationModel(base_model.BaseModel):

3 Source : test_layers.py
with Apache License 2.0
from tensorflow

  def compute_loss(self, predictions: JTensor,
                   input_batch: NestedMap) -> Tuple[NestedMap, NestedMap]:
    targets = input_batch.targets
    error = predictions - targets
    loss = jnp.mean(jnp.square(error))
    per_example_out = NestedMap(predictions=predictions)
    return NestedMap(
        loss=(loss, jnp.array(1.0, loss.dtype))), per_example_out


class TestSpmdModel(base_model.BaseModel):

3 Source : optimizers.py
with Apache License 2.0
from tensorflow

def reduce_rms(array: JTensor) -> JTensor:
  """Computes the RMS of `array` (in a numerically stable way).

  Args:
    array: Input array.

  Returns:
    The root mean square of the input array as a scalar array.
  """
  sq = jnp.square(array)
  sq_mean = reduce_mean(sq)
  return jnp.sqrt(sq_mean)


@dataclasses.dataclass(frozen=True)

3 Source : utils.py
with MIT License
from vballoli

def global_norm(updates):
  """Returns the l2 norm of the input.
  Args:
    updates: A pytree of ndarrays representing the gradient.
  """
  return jnp.sqrt(
      sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(updates)]))


def clip_by_global_norm(updates):

3 Source : folding_multimer.py
with MIT License
from Zuricho

def squared_difference(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  """Computes Squared difference between two arrays."""
  return jnp.square(x - y)


def make_backbone_affine(

0 Source : losses.py
with MIT License
from andreArtelt

def l2(x, x_orig):
    return npx.sum(npx.square(x - x_orig))

def lmad(x, x_orig, mad):

0 Source : layer.py
with MIT License
from andreArtelt

def normal_distribution(x, mean, variance):
    return npx.exp(-.5 * npx.square(x - mean) / variance) / npx.sqrt(2. * npx.pi * variance)

def log_normal_distribution(x, mean, variance):

0 Source : layer.py
with MIT License
from andreArtelt

def log_normal_distribution(x, mean, variance):
    return -.5 * npx.square(x - mean) / variance - .5 * (2. + npx.pi + variance)

def log_multivariate_normal(x, mean, sigma_inv, k):

See More Examples