jax.numpy.argmax

Here are the examples of the python api jax.numpy.argmax taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

81 Examples 7

3 Source : lookahead_mnist.py
with Apache License 2.0
from deepmind

def accuracy(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  predictions = jnp.argmax(logits, axis=-1)
  return jnp.mean(jnp.argmax(labels, axis=-1) == predictions)


def model_accuracy(

3 Source : mlp_agent.py
with Apache License 2.0
from google

def select_action(network_def: Any,
                  network_params: np.ndarray,
                  state: Any) -> int:
  """Select an action greedily from network."""
  return jnp.argmax(network_def.apply(network_params, state))


@functools.partial(jax.jit, static_argnums=(0, 7))

3 Source : model_utils.py
with Apache License 2.0
from google

def compute_depth_index(weights, depth_threshold=0.5):
  """Compute the sample index of the median depth accumulation."""
  opaqueness_mask = compute_opaqueness_mask(weights, depth_threshold)
  return jnp.argmax(opaqueness_mask, axis=-1)


def compute_depth_map(weights, z_vals, depth_threshold=0.5):

3 Source : differentially_private_sgd.py
with Apache License 2.0
from google

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)


def clipped_grad(params, l2_norm_clip, single_example_batch):

3 Source : mnist_classifier.py
with Apache License 2.0
from google

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(

3 Source : mnist_classifier_fromscratch.py
with Apache License 2.0
from google

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)


if __name__ == "__main__":

3 Source : resnet50.py
with Apache License 2.0
from google

  def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=-1)
    predicted_class = jnp.argmax(predict_fun(params, inputs), axis=-1)
    return jnp.mean(predicted_class == target_class)

  def synth_batches():

3 Source : mnist_lib.py
with Apache License 2.0
from google

  def accuracy(predict: Callable, params, dataset):

    @jax.jit
    def _per_batch(inputs, labels):
      target_class = jnp.argmax(labels, axis=1)
      predicted_class = jnp.argmax(predict(params, inputs), axis=1)
      return jnp.mean(predicted_class == target_class)

    batched = [
      _per_batch(inputs, labels) for inputs, labels in tfds.as_numpy(dataset)
    ]
    return jnp.mean(jnp.stack(batched))

  @staticmethod

3 Source : imagenet_resnet50_train.py
with Apache License 2.0
from google

    def evaluate_batch(self, images, labels):
        logits = self.model(images, training=False)
        num_correct = jn.count_nonzero(jn.equal(jn.argmax(logits, axis=1), labels))
        return num_correct

    def run_eval(self):

3 Source : mnist_infl_mem.py
with Apache License 2.0
from google-research

def batch_correctness(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return predicted_class == target_class


def load_mnist():

3 Source : evaluation.py
with Apache License 2.0
from google-research

def accuracy(logits, targets, weights=None):
  """Sequence accuracy averaged over sequence length."""
  if logits.ndim != targets.ndim + 1:
    raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" %
                     (str(logits.shape), str(targets.shape)))
  acc = jnp.equal(jnp.argmax(logits, axis=-1), targets)
  normalizing_factor = np.prod(logits.shape[:-1])
  if weights is not None:
    acc = acc * weights * 100
    normalizing_factor = weights.sum()
  return acc.sum(), normalizing_factor


def sequence_accuracy(prediction, target):

3 Source : mention_based_entity_qa_task.py
with Apache License 2.0
from google-research

def get_predictions_max(attention_weights: Array, memory_entity_ids: Array,
                        weights: Array) -> Array:
  """Predict entity ID based on a memory with the largest attention weight."""
  # `stop_gradient` is a safety check so the model doesn't keep activations
  # around, which could be expensive.
  attention_weights = jax.lax.stop_gradient(attention_weights)
  memory_entity_ids = jax.lax.stop_gradient(memory_entity_ids)
  weights = jax.lax.stop_gradient(weights)
  memory_with_largest_attn = jnp.argmax(attention_weights, axis=1)
  predictions = jnp.take_along_axis(
      memory_entity_ids, jnp.expand_dims(memory_with_largest_attn, 1), axis=1)
  predictions = jnp.squeeze(predictions, axis=1)
  predictions = predictions * weights
  return predictions


def get_predictions_sum(attention_weights: Array, memory_entity_ids: Array,

3 Source : metrics.py
with Apache License 2.0
from google-research

def pointer_accuracy_fn(pointer_logits: dt.BatchedTocopoLogits,
                        target_data: dt.BatchedTrainTocopoTargetData):
  """Computes average 0-1 accuracy of pointer predictions at timestep 1."""
  one_hot_argmax_predictions = jax.nn.one_hot(
      jnp.argmax(pointer_logits, axis=1), pointer_logits.shape[1])

  # The pointer should always be at timestep 1 in copy-paste data. Slice out
  # the BV array of those pointer targets.
  is_pointer_targets = target_data.is_target_pointer[:, 1, :]

  num_correct = jnp.sum(one_hot_argmax_predictions * is_pointer_targets)
  num_attempts = is_pointer_targets.shape[0]
  return num_correct, num_attempts


@flax_dataclass

3 Source : ensemble.py
with Apache License 2.0
from google-research

def label_pred_ensemble_softmax(logits: Array, ensemble_size: int) -> Array:
  """Function to select the predicted labels for the ensemble softmax CE.

  Args:
    logits: 2D tensor of shape [ensemble size * batch size, #classes]. It is
      assumed that the batches for each ensemble member are stacked with a
      jnp.repeat(..., ensemble_size) pattern and can thus be recovered by
      an appropriate slicing ...[::ensemble_size].
    ensemble_size: The size of the ensemble.
  Returns:
    The class labels to predict.
  """
  log_p = _get_log_ensemble_softmax(logits, ensemble_size)  # Shape: (B, C).
  return jnp.argmax(log_p, axis=1)  # Shape: (B,).

3 Source : trainer_test.py
with Apache License 2.0
from google-research

  def test(self, name, kwargs, expected_shape):
    train_loss_fn, eval_loss_fn, label_pred_fn = trainer.get_loss_fn(
        name, **kwargs)
    logits = jax.random.normal(jax.random.PRNGKey(0), (8, 16))
    labels = jax.random.uniform(jax.random.PRNGKey(0), (8, 16))

    train_loss = train_loss_fn(logits, labels)
    eval_loss = eval_loss_fn(logits, labels)

    self.assertEqual(train_loss.shape, expected_shape)
    self.assertEqual(eval_loss.shape, expected_shape)
    # The train and eval losses are the same in the standard case (no ensemble).
    self.assertSequenceEqual(list(train_loss), list(eval_loss))
    self.assertSequenceEqual(list(label_pred_fn(logits)),
                             list(jnp.argmax(logits, 1)))

  @parameterized.named_parameters(

3 Source : metrics.py
with MIT License
from gortizji

def compute_accuracy_metrics(logits, labels):
    loss = softmax_cross_entropy_loss(logits, labels)
    if len(labels.shape) > 1:
        accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
    else:
        accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        "loss": loss,
        "accuracy": accuracy,
    }
    return metrics


def compute_binary_accuracy_metrics(logits, labels):

3 Source : run_t5_mlm_flax.py
with Apache License 2.0
from gsarti

    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss
        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)

        # summarize metrics
        metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return metrics

    p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))

3 Source : ffjord_mnist.py
with MIT License
from jacobjinkelly

def _acc_fn(logits, labels):
    """
    Classification accuracy of the model.
    """
    predicted_class = jnp.argmax(logits, axis=1)
    return jnp.mean(predicted_class == labels)


def standard_normal_logprob(z):

3 Source : mnist.py
with MIT License
from jacobjinkelly

def _acc_fn(logits, labels):
    """
    Classification accuracy of the model.
    """
    predicted_class = jnp.argmax(logits, axis=1)
    return jnp.mean(predicted_class == labels)


def _loss_fn(logits, labels):

3 Source : mnist_classifier.py
with Apache License 2.0
from juliuskunze

def accuracy(inputs, targets):
    target_class = np.argmax(targets, axis=1)
    predicted_class = np.argmax(predict(inputs), axis=1)
    return np.mean(predicted_class == target_class)


def main():

3 Source : dqn.py
with MIT License
from ku2482

    def _forward(
        self,
        params: hk.Params,
        state: np.ndarray,
    ) -> jnp.ndarray:
        return jnp.argmax(self.net.apply(params, state), axis=1)

    def update(self, writer=None):

3 Source : fqf.py
with MIT License
from ku2482

    def _forward_from_feature(
        self,
        params_cum_p: hk.Params,
        params: hk.Params,
        feature: np.ndarray,
    ) -> jnp.ndarray:
        cum_p, cum_p_prime = self.cum_p_net.apply(params_cum_p, feature)
        quantile_s = self.net["quantile"].apply(params["quantile"], feature, cum_p_prime)
        q_s = ((cum_p[:, 1:, None] - cum_p[:, :-1, None]) * quantile_s).sum(axis=1)
        return jnp.argmax(q_s, axis=1)

    def update(self, writer=None):

3 Source : iqn.py
with MIT License
from ku2482

    def _forward(
        self,
        params: hk.Params,
        state: np.ndarray,
        key: jnp.ndarray,
    ) -> jnp.ndarray:
        cum_p = jax.random.uniform(key, (state.shape[0], self.num_quantiles_eval))
        return jnp.argmax(self.net.apply(params, state, cum_p).mean(axis=1), axis=1)

    @partial(jax.jit, static_argnums=0)

3 Source : qrdqn.py
with MIT License
from ku2482

    def _forward(
        self,
        params: hk.Params,
        state: np.ndarray,
    ) -> jnp.ndarray:
        return jnp.argmax(self.net.apply(params, state).mean(axis=1), axis=1)

    @partial(jax.jit, static_argnums=0)

3 Source : sac_discrete.py
with MIT License
from ku2482

    def _select_action(
        self,
        params_actor: hk.Params,
        state: np.ndarray,
    ) -> jnp.ndarray:
        pi_s, _ = self.actor.apply(params_actor, state)
        return jnp.argmax(pi_s, axis=1)

    @partial(jax.jit, static_argnums=0)

3 Source : mnist_adahessian.py
with Apache License 2.0
from nestordemeure

def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(

3 Source : main.py
with MIT License
from niklasschmitz

def choose_action(env, observation, pred_Q, params_Q_eval, eps):
    rand = onp.random.random()
    actions = pred_Q(params_Q_eval, observation)
    if rand   <   1 - eps:
        action = np.argmax(actions[0])
    else:
        action = env.action_space.sample()
    return action


def stack_frames(frames):

3 Source : numpy_ops.py
with GNU General Public License v3.0
from PKU-NIP-Lab

def argmax(x, axis=None):
  x = _remove_jaxarray(x)
  r = jnp.argmax(x, axis=axis)
  return r if axis is None else JaxArray(r)


def argmin(x, axis=None):

3 Source : sgd_flax_demo.py
with MIT License
from probml

def callback_fn(**kwargs):
  
  logprobs = model.apply(kwargs["belief_state"].params, kwargs["X_train"])
  y_pred = jnp.argmax(logprobs, axis=-1)
  train_acc = jnp.mean(y_pred == kwargs["Y_train"])

  logprobs = model.apply(kwargs["belief_state"].params, kwargs["X_train"])
  y_pred = jnp.argmax(kwargs["preds"][0], axis=-1)
  test_acc = jnp.mean(y_pred == kwargs["Y_test"])
  print("Loss: ", kwargs["info"].loss)
  print(f"Train Accuracy: {train_acc}, Test Accuracy: {test_acc}")

def main():

3 Source : ops.py
with Apache License 2.0
from pyro-ppl

def _argmax(x, axis, keepdims):
    if keepdims:
        return np.expand_dims(np.argmax(x, axis), axis)
    return np.argmax(x, axis)


@ops.argmin.register(array)

3 Source : submit_to_leaderboard.py
with MIT License
from rowanz

def pred_step(state: train_state.TrainState, batch):
    logits_from_audio, logits_from_text = state.apply_fn({'params': state.params}, batch)

    out = {'logprobs_audio': jax.nn.log_softmax(logits_from_audio, axis=-1),
            'preds_audio': jnp.argmax(logits_from_audio, -1),
            'logprobs_text': jax.nn.log_softmax(logits_from_text, axis=-1),
            'preds_text': jnp.argmax(logits_from_text, -1),
            }
    softmax_joint = jax.nn.softmax(logits_from_audio, axis=-1) + jax.nn.softmax(logits_from_text, axis=-1)
    out['preds_joint'] = jnp.argmax(softmax_joint, -1)
    return out
p_pred_step = jax.pmap(pred_step, axis_name='batch', donate_argnums=(1,))

3 Source : tvqa_finetune.py
with MIT License
from rowanz

def pred_step(state: train_state.TrainState, batch):
    logits_from_audio, logits_from_text = state.apply_fn({'params': state.params}, batch)

    out = {'logprobs_audio': jax.nn.log_softmax(logits_from_audio, axis=-1),
            'preds_audio': jnp.argmax(logits_from_audio, -1),
            'logprobs_text': jax.nn.log_softmax(logits_from_text, axis=-1),
            'preds_text': jnp.argmax(logits_from_text, -1),
            }
    softmax_joint = jax.nn.softmax(logits_from_audio, axis=-1) + jax.nn.softmax(logits_from_text, axis=-1)
    out['preds_joint'] = jnp.argmax(softmax_joint, -1)
    return out


p_pred_step = jax.pmap(pred_step, axis_name='batch', donate_argnums=(1,))

3 Source : qa_qar_joint_finetune.py
with MIT License
from rowanz

def train_loss_fn(state, params, batch):
    logits = state.apply_fn({'params': params}, batch)
    log_p = jax.nn.log_softmax(logits, axis=-1)
    labels_oh = jax.nn.one_hot(batch['labels'], dtype=log_p.dtype, num_classes=log_p.shape[-1])

    loss = -jnp.mean(jnp.sum(labels_oh * log_p, axis=-1))
    is_right = (jnp.argmax(log_p, -1) == batch['labels']).astype(jnp.float32).mean()
    return loss, {'is_right': is_right, 'loss': loss}


p_train_step = jax.pmap(functools.partial(finetune_train_step, loss_fn=train_loss_fn, tx_fns=tx_fns),

3 Source : training_utilities.py
with BSD 2-Clause "Simplified" License
from SarthakYadav

def compute_metrics(logits, labels, mode, cost_fn):
    loss = cost_fn(logits, labels)
    metrics = {
        'loss': loss
    }
    if mode == TrainingMode.MULTICLASS:
        accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
        # labels is now a (batch, num_classes) onehot-encoded array
        metrics['accuracy'] = accuracy

    metrics = lax.pmean(metrics, axis_name='batch')
    return metrics


def create_input_iter(ds, devices=None):

3 Source : haiku_run.py
with Apache License 2.0
from Scitator

    def _eval_batch(
        self, params: hk.Params, state: hk.State, rng: jnp.ndarray, batch: Batch
    ) -> jnp.ndarray:
        logits, _ = self.forward.apply(params, state, rng, batch)
        y = batch["labels"]
        return jnp.sum(jnp.argmax(logits, axis=-1) == y)

    def _update_fn(

3 Source : jax_run.py
with Apache License 2.0
from Scitator

def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)


def loss(params, images, targets):

3 Source : penguin_utils_flax_experimental.py
with Apache License 2.0
from tensorflow

def _compute_metrics(logits: _LogitBatch, labels: _LabelBatch):
  # assumes that the logits use log_softmax activations.
  loss = _categorical_cross_entropy_loss(logits, labels)
  accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels[..., 0])
  return {
      'loss': loss,
      'accuracy': accuracy,
  }


def _mean_epoch_metrics(

3 Source : utils.py
with MIT License
from vballoli

def error_rate_metric(logits,
                      one_hot_labels):
  """Returns the error rate between some predictions and some labels.
  Args:
    logits: Output of the model.
    one_hot_labels: One-hot encoded labels. Dimensions should match the logits.
  Returns:
    The error rate (1 - accuracy), averaged over the first dimension (samples).
  """
  return jnp.mean(jnp.argmax(logits, -1) != jnp.argmax(one_hot_labels, -1))


def tensorflow_to_numpy(xs):

3 Source : train.py
with Apache License 2.0
from yzhwang

def eval_batch(
    params: hk.Params,
    state: hk.State,
    batch: dataset.Batch,
) -> jnp.ndarray:
  """Evaluates a batch."""
  logits, _ = forward.apply(params, state, None, batch, is_training=False)
  predicted_label = jnp.argmax(logits, axis=-1)
  correct = jnp.sum(jnp.equal(predicted_label, batch['labels']))
  return correct.astype(jnp.float32)


def evaluate(

0 Source : accuracy.py
with MIT License
from cgarciae

    def update(
        self,
        target: jnp.ndarray,
        preds: jnp.ndarray,
        sample_weight: tp.Optional[jnp.ndarray] = None,
    ):
        """
        Accumulates metric statistics. `target` and `preds` should have the same shape.

        Arguments:
            target: Ground truth values. shape = `[batch_size, d0, .. dN]`.
            preds: The predicted values. shape = `[batch_size, d0, .. dN]`.
            sample_weight: Optional `sample_weight` acts as a
                coefficient for the metric. If a scalar is provided, then the metric is
                simply scaled by the given value. If `sample_weight` is a tensor of size
                `[batch_size]`, then the metric for each sample of the batch is rescaled
                by the corresponding element in the `sample_weight` vector. If the shape
                of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted
                to this shape), then each metric element of `preds` is scaled by the
                corresponding value of `sample_weight`. (Note on `dN-1`: all metric
                functions reduce by 1 dimension, usually the last axis (-1)).
        Returns:
            Array with the cumulative accuracy.
        """

        if self.argmax_preds:
            preds = jnp.argmax(preds, axis=-1)

        if self.argmax_labels:
            target = jnp.argmax(target, axis=-1)

        values = accuracy(target=target, preds=preds)

        super().update(
            values=values,
            sample_weight=sample_weight,
        )

0 Source : wcpg.py
with MIT License
from ChrisWaites

    def step(self, action, rng):
        action_index = int(np.argmax(action))

        dynamics = self.dynamics[self.state_index][action_index]

        temp, rng = random.split(rng)
        next_state_index = int(random.choice(temp, 7, p=dynamics))
        self.state_index = next_state_index
        next_state = self.states[next_state_index]

        mean = 1. if next_state_index == 0 else 2.
        std = 1. if next_state_index == 0 else 2.

        temp, rng = random.split(rng)
        reward = std * random.normal(temp) + mean

        done = (next_state_index >= 5)

        return next_state, reward, done, {}

    def reset(self):

0 Source : wcpg.py
with MIT License
from ChrisWaites

def is_terminal(next_state):
    return np.argmax(next_state, 1) >= 5


def critic_loss(critic_params, fixed_critic_params, fixed_actor_params, env_dynamics, batch):

0 Source : decoders.py
with Apache License 2.0
from deepmind

def postprocess(spec: _Spec, preds: Dict[str, _Array]) -> Dict[str, _DataPoint]:
  """Postprocesses decoder output."""
  result = {}
  for name in preds.keys():
    _, loc, t = spec[name]
    data = preds[name]
    if t == _Type.SCALAR:
      pass
    elif t == _Type.MASK:
      data = (data > 0.0) * 1.0
    elif t in [_Type.MASK_ONE, _Type.CATEGORICAL]:
      cat_size = data.shape[-1]
      best = jnp.argmax(data, -1)
      data = hk.one_hot(best, cat_size)
    elif t == _Type.POINTER:
      data = jnp.argmax(data, -1)
    else:
      raise ValueError("Invalid type")
    result[name] = probing.DataPoint(
        name=name, location=loc, type_=t, data=data)

  return result


def decode_fts(

0 Source : network.py
with Apache License 2.0
from deepmind

def sample_from_logits(
    logits: jnp.ndarray,
    legal_action_mask: jnp.ndarray,
    temperature: jnp.ndarray,):
  """Sample from logits respecting a legal actions mask."""
  deterministic_logits = jnp.where(
      jax.nn.one_hot(
          jnp.argmax(logits, axis=-1),
          num_classes=action_utils.MAX_ACTION_INDEX,
          dtype=jnp.bool_), 0,
      jnp.finfo(jnp.float32).min)
  stochastic_logits = jnp.where(legal_action_mask,
                                logits / temperature,
                                jnp.finfo(jnp.float32).min)

  logits_for_sampling = jnp.where(
      jnp.equal(temperature, 0.0),
      deterministic_logits,
      stochastic_logits)

  # Sample an action for the current province and update the state so that
  # following orders can be conditioned on this decision.
  key = hk.next_rng_key()
  return jax.random.categorical(
      key, logits_for_sampling, axis=-1)


class RelationalOrderDecoderState(NamedTuple):

0 Source : zacharys_karate_club.py
with Apache License 2.0
from deepmind

def optimize_club(num_steps: int):
  """Solves the karte club problem by optimizing the assignments of students."""
  network = hk.without_apply_rng(hk.transform(network_definition))
  zacharys_karate_club = get_zacharys_karate_club()
  labels = get_ground_truth_assignments_for_zacharys_karate_club()
  params = network.init(jax.random.PRNGKey(42), zacharys_karate_club)

  @jax.jit
  def prediction_loss(params):
    decoded_nodes = network.apply(params, zacharys_karate_club)
    # We interpret the decoded nodes as a pair of logits for each node.
    log_prob = jax.nn.log_softmax(decoded_nodes)
    # The only two assignments we know a-priori are those of Mr. Hi (Node 0)
    # and John A (Node 33).
    return -(log_prob[0, 0] + log_prob[33, 1])

  opt_init, opt_update = optax.adam(1e-2)
  opt_state = opt_init(params)

  @jax.jit
  def update(params, opt_state):
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

  @jax.jit
  def accuracy(params):
    decoded_nodes = network.apply(params, zacharys_karate_club)
    return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == labels)

  for step in range(num_steps):
    logging.info("step %r accuracy %r", step, accuracy(params).item())
    params, opt_state = update(params, opt_state)


def main(_):

0 Source : train.py
with Apache License 2.0
from deepmind

def compute_loss(params, graph, label, net):
  """Computes loss."""
  pred_graph = net.apply(params, graph)
  preds = jax.nn.log_softmax(pred_graph.globals)
  targets = jax.nn.one_hot(label, 2)

  # Since we have an extra 'dummy' graph in our batch due to padding, we want
  # to mask out any loss associated with the dummy graph.
  # Since we padded with `pad_with_graphs` we can recover the mask by using
  # get_graph_padding_mask.
  mask = jraph.get_graph_padding_mask(pred_graph)

  # Cross entropy loss.
  loss = -jnp.mean(preds * targets * mask[:, None])

  # Accuracy taking into account the mask.
  accuracy = jnp.sum(
      (jnp.argmax(pred_graph.globals, axis=1) == label) * mask)/jnp.sum(mask)
  return loss, accuracy


def train(data_path, master_csv_path, split_path, batch_size,

0 Source : train_flax.py
with Apache License 2.0
from deepmind

def compute_loss(params, graph, label, net):
  """Computes loss."""
  pred_graph = net.apply(params, graph)
  preds = jax.nn.log_softmax(pred_graph.globals)
  targets = jax.nn.one_hot(label, 2)

  # Since we have an extra 'dummy' graph in our batch due to padding, we want
  # to mask out any loss associated with the dummy graph.
  # Since we padded with `pad_with_graphs` we can recover the mask by using
  # get_graph_padding_mask.
  mask = jraph.get_graph_padding_mask(pred_graph)

  # Cross entropy loss.
  loss = -jnp.mean(preds * targets * mask[:, None])

  # Accuracy taking into account the mask.
  accuracy = jnp.sum(
      (jnp.argmax(pred_graph.globals, axis=1) == label) * mask)/jnp.sum(mask)
  return loss, accuracy


def train_step(optimizer, graph, label, net):

0 Source : action_selection.py
with Apache License 2.0
from deepmind

def gumbel_muzero_interior_action_selection(
    rng_key: chex.PRNGKey,
    tree: tree_lib.Tree,
    node_index: chex.Numeric,
    depth: chex.Numeric,
    *,
    qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
) -> chex.Array:
  """Selects the action with a deterministic action selection.

  The action is selected based on the visit counts to produce visitation
  frequencies similar to softmax(prior_logits + qvalues).

  Args:
    rng_key: random number generator state.
    tree: _unbatched_ MCTS tree state.
    node_index: scalar index of the node from which to take an action.
    depth: the scalar depth of the current node. The root has depth zero.
    qtransform: function to obtain completed Q-values for a node.

  Returns:
    action: the action selected from the given node.
  """
  del rng_key, depth
  chex.assert_shape([node_index], ())
  visit_counts = tree.children_visits[node_index]
  prior_logits = tree.children_prior_logits[node_index]
  chex.assert_equal_shape([visit_counts, prior_logits])
  completed_qvalues = qtransform(tree, node_index)

  # The `prior_logits + completed_qvalues` provide an improved policy,
  # because the missing qvalues are replaced by v_{prior_logits}(node).
  to_argmax = _prepare_argmax_input(
      probs=jax.nn.softmax(prior_logits + completed_qvalues),
      visit_counts=visit_counts)

  chex.assert_rank(to_argmax, 1)
  return jnp.argmax(to_argmax, axis=-1)


def masked_argmax(

0 Source : random.py
with Apache License 2.0
from google

def categorical(key: KeyArray,
                logits: RealArray,
                axis: int = -1,
                shape: Optional[Sequence[int]] = None) -> jnp.ndarray:
  """Sample random values from categorical distributions.

  Args:
    key: a PRNG key used as the random key.
    logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
      so that `softmax(logits, axis)` gives the corresponding probabilities.
    axis: Axis along which logits belong to the same categorical distribution.
    shape: Optional, a tuple of nonnegative integers representing the result shape.
      Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
      The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.

  Returns:
    A random array with int dtype and shape given by ``shape`` if ``shape``
    is not None, or else ``np.delete(logits.shape, axis)``.
  """
  key, _ = _check_prng_key(key)

  if axis >= 0:
    axis -= len(logits.shape)

  batch_shape = tuple(np.delete(logits.shape, axis))
  if shape is None:
    shape = batch_shape
  else:
    shape = tuple(shape)
    _check_shape("categorical", shape, batch_shape)

  sample_shape = shape[:len(shape)-len(batch_shape)]
  return jnp.argmax(
      gumbel(key, sample_shape + logits.shape, logits.dtype) +
      lax.expand_dims(logits, tuple(range(len(sample_shape)))),
      axis=axis)


def laplace(key: KeyArray,

0 Source : deep_ensemble.py
with Apache License 2.0
from google

def main(config, output_dir):

  seed = config.get('seed', 0)
  tf.random.set_seed(seed)

  if config.get('data_dir'):
    logging.info('data_dir=%s', config.data_dir)
  logging.info('Output dir: %s', output_dir)
  tf.io.gfile.makedirs(output_dir)

  # Create an asynchronous multi-metric writer.
  writer = metric_writers.create_default_writer(
      output_dir, just_logging=jax.process_index() > 0)

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  def write_note(note):
    if jax.process_index() == 0:
      logging.info('NOTE: %s', note)
  write_note('Initializing...')

  batch_size = config.batch_size
  batch_size_eval = config.get('batch_size_eval', batch_size)
  if (batch_size % jax.device_count() != 0 or
      batch_size_eval % jax.device_count() != 0):
    raise ValueError(f'Batch sizes ({batch_size} and {batch_size_eval}) must '
                     f'be divisible by device number ({jax.device_count()})')

  local_batch_size = batch_size // jax.process_count()
  local_batch_size_eval = batch_size_eval // jax.process_count()
  logging.info(
      'Global batch size %d on %d hosts results in %d local batch size. '
      'With %d devices per host (%d devices total), that\'s a %d per-device '
      'batch size.', batch_size, jax.process_count(), local_batch_size,
      jax.local_device_count(), jax.device_count(),
      local_batch_size // jax.local_device_count())

  write_note('Initializing val dataset(s)...')

  def _get_val_split(dataset, split, pp_eval, data_dir=None):
    # We do ceil rounding such that we include the last incomplete batch.
    nval_img = input_utils.get_num_examples(
        dataset,
        split=split,
        process_batch_size=local_batch_size_eval,
        drop_remainder=False,
        data_dir=data_dir)
    val_steps = int(np.ceil(nval_img / batch_size_eval))
    logging.info('Running validation for %d steps for %s, %s', val_steps,
                 dataset, split)

    if isinstance(pp_eval, str):
      pp_eval = preprocess_spec.parse(
          spec=pp_eval, available_ops=preprocess_utils.all_ops())

    val_ds = input_utils.get_data(
        dataset=dataset,
        split=split,
        rng=None,
        process_batch_size=local_batch_size_eval,
        preprocess_fn=pp_eval,
        cache=config.get('val_cache', 'batched'),
        num_epochs=1,
        repeat_after_batching=True,
        shuffle=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        drop_remainder=False,
        data_dir=data_dir)

    return val_ds

  val_ds_splits = {
      'val':
          _get_val_split(
              config.dataset,
              split=config.val_split,
              pp_eval=config.pp_eval,
              data_dir=config.get('data_dir'))
  }

  if config.get('test_split'):
    val_ds_splits.update({
        'test':
            _get_val_split(
                config.dataset,
                split=config.test_split,
                pp_eval=config.pp_eval,
                data_dir=config.get('data_dir'))
    })

  if config.get('eval_on_cifar_10h'):
    cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
        config.get('data_dir', None))
    preprocess_fn = preprocess_spec.parse(
        spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops())
    pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
    val_ds_splits['cifar_10h'] = _get_val_split(
        'cifar10',
        split=config.get('cifar_10h_split') or 'test',
        pp_eval=pp_eval,
        data_dir=config.get('data_dir'))
  elif config.get('eval_on_imagenet_real'):
    imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn()
    preprocess_fn = preprocess_spec.parse(
        spec=config.pp_eval_imagenet_real,
        available_ops=preprocess_utils.all_ops())
    pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
    val_ds_splits['imagenet_real'] = _get_val_split(
        'imagenet2012_real',
        split=config.get('imagenet_real_split') or 'validation',
        pp_eval=pp_eval,
        data_dir=config.get('data_dir'))

  ood_ds = {}
  if config.get('ood_datasets') and config.get('ood_methods'):
    if config.get('ood_methods'):  #  config.ood_methods is not a empty list
      logging.info('loading OOD dataset = %s', config.get('ood_datasets'))
      ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
          config.dataset,
          config.ood_datasets,
          config.ood_split,
          config.pp_eval,
          config.pp_eval_ood,
          config.ood_methods,
          config.train_split,
          config.get('data_dir'),
          _get_val_split,
      )

  write_note('Initializing model...')
  logging.info('config.model = %s', config.model)
  model = ub.models.vision_transformer(
      num_classes=config.num_classes, **config.model)

  ensemble_pred_fn = functools.partial(ensemble_prediction_fn, model.apply)

  @functools.partial(jax.pmap, axis_name='batch')
  def evaluation_fn(params, images, labels, mask):
    # params is a dict of the form:
    #   {'model_1': params_model_1, 'model_2': params_model_2, ...}
    # Ignore the entries with all zero labels for evaluation.
    mask *= labels.max(axis=1)
    loss_as_str = config.get('loss', 'sigmoid_xent')
    ens_logits, ens_prelogits = ensemble_pred_fn(params, images, loss_as_str)

    label_indices = config.get('label_indices')
    logging.info('!!! mask %s, label_indices %s', mask, label_indices)
    if label_indices:
      ens_logits = ens_logits[:, label_indices]

    # Note that logits and labels are usually of the shape [batch,num_classes].
    # But for OOD data, when num_classes_ood > num_classes_ind, we need to
    # adjust labels to labels[:, :config.num_classes] to match the shape of
    # logits. That is just to avoid shape mismatch. The output losses does not
    # have any meaning for OOD data, because OOD not belong to any IND class.
    losses = getattr(train_utils, loss_as_str)(
        logits=ens_logits,
        labels=labels[:, :(len(label_indices) if label_indices
                           else config.num_classes)], reduction=False)
    loss = jax.lax.psum(losses * mask, axis_name='batch')

    top1_idx = jnp.argmax(ens_logits, axis=1)
    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]
    ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
    n = jax.lax.psum(mask, axis_name='batch')

    metric_args = jax.lax.all_gather([ens_logits, labels, ens_prelogits, mask],
                                     axis_name='batch')
    return ncorrect, loss, n, metric_args

  @functools.partial(jax.pmap, axis_name='batch')
  def cifar_10h_evaluation_fn(params, images, labels, mask):
    loss_as_str = config.get('loss', 'softmax_xent')
    ens_logits, ens_prelogits = ensemble_pred_fn(params, images, loss_as_str)
    label_indices = config.get('label_indices')
    if label_indices:
      ens_logits = ens_logits[:, label_indices]

    losses = getattr(train_utils, loss_as_str)(
        logits=ens_logits, labels=labels, reduction=False)
    loss = jax.lax.psum(losses, axis_name='batch')

    top1_idx = jnp.argmax(ens_logits, axis=1)
    # Extracts the label at the highest logit index for each image.
    one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

    top1_correct = jnp.take_along_axis(
        one_hot_labels, top1_idx[:, None], axis=1)[:, 0]
    ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
    n = jax.lax.psum(one_hot_labels, axis_name='batch')

    metric_args = jax.lax.all_gather([ens_logits, labels, ens_prelogits, mask],
                                     axis_name='batch')
    return ncorrect, loss, n, metric_args

  # Setup function for computing representation.
  @functools.partial(jax.pmap, axis_name='batch')
  def representation_fn(params, images, labels, mask):
    # Return shape [batch_size, representation_size * ensemble_size]. During
    # few-shot eval, a single linear regressor is applied over all dimensions.
    representation = []
    for p in params.values():
      _, outputs = model.apply({'params': flax.core.freeze(p)},
                               images,
                               train=False)
      representation += [outputs[config.fewshot.representation_layer]]
    representation = jnp.concatenate(representation, axis=1)
    representation = jax.lax.all_gather(representation, 'batch')
    labels = jax.lax.all_gather(labels, 'batch')
    mask = jax.lax.all_gather(mask, 'batch')
    return representation, labels, mask

  write_note('Load checkpoints...')
  ensemble_params = load_checkpoints(config)

  write_note('Replicating...')
  ensemble_params = flax.jax_utils.replicate(ensemble_params)

  if jax.process_index() == 0:
    writer.write_hparams(dict(config))

  write_note('Initializing few-shotters...')
  fewshotter = None
  if 'fewshot' in config and fewshot is not None:
    fewshotter = fewshot.FewShotEvaluator(
        representation_fn, config.fewshot,
        config.fewshot.get('batch_size') or batch_size_eval)

  # Note: we return the train loss, val loss, and fewshot best l2s for use in
  # reproducibility unit tests.
  val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()}
  fewshot_results = {'dummy': {(0, 1): -jnp.inf}}
  step = 1

  # Report validation performance.
  write_note('Evaluating on the validation set...')
  for val_name, val_ds in val_ds_splits.items():
    # Sets up evaluation metrics.
    ece_num_bins = config.get('ece_num_bins', 15)
    auc_num_bins = config.get('auc_num_bins', 1000)
    ece = rm.metrics.ExpectedCalibrationError(num_bins=ece_num_bins)
    calib_auc = rm.metrics.CalibrationAUC(correct_pred_as_pos_label=False)
    oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.005,
                                                   num_bins=auc_num_bins)
    oc_auc_1 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.01,
                                                 num_bins=auc_num_bins)
    oc_auc_2 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.02,
                                                 num_bins=auc_num_bins)
    oc_auc_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.05,
                                                 num_bins=auc_num_bins)
    label_diversity = tf.keras.metrics.Mean()
    sample_diversity = tf.keras.metrics.Mean()
    ged = tf.keras.metrics.Mean()

    # Runs evaluation loop.
    val_iter = input_utils.start_input_pipeline(
        val_ds, config.get('prefetch_to_device', 1))
    ncorrect, loss, nseen = 0, 0, 0
    for batch in val_iter:
      if val_name == 'cifar_10h':
        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
            cifar_10h_evaluation_fn(ensemble_params, batch['image'],
                                    batch['labels'], batch['mask']))
      else:
        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
            evaluation_fn(ensemble_params, batch['image'],
                          batch['labels'], batch['mask']))
      # All results are a replicated array shaped as follows:
      # (local_devices, per_device_batch_size, elem_shape...)
      # with each local device's entry being identical as they got psum'd.
      # So let's just take the first one to the host as numpy.
      ncorrect += np.sum(np.array(batch_ncorrect[0]))
      loss += np.sum(np.array(batch_losses[0]))
      nseen += np.sum(np.array(batch_n[0]))
      if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
        # Here we parse batch_metric_args to compute uncertainty metrics.
        # (e.g., ECE or Calibration AUC).
        logits, labels, _, masks = batch_metric_args
        masks = np.array(masks[0], dtype=np.bool)
        logits = np.array(logits[0])
        probs = jax.nn.softmax(logits)
        # From one-hot to integer labels, as required by ECE.
        int_labels = np.argmax(np.array(labels[0]), axis=-1)
        int_preds = np.argmax(logits, axis=-1)
        confidence = np.max(probs, axis=-1)
        for p, c, l, d, m, label in zip(probs, confidence, int_labels,
                                        int_preds, masks, labels[0]):
          ece.add_batch(p[m, :], label=l[m])
          calib_auc.add_batch(d[m], label=l[m], confidence=c[m])
          # TODO(jereliu): Extend to support soft multi-class probabilities.
          oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m])
          oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m])
          oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m])
          oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m])

          if val_name == 'cifar_10h' or val_name == 'imagenet_real':
            batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                label[m], p[m, :], config.num_classes)
            label_diversity.update_state(batch_label_diversity)
            sample_diversity.update_state(batch_sample_diversity)
            ged.update_state(batch_ged)

    val_loss[val_name] = loss / nseen  # Keep for reproducibility tests.
    val_measurements = {
        f'{val_name}_prec@1': ncorrect / nseen,
        f'{val_name}_loss': val_loss[val_name],
    }
    if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
      val_measurements[f'{val_name}_ece'] = ece.result()['ece']
      val_measurements[f'{val_name}_calib_auc'] = calib_auc.result()[
          'calibration_auc']
      val_measurements[f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result()[
          'collaborative_auc']
      val_measurements[f'{val_name}_oc_auc_1%'] = oc_auc_1.result()[
          'collaborative_auc']
      val_measurements[f'{val_name}_oc_auc_2%'] = oc_auc_2.result()[
          'collaborative_auc']
      val_measurements[f'{val_name}_oc_auc_5%'] = oc_auc_5.result()[
          'collaborative_auc']
    writer.write_scalars(step, val_measurements)

    if val_name == 'cifar_10h' or val_name == 'imagenet_real':
      cifar_10h_measurements = {
          f'{val_name}_label_diversity': label_diversity.result(),
          f'{val_name}_sample_diversity': sample_diversity.result(),
          f'{val_name}_ged': ged.result(),
      }
      writer.write_scalars(step, cifar_10h_measurements)

  # OOD eval
  # Entries in the ood_ds dict include:
  # (ind_dataset, ood_dataset1, ood_dataset2, ...).
  # OOD metrics are computed using ind_dataset paired with each of the
  # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds
  # is also included in the ood_ds.
  if ood_ds and config.ood_methods:
    ood_measurements = ood_utils.eval_ood_metrics(
        ood_ds,
        ood_ds_names,
        config.ood_methods,
        evaluation_fn,
        ensemble_params,
        n_prefetch=config.get('prefetch_to_device', 1))
    writer.write_scalars(step, ood_measurements)

  if 'fewshot' in config and fewshotter is not None:
    # Compute few-shot on-the-fly evaluation.
    write_note('Few-shot evaluation...')
    # Keep `results` to return for reproducibility tests.
    fewshot_results, best_l2 = fewshotter.run_all(ensemble_params,
                                                  config.fewshot.datasets)

    # TODO(dusenberrymw): Remove this once fewshot.py is updated.
    def make_writer_measure_fn(step):
      def writer_measure(name, value):
        writer.write_scalars(step, {name: value})
      return writer_measure

    fewshotter.walk_results(
        make_writer_measure_fn(step), fewshot_results, best_l2)

  write_note('Done!')
  pool.close()
  pool.join()
  writer.close()

  # Return final training loss, validation loss, and fewshot results for
  # reproducibility test cases.
  return val_loss, fewshot_results

if __name__ == '__main__':

See More Examples