Here are the examples of the python api jax.numpy.mean taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.
193 Examples
3
Source : cnn_flax.py
with MIT License
from cgarciae
with MIT License
from cgarciae
def loss_fn(
params: tx.FlaxModule, model: tx.FlaxModule, x: jnp.ndarray, y: jnp.ndarray
) -> tp.Tuple[jnp.ndarray, tp.Tuple[tx.FlaxModule, jnp.ndarray]]:
model = model.merge(params)
preds = model(x)
loss = jnp.mean(
optax.softmax_cross_entropy(
preds,
jax.nn.one_hot(y, 10),
)
)
acc_batch = preds.argmax(axis=1) == y
return loss, (model, acc_batch)
@jax.jit
3
Source : cnn_model_distributed.py
with MIT License
from cgarciae
with MIT License
from cgarciae
def loss_fn(
params: Module,
module: Module,
x: jnp.ndarray,
y: jnp.ndarray,
) -> tp.Tuple[jnp.ndarray, tp.Tuple[Module, jnp.ndarray]]:
module = module.merge(params)
preds = module(x)
loss = jnp.mean(
optax.softmax_cross_entropy(
preds,
jax.nn.one_hot(y, 10),
)
)
return loss, (module, preds)
@partial(jax.pmap, axis_name="device")
3
Source : vae.py
with MIT License
from cgarciae
with MIT License
from cgarciae
def loss_fn(params: VAE, model: VAE, x: np.ndarray) -> tp.Tuple[jnp.ndarray, VAE]:
model = model.merge(params)
x_pred = model(x)
crossentropy_loss = jnp.mean(optax.sigmoid_binary_cross_entropy(x_pred, x))
aux_losses = jax.tree_leaves(model.filter(tx.LossLog))
loss = crossentropy_loss + sum(aux_losses, 0.0)
return loss, model
@jax.jit
3
Source : squeeze_and_excite_layer.py
with Apache License 2.0
from DarshanDeshpande
with Apache License 2.0
from DarshanDeshpande
def __call__(self, inputs):
batch, _, _, channels = inputs.shape
global_avg_pool = jnp.mean(inputs, axis=[1, 2], keepdims=False)
dense1 = nn.Dense(channels // self.reduction, use_bias=False)(global_avg_pool)
dense1 = nn.relu(dense1)
dense2 = nn.Dense(channels, use_bias=False)(dense1)
dense2 = nn.sigmoid(dense2)
expand_dims = jnp.reshape(dense2, (batch, 1, 1, channels))
shape_broadcasted = jnp.broadcast_to(expand_dims, inputs.shape)
return inputs * shape_broadcasted
3
Source : conv_mixer.py
with Apache License 2.0
from DarshanDeshpande
with Apache License 2.0
from DarshanDeshpande
def __call__(self, inputs, deterministic=None):
extract_patches = nn.Conv(
self.features, (self.patch_size, self.patch_size), self.patch_size
)(inputs)
x = nn.gelu(extract_patches)
x = nn.BatchNorm(deterministic)(x)
for _ in range(self.num_mixer_layers):
x = ConvMixerLayer(self.features, self.filter_size)(x, deterministic)
if self.attach_head:
x = jnp.mean(x, [1, 2])
x = nn.Dense(self.num_classes)(x)
x = nn.softmax(x)
return x
@register_model
3
Source : losses.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def manual_cross_entropy(labels, logits, weight):
ce = - weight * jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
return jnp.mean(ce)
def byol_nce_detcon(pred1, pred2, target1, target2,
3
Source : utils_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_segment_normalize(self, nan_data):
norm = lambda x: (x - jnp.mean(x)) * jax.lax.rsqrt(jnp.var(x))
data = jnp.arange(8, dtype=jnp.float32)
segment_ids = jnp.array([0, 0, 0, 1, 1, 2, 2, 2])
expected_out = jnp.concatenate(
[norm(jnp.arange(3, dtype=jnp.float32)),
norm(jnp.arange(3, 5, dtype=jnp.float32)),
norm(jnp.arange(5, 8, dtype=jnp.float32))])
if nan_data:
data = data.at[0].set(jnp.nan)
expected_out = expected_out.at[:3].set(jnp.nan)
result = utils.segment_normalize(data, segment_ids, 3)
np.testing.assert_allclose(result, expected_out)
@parameterized.parameters((False, False),
3
Source : lookahead_mnist.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def categorical_crossentropy(logits: jnp.ndarray,
labels: jnp.ndarray) -> jnp.ndarray:
losses = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=1)
return jnp.mean(losses)
def accuracy(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
3
Source : loss_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def testSigmoidCrossEntropy(self, preds, labels, expected):
tested = jnp.mean(loss.sigmoid_binary_cross_entropy(preds, labels))
np.testing.assert_allclose(tested, expected, rtol=1e-6, atol=1e-6)
class CosineDistanceTest(parameterized.TestCase):
3
Source : kernels_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_rff_approximate_hsic_xx(self, dim, c):
num_features = 512
rng = random.PRNGKey(42)
rng1, rng2, rng_x = random.split(rng, 3)
amp, amp_probs = kernels.imq_amplitude_frequency_and_probs(dim)
rff_kwargs = {'amp': amp, 'amp_probs': amp_probs}
x = random.uniform(rng_x, [1, dim])
k = inverse_multiquadratic(x, x, c)
k = k - jnp.mean(k, axis=1, keepdims=True)
hsic_xx = (k * k).mean()
hsic_xx_approx = kernels.rff_approximate_hsic_xx([x], num_features,
rng1, rng2, c, rff_kwargs)
self.assertEqual(hsic_xx_approx, hsic_xx)
if __name__ == '__main__':
3
Source : graph.py
with MIT License
from ericmjl
with MIT License
from ericmjl
def mseloss(p, model, Fs, As, y):
yhat = model(p, Fs, As)
return np.mean(_mse_loss(y, yhat))
def model(params, Fs, As):
3
Source : rnn.py
with MIT License
from ericmjl
with MIT License
from ericmjl
def mseloss(p, model, x, y):
yhat = model(p, x)
return np.mean(_mse_loss(y, yhat))
dloss = grad(mseloss)
3
Source : vae.py
with MIT License
from ericmjl
with MIT License
from ericmjl
def vae_loss(flat_params, unflattener, model, x, y):
# Make predictions
params = unflattener(flat_params)
y_hat = model(params, x)
# Define KL-divergence loss
z_mean, z_log_var = encoder(params, x)
# Loss is sum of cross-entropy loss and KL loss.
return np.mean(
cross_entropy_loss(y, y_hat) + kl_divergence(z_mean, z_log_var)
)
3
Source : api_test.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def test_staging_out_multi_replica(self):
def f(x):
return api.pmap(jnp.mean)(x)
xla_comp = api.xla_computation(f)
xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash
def test_xla_computation_instantiate_constant_outputs(self):
3
Source : no_dependency_learned_optimizer.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def meta_loss(meta_params, key, sequence_of_batches):
def step(opt_state, batch):
loss, grads = value_grad_fn(opt_state[0], batch)
opt_state = lopt.update_inner_opt_state(meta_params, opt_state, grads)
return opt_state, loss
params = task.init(key)
opt_state = lopt.initial_inner_opt_state(meta_params, params)
# Iterate N times where N is the number of batches in sequence_of_batches
opt_state, losses = jax.lax.scan(step, opt_state, sequence_of_batches)
return jnp.mean(losses)
key = jax.random.PRNGKey(0)
3
Source : no_dependency_learned_optimizer.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def vectorized_meta_loss(meta_params, key, sequence_of_batches):
vec_loss = jax.vmap(
meta_loss, in_axes=(None, 0, 0))(meta_params, key, sequence_of_batches)
return jnp.mean(vec_loss)
vec_meta_loss_grad = jax.jit(jax.value_and_grad(vectorized_meta_loss))
3
Source : no_dependency_learned_optimizer.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def vec_antithetic_es_estimate(meta_params, keys, vec_seq_batches):
"""Compute a ES estimated gradient along multiple directions."""
losses, grads = jax.vmap(
antithetic_es_estimate, in_axes=(None, 0, 0))(meta_params, keys,
vec_seq_batches)
return jnp.mean(losses), [jnp.mean(g, axis=0) for g in grads]
keys = jax.random.split(key, 8)
3
Source : no_dependency_learned_optimizer.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def vec_short_segment_unroll(meta_params, keys, inner_opt_states, on_iterations,
vec_seq_of_batches):
losses, inner_opt_states, on_iterations = jax.vmap(
short_segment_unroll,
in_axes=(None, 0, 0, 0, 0))(meta_params, keys, inner_opt_states,
on_iterations, vec_seq_of_batches)
return jnp.mean(losses), (inner_opt_states, on_iterations)
vec_short_segment_grad = jax.jit(
3
Source : summary_tutorial.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def forward(params):
to_look_at = jnp.mean(params) * 2.
return params
def loss(parameters):
3
Source : summary_tutorial.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def forward(params):
to_look_at = jnp.mean(params) * 2.
summary.summary("to_look_at", to_look_at)
return params
@jax.jit
3
Source : summary_tutorial.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def layer(params):
to_look_at = jnp.mean(params) * 2.
summary.summary("to_look_at", to_look_at)
return params * 2
@jax.jit
3
Source : summary_tutorial.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def layer(params):
to_look_at = jnp.mean(params) * 2.
summary.summary("to_look_at", to_look_at, aggregation="sample")
return params * 2
@jax.jit
3
Source : summary_tutorial.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def layer(params):
to_look_at = jnp.mean(params) * 2.
summary.summary(
"to_look_at", jnp.arange(10) * to_look_at, aggregation="collect")
return params * 2
@jax.jit
3
Source : summary_tutorial.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def layer(params):
to_look_at = jnp.mean(params) * 2.
summary.summary("to_look_at", jnp.arange(10) * to_look_at)
return params * 2
@functools.partial(jax.jit, static_argnames="with_summary")
3
Source : summary_tutorial.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def lossb(p):
to_look_at = jnp.mean(123.)
return p * 2, to_look_at
def loss(parameters):
3
Source : summary_tutorial.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def loss(parameters):
loss = jnp.mean(parameters**2)
to_look_at = jnp.mean(123.)
hcb.id_print(to_look_at, name="to_look_at")
return loss
value_grad_fn = jax.jit(jax.value_and_grad(loss))
3
Source : summary.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def tree_scalar_mean(prefix, values):
for li, l in enumerate(jax.tree_leaves(values)):
summary(prefix + "/" + str(li), jnp.mean(l))
def tree_step(prefix, values):
3
Source : summary.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def summarize_inner_params(params: Any):
for k, v in _nested_to_names(params):
summary(k + "/mean", jnp.mean(v))
summary(k + "/mean_abs", jnp.mean(jnp.abs(v)))
class SummaryWriterBase:
3
Source : image_mlp.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def loss(self, params: Params, key: PRNGKey, data: Any) -> jnp.ndarray:
num_classes = self.datasets.extra_info["num_classes"]
logits = self._mod.apply(params, key, data["image"])
labels = jax.nn.one_hot(data["label"], num_classes)
vec_loss = base.softmax_cross_entropy(logits=logits, labels=labels)
return jnp.mean(vec_loss)
def normalizer(self, loss):
3
Source : image_mlp.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def loss_with_state(self, params: Params, state: ModelState, key: PRNGKey,
data: Any) -> Tuple[jnp.ndarray, ModelState]:
num_classes = self.datasets.extra_info["num_classes"]
logits, state = self._mod.apply(params, state, key, data["image"])
labels = jax.nn.one_hot(data["label"], num_classes)
vec_loss = base.softmax_cross_entropy(logits=logits, labels=labels)
return jnp.mean(vec_loss), state
def loss_with_state_and_aux(
3
Source : resnet.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def _hk_forward(self, batch):
args = [
'blocks_per_group', 'use_projection', 'channels_per_group',
'initial_conv_kernel_size', 'initial_conv_stride', 'max_pool',
'resnet_v2'
]
num_classes = self.datasets.extra_info['num_classes']
mod = resnet.ResNet(
num_classes=num_classes, **{k: self._cfg[k] for k in args})
logits = mod(batch['image'], is_training=True)
loss = base.softmax_cross_entropy(
logits=logits, labels=jax.nn.one_hot(batch['label'], num_classes))
return jnp.mean(loss)
def init_with_state(self, key: chex.PRNGKey) -> base.Params:
3
Source : resnet.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def __call__(self, inputs, is_training, test_local_stats=False):
out = inputs
out = self.initial_conv(out)
if not self.resnet_v2:
out = self.initial_batchnorm(out, is_training, test_local_stats)
out = self.act_fn(out)
if self.max_pool:
out = hk.max_pool(
out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME")
for block_group in self.block_groups:
out = block_group(out, is_training, test_local_stats)
if self.resnet_v2:
out = self.final_batchnorm(out, is_training, test_local_stats)
out = self.act_fn(out)
out = jnp.mean(out, axis=[1, 2])
return self.logits(out)
3
Source : plot_test.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def test_plot_model_fit_plot_called_with_scaler(self):
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target = target_scaler.fit_transform(jnp.ones(50))
mmm = lightweight_mmm.LightweightMMM()
mmm.fit(
media=jnp.ones((50, 3)),
target=target,
total_costs=jnp.repeat(50, 3),
number_warmup=5,
number_samples=5,
number_chains=1)
plot.plot_model_fit(media_mix_model=mmm, target_scaler=target_scaler)
self.assertTrue(self.mock_plt_plot.called)
def test_plot_model_fit_plot_called_without_scaler(self):
3
Source : gpt2.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def __call__(self, x):
# mean across the axis
u = np.mean(x, axis=self.axis, keepdims=True)
# standard deviation
s = np.mean((x - u) ** 2, axis=self.axis, keepdims=True)
# rescaled values
x = (x - u) * objax.functional.rsqrt(s + self.epsilon)
x = x * self.g.value + self.b.value
return x
def split_states(x, n):
3
Source : train_utils.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def sigmoid_xent(*, logits, labels, reduction=True):
"""Computes a sigmoid cross-entropy (Bernoulli NLL) loss over examples."""
log_p = jax.nn.log_sigmoid(logits)
log_not_p = jax.nn.log_sigmoid(-logits)
nll = -jnp.sum(labels * log_p + (1. - labels) * log_not_p, axis=-1)
return jnp.mean(nll) if reduction else nll
def softmax_xent(*, logits, labels, reduction=True, kl=False):
3
Source : bit_resnet.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def weight_standardize(w: jnp.ndarray,
axis: Union[Sequence[int], int],
eps: float):
"""Standardize (mean=0, var=1) a weight."""
w = w - jnp.mean(w, axis=axis, keepdims=True)
w = w / jnp.sqrt(jnp.mean(jnp.square(w), axis=axis, keepdims=True) + eps)
return w
class IdentityLayer(nn.Module):
3
Source : loss.py
with MIT License
from gortizji
with MIT License
from gortizji
def softmax_cross_entropy_loss(logits, labels):
if len(labels.shape) < = 1:
num_classes = logits.shape[-1]
soft_labels = jax.nn.one_hot(labels, num_classes=num_classes)
else:
soft_labels = labels
return jnp.mean(optax.softmax_cross_entropy(logits, soft_labels))
def binary_cross_entropy_loss_with_logits(logits, labels):
3
Source : modeling_flax_beit.py
with Apache License 2.0
from huggingface
with Apache License 2.0
from huggingface
def __call__(self, hidden_states):
if self.config.use_mean_pooling:
# Mean pool the final hidden states of the patch tokens
patch_tokens = hidden_states[:, 1:, :]
pooled_output = self.layernorm(jnp.mean(patch_tokens, axis=1))
else:
# Pool by simply taking the final hidden state of the [CLS] token
pooled_output = hidden_states[:, 0]
return pooled_output
class FlaxBeitModule(nn.Module):
3
Source : utils.py
with MIT License
from jamesvuc
with MIT License
from jamesvuc
def confidence_bands(y, sample_axis=-1):
""" Computes confidence bands for samples y.
Args:
y: array of samples.
Returns:
mean and standard deviation along the axes not equal to sample_axis.
"""
m = jnp.mean(y, axis=sample_axis)
s = jnp.std(y, axis=sample_axis)
return m - s, m + s
3
Source : mpi_wrapper_t.py
with MIT License
from markusschmitt
with MIT License
from markusschmitt
def test_mean(self):
data=jnp.array(np.arange(720*4*global_defs.device_count()).reshape((global_defs.device_count()*720,4)))
myNumSamples = mpi.distribute_sampling(global_defs.device_count()*720)
myData=data[mpi.rank*myNumSamples:(mpi.rank+1)*myNumSamples].reshape(get_shape((-1,4)))
self.assertTrue( jnp.sum(mpi.global_mean(myData)-jnp.mean(data,axis=0)) < 1e-10 )
def test_var(self):
3
Source : test_mpi_stats.py
with Apache License 2.0
from netket
with Apache License 2.0
from netket
def test_mean(axis, arr, arr_loc):
arr_mean = jnp.mean(arr, axis=axis)
nk_mean = nk.stats.mean(arr_loc, axis=axis)
np.testing.assert_allclose(nk_mean, arr_mean)
nk_mean = nk.stats.mean(arr_loc, keepdims=True, axis=axis)
assert nk_mean.ndim == arr.ndim
np.testing.assert_allclose(nk_mean, arr_mean.reshape(1, -1))
@pytest.mark.parametrize("axis", [None, 0])
3
Source : test_mpi_stats.py
with Apache License 2.0
from netket
with Apache License 2.0
from netket
def test_subtract_mean(axis, arr, arr_loc, _mpi_rank):
arr_sub = arr - jnp.mean(arr, axis=axis)
nk_sub = nk.stats.subtract_mean(arr_loc, axis=axis)
np.testing.assert_allclose(
nk_sub, arr_sub[100 * _mpi_rank : 100 * (_mpi_rank + 1), :]
)
@pytest.mark.parametrize("axis", [None, 0])
3
Source : data.py
with MIT License
from nikikilbertus
with MIT License
from nikikilbertus
def whiten(
inputs: Dict[Text, np.ndarray]
) -> Dict[Text, Union[float, np.ndarray, None]]:
"""Whiten each input."""
res = {}
for k, v in inputs.items():
if v is not None:
mu = np.mean(v, 0)
std = np.maximum(np.std(v, 0), 1e-7)
res[k + "_mu"] = mu
res[k + "_std"] = std
res[k] = (v - mu) / std
else:
res[k] = v
return res
def whiten_with_mu_std(val: np.ndarray, mu: float, std: float) -> np.ndarray:
3
Source : plotting.py
with MIT License
from nikikilbertus
with MIT License
from nikikilbertus
def plot_hist_at_z(y: np.ndarray, bin_ids: np.ndarray, idx: int) -> plt.Figure:
fig = plt.figure()
plt.hist(y[bin_ids == idx], bins=30)
plt.xlabel('y')
mean = np.mean(y[bin_ids == idx])
var = np.var(y[bin_ids == idx])
plt.title(
f"mean {mean:.2f} and variance {var:.2f} for z bin {idx}")
return fig
@empty_fig_on_failure
3
Source : run.py
with MIT License
from nikikilbertus
with MIT License
from nikikilbertus
def get_phi(y: np.ndarray) -> np.ndarray:
"""The phis for the constraints."""
return np.array([np.mean(y, axis=-1), np.var(y, axis=-1)]).T
@jit
3
Source : char_rnn.py
with MIT License
from NTT123
with MIT License
from NTT123
def update_fn(model, optimizer, multi_batch: jnp.ndarray):
(model, optimizer), losses = pax.scan(update_step, (model, optimizer), multi_batch)
return model, optimizer, jnp.mean(losses)
net = LM(vocab_size=vocab_size, hidden_dim=hidden_dim)
3
Source : train.py
with MIT License
from NTT123
with MIT License
from NTT123
def loss_fn(model: WaveGRU, inputs):
logmel, wav = inputs
input_wav = wav[:, :-1]
target_wav = wav[:, 1:]
model, logits = pax.purecall(model, (logmel, input_wav))
log_pr = jax.nn.log_softmax(logits, axis=-1)
target_wave = jax.nn.one_hot(target_wav, num_classes=logits.shape[-1])
log_pr = jnp.sum(log_pr * target_wave, axis=-1)
loss = -jnp.mean(log_pr)
return loss, (loss, model)
def generate_test_sample(step, test_logmel, wave_gru, length, sample_rate, mu):
3
Source : __init__.py
with GNU General Public License v3.0
from PKU-NIP-Lab
with GNU General Public License v3.0
from PKU-NIP-Lab
def mean_absolute_error(x, y, axis=None):
r"""Computes the mean absolute error between x and y.
Args:
x: a tensor of shape (d0, .. dN-1).
y: a tensor of shape (d0, .. dN-1).
keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value.
Returns:
tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error.
"""
r = ops.abs(x - y)
return jn.mean(ops.as_device_array(r), axis=axis)
def mean_squared_error(predicts, targets, axis=None):
3
Source : __init__.py
with GNU General Public License v3.0
from PKU-NIP-Lab
with GNU General Public License v3.0
from PKU-NIP-Lab
def mean_squared_error(predicts, targets, axis=None):
r"""Computes the mean squared error between x and y.
Args:
predicts: a tensor of shape (d0, .. dN-1).
targets: a tensor of shape (d0, .. dN-1).
keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value.
Returns:
tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error.
"""
r = (predicts - targets) ** 2
return jn.mean(ops.as_device_array(r), axis=axis)
def mean_squared_log_error(y_true, y_pred, axis=None):
3
Source : __init__.py
with GNU General Public License v3.0
from PKU-NIP-Lab
with GNU General Public License v3.0
from PKU-NIP-Lab
def mean_squared_log_error(y_true, y_pred, axis=None):
r"""Computes the mean squared logarithmic error between y_true and y_pred.
Args:
y_true: a tensor of shape (d0, .. dN-1).
y_pred: a tensor of shape (d0, .. dN-1).
keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value.
Returns:
tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error.
"""
r = (ops.log1p(y_true) - ops.log1p(y_pred)) ** 2
return jn.mean(ops.as_device_array(r), axis=axis)
def huber_loss(predicts, targets, delta: float = 1.0):
See More Examples