jax.numpy.mean

Here are the examples of the python api jax.numpy.mean taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

193 Examples 7

3 Source : cnn_flax.py
with MIT License
from cgarciae

def loss_fn(
    params: tx.FlaxModule, model: tx.FlaxModule, x: jnp.ndarray, y: jnp.ndarray
) -> tp.Tuple[jnp.ndarray, tp.Tuple[tx.FlaxModule, jnp.ndarray]]:
    model = model.merge(params)
    preds = model(x)

    loss = jnp.mean(
        optax.softmax_cross_entropy(
            preds,
            jax.nn.one_hot(y, 10),
        )
    )

    acc_batch = preds.argmax(axis=1) == y

    return loss, (model, acc_batch)


@jax.jit

3 Source : cnn_model_distributed.py
with MIT License
from cgarciae

    def loss_fn(
        params: Module,
        module: Module,
        x: jnp.ndarray,
        y: jnp.ndarray,
    ) -> tp.Tuple[jnp.ndarray, tp.Tuple[Module, jnp.ndarray]]:
        module = module.merge(params)
        preds = module(x)

        loss = jnp.mean(
            optax.softmax_cross_entropy(
                preds,
                jax.nn.one_hot(y, 10),
            )
        )

        return loss, (module, preds)

    @partial(jax.pmap, axis_name="device")

3 Source : vae.py
with MIT License
from cgarciae

def loss_fn(params: VAE, model: VAE, x: np.ndarray) -> tp.Tuple[jnp.ndarray, VAE]:
    model = model.merge(params)
    x_pred = model(x)

    crossentropy_loss = jnp.mean(optax.sigmoid_binary_cross_entropy(x_pred, x))
    aux_losses = jax.tree_leaves(model.filter(tx.LossLog))

    loss = crossentropy_loss + sum(aux_losses, 0.0)

    return loss, model


@jax.jit

3 Source : squeeze_and_excite_layer.py
with Apache License 2.0
from DarshanDeshpande

    def __call__(self, inputs):
        batch, _, _, channels = inputs.shape
        global_avg_pool = jnp.mean(inputs, axis=[1, 2], keepdims=False)
        dense1 = nn.Dense(channels // self.reduction, use_bias=False)(global_avg_pool)
        dense1 = nn.relu(dense1)
        dense2 = nn.Dense(channels, use_bias=False)(dense1)
        dense2 = nn.sigmoid(dense2)
        expand_dims = jnp.reshape(dense2, (batch, 1, 1, channels))
        shape_broadcasted = jnp.broadcast_to(expand_dims, inputs.shape)
        return inputs * shape_broadcasted

3 Source : conv_mixer.py
with Apache License 2.0
from DarshanDeshpande

    def __call__(self, inputs, deterministic=None):
        extract_patches = nn.Conv(
            self.features, (self.patch_size, self.patch_size), self.patch_size
        )(inputs)
        x = nn.gelu(extract_patches)
        x = nn.BatchNorm(deterministic)(x)
        for _ in range(self.num_mixer_layers):
            x = ConvMixerLayer(self.features, self.filter_size)(x, deterministic)

        if self.attach_head:
            x = jnp.mean(x, [1, 2])
            x = nn.Dense(self.num_classes)(x)
            x = nn.softmax(x)

        return x


@register_model

3 Source : losses.py
with Apache License 2.0
from deepmind

def manual_cross_entropy(labels, logits, weight):
  ce = - weight * jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
  return jnp.mean(ce)


def byol_nce_detcon(pred1, pred2, target1, target2,

3 Source : utils_test.py
with Apache License 2.0
from deepmind

  def test_segment_normalize(self, nan_data):
    norm = lambda x: (x - jnp.mean(x)) * jax.lax.rsqrt(jnp.var(x))
    data = jnp.arange(8, dtype=jnp.float32)
    segment_ids = jnp.array([0, 0, 0, 1, 1, 2, 2, 2])
    expected_out = jnp.concatenate(
        [norm(jnp.arange(3, dtype=jnp.float32)),
         norm(jnp.arange(3, 5, dtype=jnp.float32)),
         norm(jnp.arange(5, 8, dtype=jnp.float32))])
    if nan_data:
      data = data.at[0].set(jnp.nan)
      expected_out = expected_out.at[:3].set(jnp.nan)
    result = utils.segment_normalize(data, segment_ids, 3)
    np.testing.assert_allclose(result, expected_out)

  @parameterized.parameters((False, False),

3 Source : lookahead_mnist.py
with Apache License 2.0
from deepmind

def categorical_crossentropy(logits: jnp.ndarray,
                             labels: jnp.ndarray) -> jnp.ndarray:
  losses = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=1)
  return jnp.mean(losses)


def accuracy(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:

3 Source : loss_test.py
with Apache License 2.0
from deepmind

  def testSigmoidCrossEntropy(self, preds, labels, expected):
    tested = jnp.mean(loss.sigmoid_binary_cross_entropy(preds, labels))
    np.testing.assert_allclose(tested, expected, rtol=1e-6, atol=1e-6)


class CosineDistanceTest(parameterized.TestCase):

3 Source : kernels_test.py
with Apache License 2.0
from deepmind

  def test_rff_approximate_hsic_xx(self, dim, c):
    num_features = 512
    rng = random.PRNGKey(42)
    rng1, rng2, rng_x = random.split(rng, 3)
    amp, amp_probs = kernels.imq_amplitude_frequency_and_probs(dim)
    rff_kwargs = {'amp': amp, 'amp_probs': amp_probs}
    x = random.uniform(rng_x, [1, dim])
    k = inverse_multiquadratic(x, x, c)
    k = k - jnp.mean(k, axis=1, keepdims=True)
    hsic_xx = (k * k).mean()
    hsic_xx_approx = kernels.rff_approximate_hsic_xx([x], num_features,
                                                     rng1, rng2, c, rff_kwargs)
    self.assertEqual(hsic_xx_approx, hsic_xx)

if __name__ == '__main__':

3 Source : graph.py
with MIT License
from ericmjl

def mseloss(p, model, Fs, As, y):
    yhat = model(p, Fs, As)
    return np.mean(_mse_loss(y, yhat))


def model(params, Fs, As):

3 Source : rnn.py
with MIT License
from ericmjl

def mseloss(p, model, x, y):
    yhat = model(p, x)
    return np.mean(_mse_loss(y, yhat))


dloss = grad(mseloss)

3 Source : vae.py
with MIT License
from ericmjl

def vae_loss(flat_params, unflattener, model, x, y):
    # Make predictions
    params = unflattener(flat_params)
    y_hat = model(params, x)

    # Define KL-divergence loss
    z_mean, z_log_var = encoder(params, x)

    # Loss is sum of cross-entropy loss and KL loss.
    return np.mean(
        cross_entropy_loss(y, y_hat) + kl_divergence(z_mean, z_log_var)
    )

3 Source : api_test.py
with Apache License 2.0
from google

  def test_staging_out_multi_replica(self):
    def f(x):
      return api.pmap(jnp.mean)(x)
    xla_comp = api.xla_computation(f)
    xla_comp(jnp.arange(8)).as_hlo_text()  # doesn't crash

  def test_xla_computation_instantiate_constant_outputs(self):

3 Source : no_dependency_learned_optimizer.py
with Apache License 2.0
from google

def meta_loss(meta_params, key, sequence_of_batches):

  def step(opt_state, batch):
    loss, grads = value_grad_fn(opt_state[0], batch)
    opt_state = lopt.update_inner_opt_state(meta_params, opt_state, grads)
    return opt_state, loss

  params = task.init(key)
  opt_state = lopt.initial_inner_opt_state(meta_params, params)
  # Iterate N times where N is the number of batches in sequence_of_batches
  opt_state, losses = jax.lax.scan(step, opt_state, sequence_of_batches)

  return jnp.mean(losses)


key = jax.random.PRNGKey(0)

3 Source : no_dependency_learned_optimizer.py
with Apache License 2.0
from google

def vectorized_meta_loss(meta_params, key, sequence_of_batches):
  vec_loss = jax.vmap(
      meta_loss, in_axes=(None, 0, 0))(meta_params, key, sequence_of_batches)
  return jnp.mean(vec_loss)


vec_meta_loss_grad = jax.jit(jax.value_and_grad(vectorized_meta_loss))

3 Source : no_dependency_learned_optimizer.py
with Apache License 2.0
from google

def vec_antithetic_es_estimate(meta_params, keys, vec_seq_batches):
  """Compute a ES estimated gradient along multiple directions."""
  losses, grads = jax.vmap(
      antithetic_es_estimate, in_axes=(None, 0, 0))(meta_params, keys,
                                                    vec_seq_batches)
  return jnp.mean(losses), [jnp.mean(g, axis=0) for g in grads]


keys = jax.random.split(key, 8)

3 Source : no_dependency_learned_optimizer.py
with Apache License 2.0
from google

def vec_short_segment_unroll(meta_params, keys, inner_opt_states, on_iterations,
                             vec_seq_of_batches):
  losses, inner_opt_states, on_iterations = jax.vmap(
      short_segment_unroll,
      in_axes=(None, 0, 0, 0, 0))(meta_params, keys, inner_opt_states,
                                  on_iterations, vec_seq_of_batches)
  return jnp.mean(losses), (inner_opt_states, on_iterations)


vec_short_segment_grad = jax.jit(

3 Source : summary_tutorial.py
with Apache License 2.0
from google

def forward(params):
  to_look_at = jnp.mean(params) * 2.
  return params


def loss(parameters):

3 Source : summary_tutorial.py
with Apache License 2.0
from google

def forward(params):
  to_look_at = jnp.mean(params) * 2.
  summary.summary("to_look_at", to_look_at)
  return params


@jax.jit

3 Source : summary_tutorial.py
with Apache License 2.0
from google

def layer(params):
  to_look_at = jnp.mean(params) * 2.
  summary.summary("to_look_at", to_look_at)
  return params * 2


@jax.jit

3 Source : summary_tutorial.py
with Apache License 2.0
from google

def layer(params):
  to_look_at = jnp.mean(params) * 2.
  summary.summary("to_look_at", to_look_at, aggregation="sample")
  return params * 2


@jax.jit

3 Source : summary_tutorial.py
with Apache License 2.0
from google

def layer(params):
  to_look_at = jnp.mean(params) * 2.
  summary.summary(
      "to_look_at", jnp.arange(10) * to_look_at, aggregation="collect")
  return params * 2


@jax.jit

3 Source : summary_tutorial.py
with Apache License 2.0
from google

def layer(params):
  to_look_at = jnp.mean(params) * 2.
  summary.summary("to_look_at", jnp.arange(10) * to_look_at)
  return params * 2


@functools.partial(jax.jit, static_argnames="with_summary")

3 Source : summary_tutorial.py
with Apache License 2.0
from google

def lossb(p):
  to_look_at = jnp.mean(123.)
  return p * 2, to_look_at


def loss(parameters):

3 Source : summary_tutorial.py
with Apache License 2.0
from google

def loss(parameters):
  loss = jnp.mean(parameters**2)
  to_look_at = jnp.mean(123.)
  hcb.id_print(to_look_at, name="to_look_at")
  return loss


value_grad_fn = jax.jit(jax.value_and_grad(loss))

3 Source : summary.py
with Apache License 2.0
from google

def tree_scalar_mean(prefix, values):
  for li, l in enumerate(jax.tree_leaves(values)):
    summary(prefix + "/" + str(li), jnp.mean(l))


def tree_step(prefix, values):

3 Source : summary.py
with Apache License 2.0
from google

def summarize_inner_params(params: Any):
  for k, v in _nested_to_names(params):
    summary(k + "/mean", jnp.mean(v))
    summary(k + "/mean_abs", jnp.mean(jnp.abs(v)))


class SummaryWriterBase:

3 Source : image_mlp.py
with Apache License 2.0
from google

  def loss(self, params: Params, key: PRNGKey, data: Any) -> jnp.ndarray:
    num_classes = self.datasets.extra_info["num_classes"]
    logits = self._mod.apply(params, key, data["image"])
    labels = jax.nn.one_hot(data["label"], num_classes)
    vec_loss = base.softmax_cross_entropy(logits=logits, labels=labels)
    return jnp.mean(vec_loss)

  def normalizer(self, loss):

3 Source : image_mlp.py
with Apache License 2.0
from google

  def loss_with_state(self, params: Params, state: ModelState, key: PRNGKey,
                      data: Any) -> Tuple[jnp.ndarray, ModelState]:
    num_classes = self.datasets.extra_info["num_classes"]
    logits, state = self._mod.apply(params, state, key, data["image"])
    labels = jax.nn.one_hot(data["label"], num_classes)
    vec_loss = base.softmax_cross_entropy(logits=logits, labels=labels)
    return jnp.mean(vec_loss), state

  def loss_with_state_and_aux(

3 Source : resnet.py
with Apache License 2.0
from google

  def _hk_forward(self, batch):
    args = [
        'blocks_per_group', 'use_projection', 'channels_per_group',
        'initial_conv_kernel_size', 'initial_conv_stride', 'max_pool',
        'resnet_v2'
    ]
    num_classes = self.datasets.extra_info['num_classes']
    mod = resnet.ResNet(
        num_classes=num_classes, **{k: self._cfg[k] for k in args})
    logits = mod(batch['image'], is_training=True)
    loss = base.softmax_cross_entropy(
        logits=logits, labels=jax.nn.one_hot(batch['label'], num_classes))
    return jnp.mean(loss)

  def init_with_state(self, key: chex.PRNGKey) -> base.Params:

3 Source : resnet.py
with Apache License 2.0
from google

  def __call__(self, inputs, is_training, test_local_stats=False):
    out = inputs
    out = self.initial_conv(out)
    if not self.resnet_v2:
      out = self.initial_batchnorm(out, is_training, test_local_stats)
      out = self.act_fn(out)

    if self.max_pool:
      out = hk.max_pool(
          out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME")

    for block_group in self.block_groups:
      out = block_group(out, is_training, test_local_stats)

    if self.resnet_v2:
      out = self.final_batchnorm(out, is_training, test_local_stats)
      out = self.act_fn(out)
    out = jnp.mean(out, axis=[1, 2])
    return self.logits(out)

3 Source : plot_test.py
with Apache License 2.0
from google

  def test_plot_model_fit_plot_called_with_scaler(self):
    target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
    target = target_scaler.fit_transform(jnp.ones(50))
    mmm = lightweight_mmm.LightweightMMM()
    mmm.fit(
        media=jnp.ones((50, 3)),
        target=target,
        total_costs=jnp.repeat(50, 3),
        number_warmup=5,
        number_samples=5,
        number_chains=1)

    plot.plot_model_fit(media_mix_model=mmm, target_scaler=target_scaler)

    self.assertTrue(self.mock_plt_plot.called)

  def test_plot_model_fit_plot_called_without_scaler(self):

3 Source : gpt2.py
with Apache License 2.0
from google

    def __call__(self, x):
        # mean across the axis
        u = np.mean(x, axis=self.axis, keepdims=True)
        # standard deviation
        s = np.mean((x - u) ** 2, axis=self.axis, keepdims=True)
        # rescaled values
        x = (x - u) * objax.functional.rsqrt(s + self.epsilon)
        x = x * self.g.value + self.b.value
        return x


def split_states(x, n):

3 Source : train_utils.py
with Apache License 2.0
from google

def sigmoid_xent(*, logits, labels, reduction=True):
  """Computes a sigmoid cross-entropy (Bernoulli NLL) loss over examples."""
  log_p = jax.nn.log_sigmoid(logits)
  log_not_p = jax.nn.log_sigmoid(-logits)
  nll = -jnp.sum(labels * log_p + (1. - labels) * log_not_p, axis=-1)
  return jnp.mean(nll) if reduction else nll


def softmax_xent(*, logits, labels, reduction=True, kl=False):

3 Source : bit_resnet.py
with Apache License 2.0
from google

def weight_standardize(w: jnp.ndarray,
                       axis: Union[Sequence[int], int],
                       eps: float):
  """Standardize (mean=0, var=1) a weight."""
  w = w - jnp.mean(w, axis=axis, keepdims=True)
  w = w / jnp.sqrt(jnp.mean(jnp.square(w), axis=axis, keepdims=True) + eps)
  return w


class IdentityLayer(nn.Module):

3 Source : loss.py
with MIT License
from gortizji

def softmax_cross_entropy_loss(logits, labels):
    if len(labels.shape)   <  = 1:
        num_classes = logits.shape[-1]
        soft_labels = jax.nn.one_hot(labels, num_classes=num_classes)
    else:
        soft_labels = labels
    return jnp.mean(optax.softmax_cross_entropy(logits, soft_labels))


def binary_cross_entropy_loss_with_logits(logits, labels):

3 Source : modeling_flax_beit.py
with Apache License 2.0
from huggingface

    def __call__(self, hidden_states):
        if self.config.use_mean_pooling:
            # Mean pool the final hidden states of the patch tokens
            patch_tokens = hidden_states[:, 1:, :]
            pooled_output = self.layernorm(jnp.mean(patch_tokens, axis=1))
        else:
            # Pool by simply taking the final hidden state of the [CLS] token
            pooled_output = hidden_states[:, 0]

        return pooled_output


class FlaxBeitModule(nn.Module):

3 Source : utils.py
with MIT License
from jamesvuc

def confidence_bands(y, sample_axis=-1):
    """ Computes confidence bands for samples y.

    Args:
        y: array of samples.

    Returns:
        mean and standard deviation along the axes not equal to sample_axis.
    """
    m = jnp.mean(y, axis=sample_axis)
    s = jnp.std(y, axis=sample_axis)
    return m - s, m + s

3 Source : mpi_wrapper_t.py
with MIT License
from markusschmitt

    def test_mean(self):
        
        data=jnp.array(np.arange(720*4*global_defs.device_count()).reshape((global_defs.device_count()*720,4)))
        myNumSamples = mpi.distribute_sampling(global_defs.device_count()*720)

        myData=data[mpi.rank*myNumSamples:(mpi.rank+1)*myNumSamples].reshape(get_shape((-1,4)))

        self.assertTrue( jnp.sum(mpi.global_mean(myData)-jnp.mean(data,axis=0))   <   1e-10 )

    def test_var(self):

3 Source : test_mpi_stats.py
with Apache License 2.0
from netket

def test_mean(axis, arr, arr_loc):
    arr_mean = jnp.mean(arr, axis=axis)

    nk_mean = nk.stats.mean(arr_loc, axis=axis)
    np.testing.assert_allclose(nk_mean, arr_mean)

    nk_mean = nk.stats.mean(arr_loc, keepdims=True, axis=axis)
    assert nk_mean.ndim == arr.ndim
    np.testing.assert_allclose(nk_mean, arr_mean.reshape(1, -1))


@pytest.mark.parametrize("axis", [None, 0])

3 Source : test_mpi_stats.py
with Apache License 2.0
from netket

def test_subtract_mean(axis, arr, arr_loc, _mpi_rank):
    arr_sub = arr - jnp.mean(arr, axis=axis)

    nk_sub = nk.stats.subtract_mean(arr_loc, axis=axis)
    np.testing.assert_allclose(
        nk_sub, arr_sub[100 * _mpi_rank : 100 * (_mpi_rank + 1), :]
    )


@pytest.mark.parametrize("axis", [None, 0])

3 Source : data.py
with MIT License
from nikikilbertus

def whiten(
  inputs: Dict[Text, np.ndarray]
) -> Dict[Text, Union[float, np.ndarray, None]]:
  """Whiten each input."""
  res = {}
  for k, v in inputs.items():
    if v is not None:
      mu = np.mean(v, 0)
      std = np.maximum(np.std(v, 0), 1e-7)
      res[k + "_mu"] = mu
      res[k + "_std"] = std
      res[k] = (v - mu) / std
    else:
      res[k] = v
  return res


def whiten_with_mu_std(val: np.ndarray, mu: float, std: float) -> np.ndarray:

3 Source : plotting.py
with MIT License
from nikikilbertus

def plot_hist_at_z(y: np.ndarray, bin_ids: np.ndarray, idx: int) -> plt.Figure:
  fig = plt.figure()
  plt.hist(y[bin_ids == idx], bins=30)
  plt.xlabel('y')
  mean = np.mean(y[bin_ids == idx])
  var = np.var(y[bin_ids == idx])
  plt.title(
    f"mean {mean:.2f} and variance {var:.2f} for z bin {idx}")
  return fig


@empty_fig_on_failure

3 Source : run.py
with MIT License
from nikikilbertus

def get_phi(y: np.ndarray) -> np.ndarray:
  """The phis for the constraints."""
  return np.array([np.mean(y, axis=-1), np.var(y, axis=-1)]).T


@jit

3 Source : char_rnn.py
with MIT License
from NTT123

def update_fn(model, optimizer, multi_batch: jnp.ndarray):
    (model, optimizer), losses = pax.scan(update_step, (model, optimizer), multi_batch)
    return model, optimizer, jnp.mean(losses)


net = LM(vocab_size=vocab_size, hidden_dim=hidden_dim)

3 Source : train.py
with MIT License
from NTT123

def loss_fn(model: WaveGRU, inputs):
    logmel, wav = inputs
    input_wav = wav[:, :-1]
    target_wav = wav[:, 1:]
    model, logits = pax.purecall(model, (logmel, input_wav))
    log_pr = jax.nn.log_softmax(logits, axis=-1)
    target_wave = jax.nn.one_hot(target_wav, num_classes=logits.shape[-1])
    log_pr = jnp.sum(log_pr * target_wave, axis=-1)
    loss = -jnp.mean(log_pr)
    return loss, (loss, model)


def generate_test_sample(step, test_logmel, wave_gru, length, sample_rate, mu):

3 Source : __init__.py
with GNU General Public License v3.0
from PKU-NIP-Lab

def mean_absolute_error(x, y, axis=None):
  r"""Computes the mean absolute error between x and y.

  Args:
      x: a tensor of shape (d0, .. dN-1).
      y: a tensor of shape (d0, .. dN-1).
      keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value.

  Returns:
      tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error.
  """
  r = ops.abs(x - y)
  return jn.mean(ops.as_device_array(r), axis=axis)


def mean_squared_error(predicts, targets, axis=None):

3 Source : __init__.py
with GNU General Public License v3.0
from PKU-NIP-Lab

def mean_squared_error(predicts, targets, axis=None):
  r"""Computes the mean squared error between x and y.

  Args:
      predicts: a tensor of shape (d0, .. dN-1).
      targets: a tensor of shape (d0, .. dN-1).
      keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value.

  Returns:
      tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error.
  """
  r = (predicts - targets) ** 2
  return jn.mean(ops.as_device_array(r), axis=axis)


def mean_squared_log_error(y_true, y_pred, axis=None):

3 Source : __init__.py
with GNU General Public License v3.0
from PKU-NIP-Lab

def mean_squared_log_error(y_true, y_pred, axis=None):
  r"""Computes the mean squared logarithmic error between y_true and y_pred.

  Args:
      y_true: a tensor of shape (d0, .. dN-1).
      y_pred: a tensor of shape (d0, .. dN-1).
      keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value.

  Returns:
      tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error.
  """
  r = (ops.log1p(y_true) - ops.log1p(y_pred)) ** 2
  return jn.mean(ops.as_device_array(r), axis=axis)


def huber_loss(predicts, targets, delta: float = 1.0):

See More Examples