Here are the examples of the python api jax.numpy.argmax taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.
81 Examples
3
Source : lookahead_mnist.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def accuracy(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
predictions = jnp.argmax(logits, axis=-1)
return jnp.mean(jnp.argmax(labels, axis=-1) == predictions)
def model_accuracy(
3
Source : mlp_agent.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def select_action(network_def: Any,
network_params: np.ndarray,
state: Any) -> int:
"""Select an action greedily from network."""
return jnp.argmax(network_def.apply(network_params, state))
@functools.partial(jax.jit, static_argnums=(0, 7))
3
Source : model_utils.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def compute_depth_index(weights, depth_threshold=0.5):
"""Compute the sample index of the median depth accumulation."""
opaqueness_mask = compute_opaqueness_mask(weights, depth_threshold)
return jnp.argmax(opaqueness_mask, axis=-1)
def compute_depth_map(weights, z_vals, depth_threshold=0.5):
3
Source : differentially_private_sgd.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
def clipped_grad(params, l2_norm_clip, single_example_batch):
3
Source : mnist_classifier.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
3
Source : mnist_classifier_fromscratch.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
if __name__ == "__main__":
3
Source : resnet50.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=-1)
predicted_class = jnp.argmax(predict_fun(params, inputs), axis=-1)
return jnp.mean(predicted_class == target_class)
def synth_batches():
3
Source : mnist_lib.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def accuracy(predict: Callable, params, dataset):
@jax.jit
def _per_batch(inputs, labels):
target_class = jnp.argmax(labels, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
batched = [
_per_batch(inputs, labels) for inputs, labels in tfds.as_numpy(dataset)
]
return jnp.mean(jnp.stack(batched))
@staticmethod
3
Source : imagenet_resnet50_train.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def evaluate_batch(self, images, labels):
logits = self.model(images, training=False)
num_correct = jn.count_nonzero(jn.equal(jn.argmax(logits, axis=1), labels))
return num_correct
def run_eval(self):
3
Source : mnist_infl_mem.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def batch_correctness(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return predicted_class == target_class
def load_mnist():
3
Source : evaluation.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def accuracy(logits, targets, weights=None):
"""Sequence accuracy averaged over sequence length."""
if logits.ndim != targets.ndim + 1:
raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" %
(str(logits.shape), str(targets.shape)))
acc = jnp.equal(jnp.argmax(logits, axis=-1), targets)
normalizing_factor = np.prod(logits.shape[:-1])
if weights is not None:
acc = acc * weights * 100
normalizing_factor = weights.sum()
return acc.sum(), normalizing_factor
def sequence_accuracy(prediction, target):
3
Source : mention_based_entity_qa_task.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def get_predictions_max(attention_weights: Array, memory_entity_ids: Array,
weights: Array) -> Array:
"""Predict entity ID based on a memory with the largest attention weight."""
# `stop_gradient` is a safety check so the model doesn't keep activations
# around, which could be expensive.
attention_weights = jax.lax.stop_gradient(attention_weights)
memory_entity_ids = jax.lax.stop_gradient(memory_entity_ids)
weights = jax.lax.stop_gradient(weights)
memory_with_largest_attn = jnp.argmax(attention_weights, axis=1)
predictions = jnp.take_along_axis(
memory_entity_ids, jnp.expand_dims(memory_with_largest_attn, 1), axis=1)
predictions = jnp.squeeze(predictions, axis=1)
predictions = predictions * weights
return predictions
def get_predictions_sum(attention_weights: Array, memory_entity_ids: Array,
3
Source : metrics.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def pointer_accuracy_fn(pointer_logits: dt.BatchedTocopoLogits,
target_data: dt.BatchedTrainTocopoTargetData):
"""Computes average 0-1 accuracy of pointer predictions at timestep 1."""
one_hot_argmax_predictions = jax.nn.one_hot(
jnp.argmax(pointer_logits, axis=1), pointer_logits.shape[1])
# The pointer should always be at timestep 1 in copy-paste data. Slice out
# the BV array of those pointer targets.
is_pointer_targets = target_data.is_target_pointer[:, 1, :]
num_correct = jnp.sum(one_hot_argmax_predictions * is_pointer_targets)
num_attempts = is_pointer_targets.shape[0]
return num_correct, num_attempts
@flax_dataclass
3
Source : ensemble.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def label_pred_ensemble_softmax(logits: Array, ensemble_size: int) -> Array:
"""Function to select the predicted labels for the ensemble softmax CE.
Args:
logits: 2D tensor of shape [ensemble size * batch size, #classes]. It is
assumed that the batches for each ensemble member are stacked with a
jnp.repeat(..., ensemble_size) pattern and can thus be recovered by
an appropriate slicing ...[::ensemble_size].
ensemble_size: The size of the ensemble.
Returns:
The class labels to predict.
"""
log_p = _get_log_ensemble_softmax(logits, ensemble_size) # Shape: (B, C).
return jnp.argmax(log_p, axis=1) # Shape: (B,).
3
Source : trainer_test.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def test(self, name, kwargs, expected_shape):
train_loss_fn, eval_loss_fn, label_pred_fn = trainer.get_loss_fn(
name, **kwargs)
logits = jax.random.normal(jax.random.PRNGKey(0), (8, 16))
labels = jax.random.uniform(jax.random.PRNGKey(0), (8, 16))
train_loss = train_loss_fn(logits, labels)
eval_loss = eval_loss_fn(logits, labels)
self.assertEqual(train_loss.shape, expected_shape)
self.assertEqual(eval_loss.shape, expected_shape)
# The train and eval losses are the same in the standard case (no ensemble).
self.assertSequenceEqual(list(train_loss), list(eval_loss))
self.assertSequenceEqual(list(label_pred_fn(logits)),
list(jnp.argmax(logits, 1)))
@parameterized.named_parameters(
3
Source : metrics.py
with MIT License
from gortizji
with MIT License
from gortizji
def compute_accuracy_metrics(logits, labels):
loss = softmax_cross_entropy_loss(logits, labels)
if len(labels.shape) > 1:
accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
else:
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
"loss": loss,
"accuracy": accuracy,
}
return metrics
def compute_binary_accuracy_metrics(logits, labels):
3
Source : run_t5_mlm_flax.py
with Apache License 2.0
from gsarti
with Apache License 2.0
from gsarti
def eval_step(params, batch):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
# compute loss
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
# compute accuracy
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)
# summarize metrics
metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
3
Source : ffjord_mnist.py
with MIT License
from jacobjinkelly
with MIT License
from jacobjinkelly
def _acc_fn(logits, labels):
"""
Classification accuracy of the model.
"""
predicted_class = jnp.argmax(logits, axis=1)
return jnp.mean(predicted_class == labels)
def standard_normal_logprob(z):
3
Source : mnist.py
with MIT License
from jacobjinkelly
with MIT License
from jacobjinkelly
def _acc_fn(logits, labels):
"""
Classification accuracy of the model.
"""
predicted_class = jnp.argmax(logits, axis=1)
return jnp.mean(predicted_class == labels)
def _loss_fn(logits, labels):
3
Source : mnist_classifier.py
with Apache License 2.0
from juliuskunze
with Apache License 2.0
from juliuskunze
def accuracy(inputs, targets):
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(predict(inputs), axis=1)
return np.mean(predicted_class == target_class)
def main():
3
Source : dqn.py
with MIT License
from ku2482
with MIT License
from ku2482
def _forward(
self,
params: hk.Params,
state: np.ndarray,
) -> jnp.ndarray:
return jnp.argmax(self.net.apply(params, state), axis=1)
def update(self, writer=None):
3
Source : fqf.py
with MIT License
from ku2482
with MIT License
from ku2482
def _forward_from_feature(
self,
params_cum_p: hk.Params,
params: hk.Params,
feature: np.ndarray,
) -> jnp.ndarray:
cum_p, cum_p_prime = self.cum_p_net.apply(params_cum_p, feature)
quantile_s = self.net["quantile"].apply(params["quantile"], feature, cum_p_prime)
q_s = ((cum_p[:, 1:, None] - cum_p[:, :-1, None]) * quantile_s).sum(axis=1)
return jnp.argmax(q_s, axis=1)
def update(self, writer=None):
3
Source : iqn.py
with MIT License
from ku2482
with MIT License
from ku2482
def _forward(
self,
params: hk.Params,
state: np.ndarray,
key: jnp.ndarray,
) -> jnp.ndarray:
cum_p = jax.random.uniform(key, (state.shape[0], self.num_quantiles_eval))
return jnp.argmax(self.net.apply(params, state, cum_p).mean(axis=1), axis=1)
@partial(jax.jit, static_argnums=0)
3
Source : qrdqn.py
with MIT License
from ku2482
with MIT License
from ku2482
def _forward(
self,
params: hk.Params,
state: np.ndarray,
) -> jnp.ndarray:
return jnp.argmax(self.net.apply(params, state).mean(axis=1), axis=1)
@partial(jax.jit, static_argnums=0)
3
Source : sac_discrete.py
with MIT License
from ku2482
with MIT License
from ku2482
def _select_action(
self,
params_actor: hk.Params,
state: np.ndarray,
) -> jnp.ndarray:
pi_s, _ = self.actor.apply(params_actor, state)
return jnp.argmax(pi_s, axis=1)
@partial(jax.jit, static_argnums=0)
3
Source : mnist_adahessian.py
with Apache License 2.0
from nestordemeure
with Apache License 2.0
from nestordemeure
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
3
Source : main.py
with MIT License
from niklasschmitz
with MIT License
from niklasschmitz
def choose_action(env, observation, pred_Q, params_Q_eval, eps):
rand = onp.random.random()
actions = pred_Q(params_Q_eval, observation)
if rand < 1 - eps:
action = np.argmax(actions[0])
else:
action = env.action_space.sample()
return action
def stack_frames(frames):
3
Source : numpy_ops.py
with GNU General Public License v3.0
from PKU-NIP-Lab
with GNU General Public License v3.0
from PKU-NIP-Lab
def argmax(x, axis=None):
x = _remove_jaxarray(x)
r = jnp.argmax(x, axis=axis)
return r if axis is None else JaxArray(r)
def argmin(x, axis=None):
3
Source : sgd_flax_demo.py
with MIT License
from probml
with MIT License
from probml
def callback_fn(**kwargs):
logprobs = model.apply(kwargs["belief_state"].params, kwargs["X_train"])
y_pred = jnp.argmax(logprobs, axis=-1)
train_acc = jnp.mean(y_pred == kwargs["Y_train"])
logprobs = model.apply(kwargs["belief_state"].params, kwargs["X_train"])
y_pred = jnp.argmax(kwargs["preds"][0], axis=-1)
test_acc = jnp.mean(y_pred == kwargs["Y_test"])
print("Loss: ", kwargs["info"].loss)
print(f"Train Accuracy: {train_acc}, Test Accuracy: {test_acc}")
def main():
3
Source : ops.py
with Apache License 2.0
from pyro-ppl
with Apache License 2.0
from pyro-ppl
def _argmax(x, axis, keepdims):
if keepdims:
return np.expand_dims(np.argmax(x, axis), axis)
return np.argmax(x, axis)
@ops.argmin.register(array)
3
Source : submit_to_leaderboard.py
with MIT License
from rowanz
with MIT License
from rowanz
def pred_step(state: train_state.TrainState, batch):
logits_from_audio, logits_from_text = state.apply_fn({'params': state.params}, batch)
out = {'logprobs_audio': jax.nn.log_softmax(logits_from_audio, axis=-1),
'preds_audio': jnp.argmax(logits_from_audio, -1),
'logprobs_text': jax.nn.log_softmax(logits_from_text, axis=-1),
'preds_text': jnp.argmax(logits_from_text, -1),
}
softmax_joint = jax.nn.softmax(logits_from_audio, axis=-1) + jax.nn.softmax(logits_from_text, axis=-1)
out['preds_joint'] = jnp.argmax(softmax_joint, -1)
return out
p_pred_step = jax.pmap(pred_step, axis_name='batch', donate_argnums=(1,))
3
Source : tvqa_finetune.py
with MIT License
from rowanz
with MIT License
from rowanz
def pred_step(state: train_state.TrainState, batch):
logits_from_audio, logits_from_text = state.apply_fn({'params': state.params}, batch)
out = {'logprobs_audio': jax.nn.log_softmax(logits_from_audio, axis=-1),
'preds_audio': jnp.argmax(logits_from_audio, -1),
'logprobs_text': jax.nn.log_softmax(logits_from_text, axis=-1),
'preds_text': jnp.argmax(logits_from_text, -1),
}
softmax_joint = jax.nn.softmax(logits_from_audio, axis=-1) + jax.nn.softmax(logits_from_text, axis=-1)
out['preds_joint'] = jnp.argmax(softmax_joint, -1)
return out
p_pred_step = jax.pmap(pred_step, axis_name='batch', donate_argnums=(1,))
3
Source : qa_qar_joint_finetune.py
with MIT License
from rowanz
with MIT License
from rowanz
def train_loss_fn(state, params, batch):
logits = state.apply_fn({'params': params}, batch)
log_p = jax.nn.log_softmax(logits, axis=-1)
labels_oh = jax.nn.one_hot(batch['labels'], dtype=log_p.dtype, num_classes=log_p.shape[-1])
loss = -jnp.mean(jnp.sum(labels_oh * log_p, axis=-1))
is_right = (jnp.argmax(log_p, -1) == batch['labels']).astype(jnp.float32).mean()
return loss, {'is_right': is_right, 'loss': loss}
p_train_step = jax.pmap(functools.partial(finetune_train_step, loss_fn=train_loss_fn, tx_fns=tx_fns),
3
Source : training_utilities.py
with BSD 2-Clause "Simplified" License
from SarthakYadav
with BSD 2-Clause "Simplified" License
from SarthakYadav
def compute_metrics(logits, labels, mode, cost_fn):
loss = cost_fn(logits, labels)
metrics = {
'loss': loss
}
if mode == TrainingMode.MULTICLASS:
accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
# labels is now a (batch, num_classes) onehot-encoded array
metrics['accuracy'] = accuracy
metrics = lax.pmean(metrics, axis_name='batch')
return metrics
def create_input_iter(ds, devices=None):
3
Source : haiku_run.py
with Apache License 2.0
from Scitator
with Apache License 2.0
from Scitator
def _eval_batch(
self, params: hk.Params, state: hk.State, rng: jnp.ndarray, batch: Batch
) -> jnp.ndarray:
logits, _ = self.forward.apply(params, state, rng, batch)
y = batch["labels"]
return jnp.sum(jnp.argmax(logits, axis=-1) == y)
def _update_fn(
3
Source : jax_run.py
with Apache License 2.0
from Scitator
with Apache License 2.0
from Scitator
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
3
Source : penguin_utils_flax_experimental.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def _compute_metrics(logits: _LogitBatch, labels: _LabelBatch):
# assumes that the logits use log_softmax activations.
loss = _categorical_cross_entropy_loss(logits, labels)
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels[..., 0])
return {
'loss': loss,
'accuracy': accuracy,
}
def _mean_epoch_metrics(
3
Source : utils.py
with MIT License
from vballoli
with MIT License
from vballoli
def error_rate_metric(logits,
one_hot_labels):
"""Returns the error rate between some predictions and some labels.
Args:
logits: Output of the model.
one_hot_labels: One-hot encoded labels. Dimensions should match the logits.
Returns:
The error rate (1 - accuracy), averaged over the first dimension (samples).
"""
return jnp.mean(jnp.argmax(logits, -1) != jnp.argmax(one_hot_labels, -1))
def tensorflow_to_numpy(xs):
3
Source : train.py
with Apache License 2.0
from yzhwang
with Apache License 2.0
from yzhwang
def eval_batch(
params: hk.Params,
state: hk.State,
batch: dataset.Batch,
) -> jnp.ndarray:
"""Evaluates a batch."""
logits, _ = forward.apply(params, state, None, batch, is_training=False)
predicted_label = jnp.argmax(logits, axis=-1)
correct = jnp.sum(jnp.equal(predicted_label, batch['labels']))
return correct.astype(jnp.float32)
def evaluate(
0
Source : accuracy.py
with MIT License
from cgarciae
with MIT License
from cgarciae
def update(
self,
target: jnp.ndarray,
preds: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
):
"""
Accumulates metric statistics. `target` and `preds` should have the same shape.
Arguments:
target: Ground truth values. shape = `[batch_size, d0, .. dN]`.
preds: The predicted values. shape = `[batch_size, d0, .. dN]`.
sample_weight: Optional `sample_weight` acts as a
coefficient for the metric. If a scalar is provided, then the metric is
simply scaled by the given value. If `sample_weight` is a tensor of size
`[batch_size]`, then the metric for each sample of the batch is rescaled
by the corresponding element in the `sample_weight` vector. If the shape
of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted
to this shape), then each metric element of `preds` is scaled by the
corresponding value of `sample_weight`. (Note on `dN-1`: all metric
functions reduce by 1 dimension, usually the last axis (-1)).
Returns:
Array with the cumulative accuracy.
"""
if self.argmax_preds:
preds = jnp.argmax(preds, axis=-1)
if self.argmax_labels:
target = jnp.argmax(target, axis=-1)
values = accuracy(target=target, preds=preds)
super().update(
values=values,
sample_weight=sample_weight,
)
0
Source : wcpg.py
with MIT License
from ChrisWaites
with MIT License
from ChrisWaites
def step(self, action, rng):
action_index = int(np.argmax(action))
dynamics = self.dynamics[self.state_index][action_index]
temp, rng = random.split(rng)
next_state_index = int(random.choice(temp, 7, p=dynamics))
self.state_index = next_state_index
next_state = self.states[next_state_index]
mean = 1. if next_state_index == 0 else 2.
std = 1. if next_state_index == 0 else 2.
temp, rng = random.split(rng)
reward = std * random.normal(temp) + mean
done = (next_state_index >= 5)
return next_state, reward, done, {}
def reset(self):
0
Source : wcpg.py
with MIT License
from ChrisWaites
with MIT License
from ChrisWaites
def is_terminal(next_state):
return np.argmax(next_state, 1) >= 5
def critic_loss(critic_params, fixed_critic_params, fixed_actor_params, env_dynamics, batch):
0
Source : decoders.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def postprocess(spec: _Spec, preds: Dict[str, _Array]) -> Dict[str, _DataPoint]:
"""Postprocesses decoder output."""
result = {}
for name in preds.keys():
_, loc, t = spec[name]
data = preds[name]
if t == _Type.SCALAR:
pass
elif t == _Type.MASK:
data = (data > 0.0) * 1.0
elif t in [_Type.MASK_ONE, _Type.CATEGORICAL]:
cat_size = data.shape[-1]
best = jnp.argmax(data, -1)
data = hk.one_hot(best, cat_size)
elif t == _Type.POINTER:
data = jnp.argmax(data, -1)
else:
raise ValueError("Invalid type")
result[name] = probing.DataPoint(
name=name, location=loc, type_=t, data=data)
return result
def decode_fts(
0
Source : network.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def sample_from_logits(
logits: jnp.ndarray,
legal_action_mask: jnp.ndarray,
temperature: jnp.ndarray,):
"""Sample from logits respecting a legal actions mask."""
deterministic_logits = jnp.where(
jax.nn.one_hot(
jnp.argmax(logits, axis=-1),
num_classes=action_utils.MAX_ACTION_INDEX,
dtype=jnp.bool_), 0,
jnp.finfo(jnp.float32).min)
stochastic_logits = jnp.where(legal_action_mask,
logits / temperature,
jnp.finfo(jnp.float32).min)
logits_for_sampling = jnp.where(
jnp.equal(temperature, 0.0),
deterministic_logits,
stochastic_logits)
# Sample an action for the current province and update the state so that
# following orders can be conditioned on this decision.
key = hk.next_rng_key()
return jax.random.categorical(
key, logits_for_sampling, axis=-1)
class RelationalOrderDecoderState(NamedTuple):
0
Source : zacharys_karate_club.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def optimize_club(num_steps: int):
"""Solves the karte club problem by optimizing the assignments of students."""
network = hk.without_apply_rng(hk.transform(network_definition))
zacharys_karate_club = get_zacharys_karate_club()
labels = get_ground_truth_assignments_for_zacharys_karate_club()
params = network.init(jax.random.PRNGKey(42), zacharys_karate_club)
@jax.jit
def prediction_loss(params):
decoded_nodes = network.apply(params, zacharys_karate_club)
# We interpret the decoded nodes as a pair of logits for each node.
log_prob = jax.nn.log_softmax(decoded_nodes)
# The only two assignments we know a-priori are those of Mr. Hi (Node 0)
# and John A (Node 33).
return -(log_prob[0, 0] + log_prob[33, 1])
opt_init, opt_update = optax.adam(1e-2)
opt_state = opt_init(params)
@jax.jit
def update(params, opt_state):
g = jax.grad(prediction_loss)(params)
updates, opt_state = opt_update(g, opt_state)
return optax.apply_updates(params, updates), opt_state
@jax.jit
def accuracy(params):
decoded_nodes = network.apply(params, zacharys_karate_club)
return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == labels)
for step in range(num_steps):
logging.info("step %r accuracy %r", step, accuracy(params).item())
params, opt_state = update(params, opt_state)
def main(_):
0
Source : train.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def compute_loss(params, graph, label, net):
"""Computes loss."""
pred_graph = net.apply(params, graph)
preds = jax.nn.log_softmax(pred_graph.globals)
targets = jax.nn.one_hot(label, 2)
# Since we have an extra 'dummy' graph in our batch due to padding, we want
# to mask out any loss associated with the dummy graph.
# Since we padded with `pad_with_graphs` we can recover the mask by using
# get_graph_padding_mask.
mask = jraph.get_graph_padding_mask(pred_graph)
# Cross entropy loss.
loss = -jnp.mean(preds * targets * mask[:, None])
# Accuracy taking into account the mask.
accuracy = jnp.sum(
(jnp.argmax(pred_graph.globals, axis=1) == label) * mask)/jnp.sum(mask)
return loss, accuracy
def train(data_path, master_csv_path, split_path, batch_size,
0
Source : train_flax.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def compute_loss(params, graph, label, net):
"""Computes loss."""
pred_graph = net.apply(params, graph)
preds = jax.nn.log_softmax(pred_graph.globals)
targets = jax.nn.one_hot(label, 2)
# Since we have an extra 'dummy' graph in our batch due to padding, we want
# to mask out any loss associated with the dummy graph.
# Since we padded with `pad_with_graphs` we can recover the mask by using
# get_graph_padding_mask.
mask = jraph.get_graph_padding_mask(pred_graph)
# Cross entropy loss.
loss = -jnp.mean(preds * targets * mask[:, None])
# Accuracy taking into account the mask.
accuracy = jnp.sum(
(jnp.argmax(pred_graph.globals, axis=1) == label) * mask)/jnp.sum(mask)
return loss, accuracy
def train_step(optimizer, graph, label, net):
0
Source : action_selection.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def gumbel_muzero_interior_action_selection(
rng_key: chex.PRNGKey,
tree: tree_lib.Tree,
node_index: chex.Numeric,
depth: chex.Numeric,
*,
qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
) -> chex.Array:
"""Selects the action with a deterministic action selection.
The action is selected based on the visit counts to produce visitation
frequencies similar to softmax(prior_logits + qvalues).
Args:
rng_key: random number generator state.
tree: _unbatched_ MCTS tree state.
node_index: scalar index of the node from which to take an action.
depth: the scalar depth of the current node. The root has depth zero.
qtransform: function to obtain completed Q-values for a node.
Returns:
action: the action selected from the given node.
"""
del rng_key, depth
chex.assert_shape([node_index], ())
visit_counts = tree.children_visits[node_index]
prior_logits = tree.children_prior_logits[node_index]
chex.assert_equal_shape([visit_counts, prior_logits])
completed_qvalues = qtransform(tree, node_index)
# The `prior_logits + completed_qvalues` provide an improved policy,
# because the missing qvalues are replaced by v_{prior_logits}(node).
to_argmax = _prepare_argmax_input(
probs=jax.nn.softmax(prior_logits + completed_qvalues),
visit_counts=visit_counts)
chex.assert_rank(to_argmax, 1)
return jnp.argmax(to_argmax, axis=-1)
def masked_argmax(
0
Source : random.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def categorical(key: KeyArray,
logits: RealArray,
axis: int = -1,
shape: Optional[Sequence[int]] = None) -> jnp.ndarray:
"""Sample random values from categorical distributions.
Args:
key: a PRNG key used as the random key.
logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
so that `softmax(logits, axis)` gives the corresponding probabilities.
axis: Axis along which logits belong to the same categorical distribution.
shape: Optional, a tuple of nonnegative integers representing the result shape.
Must be broadcast-compatible with ``np.delete(logits.shape, axis)``.
The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``.
Returns:
A random array with int dtype and shape given by ``shape`` if ``shape``
is not None, or else ``np.delete(logits.shape, axis)``.
"""
key, _ = _check_prng_key(key)
if axis >= 0:
axis -= len(logits.shape)
batch_shape = tuple(np.delete(logits.shape, axis))
if shape is None:
shape = batch_shape
else:
shape = tuple(shape)
_check_shape("categorical", shape, batch_shape)
sample_shape = shape[:len(shape)-len(batch_shape)]
return jnp.argmax(
gumbel(key, sample_shape + logits.shape, logits.dtype) +
lax.expand_dims(logits, tuple(range(len(sample_shape)))),
axis=axis)
def laplace(key: KeyArray,
0
Source : deep_ensemble.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def main(config, output_dir):
seed = config.get('seed', 0)
tf.random.set_seed(seed)
if config.get('data_dir'):
logging.info('data_dir=%s', config.data_dir)
logging.info('Output dir: %s', output_dir)
tf.io.gfile.makedirs(output_dir)
# Create an asynchronous multi-metric writer.
writer = metric_writers.create_default_writer(
output_dir, just_logging=jax.process_index() > 0)
# The pool is used to perform misc operations such as logging in async way.
pool = multiprocessing.pool.ThreadPool()
def write_note(note):
if jax.process_index() == 0:
logging.info('NOTE: %s', note)
write_note('Initializing...')
batch_size = config.batch_size
batch_size_eval = config.get('batch_size_eval', batch_size)
if (batch_size % jax.device_count() != 0 or
batch_size_eval % jax.device_count() != 0):
raise ValueError(f'Batch sizes ({batch_size} and {batch_size_eval}) must '
f'be divisible by device number ({jax.device_count()})')
local_batch_size = batch_size // jax.process_count()
local_batch_size_eval = batch_size_eval // jax.process_count()
logging.info(
'Global batch size %d on %d hosts results in %d local batch size. '
'With %d devices per host (%d devices total), that\'s a %d per-device '
'batch size.', batch_size, jax.process_count(), local_batch_size,
jax.local_device_count(), jax.device_count(),
local_batch_size // jax.local_device_count())
write_note('Initializing val dataset(s)...')
def _get_val_split(dataset, split, pp_eval, data_dir=None):
# We do ceil rounding such that we include the last incomplete batch.
nval_img = input_utils.get_num_examples(
dataset,
split=split,
process_batch_size=local_batch_size_eval,
drop_remainder=False,
data_dir=data_dir)
val_steps = int(np.ceil(nval_img / batch_size_eval))
logging.info('Running validation for %d steps for %s, %s', val_steps,
dataset, split)
if isinstance(pp_eval, str):
pp_eval = preprocess_spec.parse(
spec=pp_eval, available_ops=preprocess_utils.all_ops())
val_ds = input_utils.get_data(
dataset=dataset,
split=split,
rng=None,
process_batch_size=local_batch_size_eval,
preprocess_fn=pp_eval,
cache=config.get('val_cache', 'batched'),
num_epochs=1,
repeat_after_batching=True,
shuffle=False,
prefetch_size=config.get('prefetch_to_host', 2),
drop_remainder=False,
data_dir=data_dir)
return val_ds
val_ds_splits = {
'val':
_get_val_split(
config.dataset,
split=config.val_split,
pp_eval=config.pp_eval,
data_dir=config.get('data_dir'))
}
if config.get('test_split'):
val_ds_splits.update({
'test':
_get_val_split(
config.dataset,
split=config.test_split,
pp_eval=config.pp_eval,
data_dir=config.get('data_dir'))
})
if config.get('eval_on_cifar_10h'):
cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
config.get('data_dir', None))
preprocess_fn = preprocess_spec.parse(
spec=config.pp_eval_cifar_10h, available_ops=preprocess_utils.all_ops())
pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
val_ds_splits['cifar_10h'] = _get_val_split(
'cifar10',
split=config.get('cifar_10h_split') or 'test',
pp_eval=pp_eval,
data_dir=config.get('data_dir'))
elif config.get('eval_on_imagenet_real'):
imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn()
preprocess_fn = preprocess_spec.parse(
spec=config.pp_eval_imagenet_real,
available_ops=preprocess_utils.all_ops())
pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
val_ds_splits['imagenet_real'] = _get_val_split(
'imagenet2012_real',
split=config.get('imagenet_real_split') or 'validation',
pp_eval=pp_eval,
data_dir=config.get('data_dir'))
ood_ds = {}
if config.get('ood_datasets') and config.get('ood_methods'):
if config.get('ood_methods'): # config.ood_methods is not a empty list
logging.info('loading OOD dataset = %s', config.get('ood_datasets'))
ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
config.dataset,
config.ood_datasets,
config.ood_split,
config.pp_eval,
config.pp_eval_ood,
config.ood_methods,
config.train_split,
config.get('data_dir'),
_get_val_split,
)
write_note('Initializing model...')
logging.info('config.model = %s', config.model)
model = ub.models.vision_transformer(
num_classes=config.num_classes, **config.model)
ensemble_pred_fn = functools.partial(ensemble_prediction_fn, model.apply)
@functools.partial(jax.pmap, axis_name='batch')
def evaluation_fn(params, images, labels, mask):
# params is a dict of the form:
# {'model_1': params_model_1, 'model_2': params_model_2, ...}
# Ignore the entries with all zero labels for evaluation.
mask *= labels.max(axis=1)
loss_as_str = config.get('loss', 'sigmoid_xent')
ens_logits, ens_prelogits = ensemble_pred_fn(params, images, loss_as_str)
label_indices = config.get('label_indices')
logging.info('!!! mask %s, label_indices %s', mask, label_indices)
if label_indices:
ens_logits = ens_logits[:, label_indices]
# Note that logits and labels are usually of the shape [batch,num_classes].
# But for OOD data, when num_classes_ood > num_classes_ind, we need to
# adjust labels to labels[:, :config.num_classes] to match the shape of
# logits. That is just to avoid shape mismatch. The output losses does not
# have any meaning for OOD data, because OOD not belong to any IND class.
losses = getattr(train_utils, loss_as_str)(
logits=ens_logits,
labels=labels[:, :(len(label_indices) if label_indices
else config.num_classes)], reduction=False)
loss = jax.lax.psum(losses * mask, axis_name='batch')
top1_idx = jnp.argmax(ens_logits, axis=1)
# Extracts the label at the highest logit index for each image.
top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]
ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
n = jax.lax.psum(mask, axis_name='batch')
metric_args = jax.lax.all_gather([ens_logits, labels, ens_prelogits, mask],
axis_name='batch')
return ncorrect, loss, n, metric_args
@functools.partial(jax.pmap, axis_name='batch')
def cifar_10h_evaluation_fn(params, images, labels, mask):
loss_as_str = config.get('loss', 'softmax_xent')
ens_logits, ens_prelogits = ensemble_pred_fn(params, images, loss_as_str)
label_indices = config.get('label_indices')
if label_indices:
ens_logits = ens_logits[:, label_indices]
losses = getattr(train_utils, loss_as_str)(
logits=ens_logits, labels=labels, reduction=False)
loss = jax.lax.psum(losses, axis_name='batch')
top1_idx = jnp.argmax(ens_logits, axis=1)
# Extracts the label at the highest logit index for each image.
one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]
top1_correct = jnp.take_along_axis(
one_hot_labels, top1_idx[:, None], axis=1)[:, 0]
ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
n = jax.lax.psum(one_hot_labels, axis_name='batch')
metric_args = jax.lax.all_gather([ens_logits, labels, ens_prelogits, mask],
axis_name='batch')
return ncorrect, loss, n, metric_args
# Setup function for computing representation.
@functools.partial(jax.pmap, axis_name='batch')
def representation_fn(params, images, labels, mask):
# Return shape [batch_size, representation_size * ensemble_size]. During
# few-shot eval, a single linear regressor is applied over all dimensions.
representation = []
for p in params.values():
_, outputs = model.apply({'params': flax.core.freeze(p)},
images,
train=False)
representation += [outputs[config.fewshot.representation_layer]]
representation = jnp.concatenate(representation, axis=1)
representation = jax.lax.all_gather(representation, 'batch')
labels = jax.lax.all_gather(labels, 'batch')
mask = jax.lax.all_gather(mask, 'batch')
return representation, labels, mask
write_note('Load checkpoints...')
ensemble_params = load_checkpoints(config)
write_note('Replicating...')
ensemble_params = flax.jax_utils.replicate(ensemble_params)
if jax.process_index() == 0:
writer.write_hparams(dict(config))
write_note('Initializing few-shotters...')
fewshotter = None
if 'fewshot' in config and fewshot is not None:
fewshotter = fewshot.FewShotEvaluator(
representation_fn, config.fewshot,
config.fewshot.get('batch_size') or batch_size_eval)
# Note: we return the train loss, val loss, and fewshot best l2s for use in
# reproducibility unit tests.
val_loss = {val_name: -jnp.inf for val_name, _ in val_ds_splits.items()}
fewshot_results = {'dummy': {(0, 1): -jnp.inf}}
step = 1
# Report validation performance.
write_note('Evaluating on the validation set...')
for val_name, val_ds in val_ds_splits.items():
# Sets up evaluation metrics.
ece_num_bins = config.get('ece_num_bins', 15)
auc_num_bins = config.get('auc_num_bins', 1000)
ece = rm.metrics.ExpectedCalibrationError(num_bins=ece_num_bins)
calib_auc = rm.metrics.CalibrationAUC(correct_pred_as_pos_label=False)
oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.005,
num_bins=auc_num_bins)
oc_auc_1 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.01,
num_bins=auc_num_bins)
oc_auc_2 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.02,
num_bins=auc_num_bins)
oc_auc_5 = rm.metrics.OracleCollaborativeAUC(oracle_fraction=0.05,
num_bins=auc_num_bins)
label_diversity = tf.keras.metrics.Mean()
sample_diversity = tf.keras.metrics.Mean()
ged = tf.keras.metrics.Mean()
# Runs evaluation loop.
val_iter = input_utils.start_input_pipeline(
val_ds, config.get('prefetch_to_device', 1))
ncorrect, loss, nseen = 0, 0, 0
for batch in val_iter:
if val_name == 'cifar_10h':
batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
cifar_10h_evaluation_fn(ensemble_params, batch['image'],
batch['labels'], batch['mask']))
else:
batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
evaluation_fn(ensemble_params, batch['image'],
batch['labels'], batch['mask']))
# All results are a replicated array shaped as follows:
# (local_devices, per_device_batch_size, elem_shape...)
# with each local device's entry being identical as they got psum'd.
# So let's just take the first one to the host as numpy.
ncorrect += np.sum(np.array(batch_ncorrect[0]))
loss += np.sum(np.array(batch_losses[0]))
nseen += np.sum(np.array(batch_n[0]))
if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
# Here we parse batch_metric_args to compute uncertainty metrics.
# (e.g., ECE or Calibration AUC).
logits, labels, _, masks = batch_metric_args
masks = np.array(masks[0], dtype=np.bool)
logits = np.array(logits[0])
probs = jax.nn.softmax(logits)
# From one-hot to integer labels, as required by ECE.
int_labels = np.argmax(np.array(labels[0]), axis=-1)
int_preds = np.argmax(logits, axis=-1)
confidence = np.max(probs, axis=-1)
for p, c, l, d, m, label in zip(probs, confidence, int_labels,
int_preds, masks, labels[0]):
ece.add_batch(p[m, :], label=l[m])
calib_auc.add_batch(d[m], label=l[m], confidence=c[m])
# TODO(jereliu): Extend to support soft multi-class probabilities.
oc_auc_0_5.add_batch(d[m], label=l[m], custom_binning_score=c[m])
oc_auc_1.add_batch(d[m], label=l[m], custom_binning_score=c[m])
oc_auc_2.add_batch(d[m], label=l[m], custom_binning_score=c[m])
oc_auc_5.add_batch(d[m], label=l[m], custom_binning_score=c[m])
if val_name == 'cifar_10h' or val_name == 'imagenet_real':
batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
label[m], p[m, :], config.num_classes)
label_diversity.update_state(batch_label_diversity)
sample_diversity.update_state(batch_sample_diversity)
ged.update_state(batch_ged)
val_loss[val_name] = loss / nseen # Keep for reproducibility tests.
val_measurements = {
f'{val_name}_prec@1': ncorrect / nseen,
f'{val_name}_loss': val_loss[val_name],
}
if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
val_measurements[f'{val_name}_ece'] = ece.result()['ece']
val_measurements[f'{val_name}_calib_auc'] = calib_auc.result()[
'calibration_auc']
val_measurements[f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result()[
'collaborative_auc']
val_measurements[f'{val_name}_oc_auc_1%'] = oc_auc_1.result()[
'collaborative_auc']
val_measurements[f'{val_name}_oc_auc_2%'] = oc_auc_2.result()[
'collaborative_auc']
val_measurements[f'{val_name}_oc_auc_5%'] = oc_auc_5.result()[
'collaborative_auc']
writer.write_scalars(step, val_measurements)
if val_name == 'cifar_10h' or val_name == 'imagenet_real':
cifar_10h_measurements = {
f'{val_name}_label_diversity': label_diversity.result(),
f'{val_name}_sample_diversity': sample_diversity.result(),
f'{val_name}_ged': ged.result(),
}
writer.write_scalars(step, cifar_10h_measurements)
# OOD eval
# Entries in the ood_ds dict include:
# (ind_dataset, ood_dataset1, ood_dataset2, ...).
# OOD metrics are computed using ind_dataset paired with each of the
# ood_dataset. When Mahalanobis distance method is applied, train_ind_ds
# is also included in the ood_ds.
if ood_ds and config.ood_methods:
ood_measurements = ood_utils.eval_ood_metrics(
ood_ds,
ood_ds_names,
config.ood_methods,
evaluation_fn,
ensemble_params,
n_prefetch=config.get('prefetch_to_device', 1))
writer.write_scalars(step, ood_measurements)
if 'fewshot' in config and fewshotter is not None:
# Compute few-shot on-the-fly evaluation.
write_note('Few-shot evaluation...')
# Keep `results` to return for reproducibility tests.
fewshot_results, best_l2 = fewshotter.run_all(ensemble_params,
config.fewshot.datasets)
# TODO(dusenberrymw): Remove this once fewshot.py is updated.
def make_writer_measure_fn(step):
def writer_measure(name, value):
writer.write_scalars(step, {name: value})
return writer_measure
fewshotter.walk_results(
make_writer_measure_fn(step), fewshot_results, best_l2)
write_note('Done!')
pool.close()
pool.join()
writer.close()
# Return final training loss, validation loss, and fewshot results for
# reproducibility test cases.
return val_loss, fewshot_results
if __name__ == '__main__':
See More Examples