Here are the examples of the python api jax.numpy.ones taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.
619 Examples
3
Source : test_model.py
with Apache License 2.0
from akbir
with Apache License 2.0
from akbir
def test_transformer_block():
mod = UTBlock(3, 5, 0.1)
input = jnp.ones([1, 2, 3])
hidden = jnp.ones([1, 2, 3])
out = mod(input, hidden, mask=None, is_training=False)
assert out.ndim == 3
@pytest.fixture
3
Source : test_integration.py
with MIT License
from brentyi
with MIT License
from brentyi
def make(variable: Variable) -> "UniFactor":
return UniFactor(
variables=(variable,),
noise_model=jaxfg.noises.DiagonalGaussian.make_from_covariance(jnp.ones(4)),
)
@overrides
3
Source : test_integration.py
with MIT License
from brentyi
with MIT License
from brentyi
def make(variable1: Variable, variable2: Variable) -> "BiFactor":
return BiFactor(
variables=(variable1, variable2),
noise_model=jaxfg.noises.DiagonalGaussian.make_from_covariance(jnp.ones(4)),
)
@overrides
3
Source : test_recurrent.py
with MIT License
from cgarciae
with MIT License
from cgarciae
def test_optional_initial_state(self):
key = tx.Key(8)
hidden_dim = 5
features = 10
batch_size = 32
time = 10
gru = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=False)
gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))))
inputs = np.random.rand(time, batch_size, features)
assert np.allclose(gru(inputs), gru(inputs, np.zeros((batch_size, hidden_dim))))
assert np.allclose(gru(inputs), gru(inputs, gru.initialize_state(batch_size)))
def test_stateful(self):
3
Source : test_treex.py
with MIT License
from cgarciae
with MIT License
from cgarciae
def test_shape_inference(self):
x = jnp.ones((5, 2))
module = LazyLinear(1).init(42, x)
y = module(x)
assert y.shape == (5, 1)
def test_init_error(self):
3
Source : test_treex.py
with MIT License
from cgarciae
with MIT License
from cgarciae
def test_init_error(self):
x = jnp.ones((5, 2))
module = LazyLinear(1)
with pytest.raises(RuntimeError):
y = module(x)
def test_compact(self):
3
Source : test_poolformer.py
with Apache License 2.0
from DarshanDeshpande
with Apache License 2.0
from DarshanDeshpande
def test_s12_init(self):
rng, drop = random.split(random.PRNGKey(0), 2)
s12 = poolformer_s12()
x = jnp.ones([1, 256, 256, 3])
params = s12.init({"params": rng, "dropout": drop}, x, False)["params"]
sample_out = s12.apply({"params": params}, x, False, rngs={"dropout": drop})
self.assertEqual(sample_out.shape, (1, 8, 8, 512))
def test_s24_init(self):
3
Source : test_poolformer.py
with Apache License 2.0
from DarshanDeshpande
with Apache License 2.0
from DarshanDeshpande
def test_s24_init(self):
rng, drop = random.split(random.PRNGKey(0), 2)
s24 = poolformer_s24()
x = jnp.ones([1, 256, 256, 3])
params = s24.init({"params": rng, "dropout": drop}, x, False)["params"]
sample_out = s24.apply({"params": params}, x, False, rngs={"dropout": drop})
self.assertEqual(sample_out.shape, (1, 8, 8, 512))
def test_s36_init(self):
3
Source : test_poolformer.py
with Apache License 2.0
from DarshanDeshpande
with Apache License 2.0
from DarshanDeshpande
def test_s36_init(self):
rng, drop = random.split(random.PRNGKey(0), 2)
s36 = poolformer_s36()
x = jnp.ones([1, 256, 256, 3])
params = s36.init({"params": rng, "dropout": drop}, x, False)["params"]
sample_out = s36.apply({"params": params}, x, False, rngs={"dropout": drop})
self.assertEqual(sample_out.shape, (1, 8, 8, 512))
def test_m36_init(self):
3
Source : test_poolformer.py
with Apache License 2.0
from DarshanDeshpande
with Apache License 2.0
from DarshanDeshpande
def test_m36_init(self):
rng, drop = random.split(random.PRNGKey(0), 2)
m36 = poolformer_m36()
x = jnp.ones([1, 256, 256, 3])
params = m36.init({"params": rng, "dropout": drop}, x, False)["params"]
sample_out = m36.apply({"params": params}, x, False, rngs={"dropout": drop})
self.assertEqual(sample_out.shape, (1, 8, 8, 768))
def test_m48_init(self):
3
Source : test_poolformer.py
with Apache License 2.0
from DarshanDeshpande
with Apache License 2.0
from DarshanDeshpande
def test_m48_init(self):
rng, drop = random.split(random.PRNGKey(0), 2)
m48 = poolformer_m48()
x = jnp.ones([1, 256, 256, 3])
params = m48.init({"params": rng, "dropout": drop}, x, False)["params"]
sample_out = m48.apply({"params": params}, x, False, rngs={"dropout": drop})
self.assertEqual(sample_out.shape, (1, 8, 8, 768))
3
Source : densities.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def evaluate_log_density(self, x: Array) -> Array:
mean = jnp.ones(self._num_dim) * self._config.shared_mean
cov = jnp.diag(jnp.ones(self._num_dim) * self._config.diagonal_cov)
output = multivariate_normal.logpdf(x,
mean=mean,
cov=cov)
return output
class FunnelDistribution(LogDensity):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_pass(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
asserts.assert_axis_dimension(tensor, axis=i, expected=s)
def test_assert_axis_dimension_fail(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_fail(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
with self.assertRaisesRegex(
AssertionError, _get_err_regex('Expected tensor to have dimension')):
asserts.assert_axis_dimension(tensor, axis=i, expected=s + 1)
def test_assert_axis_dimension_axis_invalid(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_axis_invalid(self):
tensor = jnp.ones((3, 2))
for i in (2, -3):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('not available')):
asserts.assert_axis_dimension(tensor, axis=i, expected=1)
def test_assert_axis_dimension_gt_pass(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_gt_pass(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
asserts.assert_axis_dimension_gt(tensor, axis=i, val=s - 1)
def test_assert_axis_dimension_gt_fail(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_gt_fail(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Expected tensor to have dimension greater than')):
asserts.assert_axis_dimension_gt(tensor, axis=i, val=s)
def test_assert_axis_dimension_gt_axis_invalid(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_gt_axis_invalid(self):
tensor = jnp.ones((3, 2))
for i in (2, -3):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('not available')):
asserts.assert_axis_dimension_gt(tensor, axis=i, val=0)
def test_assert_axis_dimension_gteq_pass(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_gteq_pass(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
asserts.assert_axis_dimension_gteq(tensor, axis=i, val=s)
def test_assert_axis_dimension_gteq_fail(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_gteq_fail(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Expected tensor to have dimension greater than or')):
asserts.assert_axis_dimension_gteq(tensor, axis=i, val=s + 1)
def test_assert_axis_dimension_gteq_axis_invalid(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_gteq_axis_invalid(self):
tensor = jnp.ones((3, 2))
for i in (2, -3):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('not available')):
asserts.assert_axis_dimension_gteq(tensor, axis=i, val=0)
def test_assert_axis_dimension_lt_pass(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_lt_pass(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
asserts.assert_axis_dimension_lt(tensor, axis=i, val=s + 1)
def test_assert_axis_dimension_lt_fail(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_lt_fail(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Expected tensor to have dimension less than')):
asserts.assert_axis_dimension_lt(tensor, axis=i, val=s)
def test_assert_axis_dimension_lt_axis_invalid(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_lt_axis_invalid(self):
tensor = jnp.ones((3, 2))
for i in (2, -3):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('not available')):
asserts.assert_axis_dimension_lt(tensor, axis=i, val=0)
def test_assert_axis_dimension_lteq_pass(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_lteq_pass(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
asserts.assert_axis_dimension_lteq(tensor, axis=i, val=s)
def test_assert_axis_dimension_lteq_fail(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_lteq_fail(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Expected tensor to have dimension less than or')):
asserts.assert_axis_dimension_lteq(tensor, axis=i, val=s - 1)
def test_assert_axis_dimension_lteq_axis_invalid(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_axis_dimension_lteq_axis_invalid(self):
tensor = jnp.ones((3, 2))
for i in (2, -3):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('not available')):
asserts.assert_axis_dimension_lteq(tensor, axis=i, val=0)
class TreeAssertionsTest(parameterized.TestCase):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_tree_all_finite_should_fail_inf(self):
inf_tree = {
'finite_var': jnp.ones((3,)),
'inf_var': jnp.array([0.0, jnp.inf]),
}
with self.assertRaisesRegex(
AssertionError, _get_err_regex('Tree contains non-finite value')):
asserts.assert_tree_all_finite(inf_tree)
def test_assert_trees_all_equal_passes_same_tree(self):
3
Source : asserts_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_equal_pass_on_arrays(self):
# Not using named_parameters, becase JAX cannot be used before app.run().
asserts.assert_equal(jnp.ones([]), np.ones([]))
asserts.assert_equal(
jnp.ones([], dtype=jnp.int32), np.ones([], dtype=np.float64))
@parameterized.named_parameters(
3
Source : fake_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_pmapped(self):
def foo(x):
return x * 2
fn_input = jnp.ones((4,))
_assert_pmapped(foo, fn_input, True)
with self.assertRaises(AssertionError):
_assert_pmapped(foo, fn_input, False)
def test_assert_jitted(self):
3
Source : fake_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_assert_jitted(self):
fn_input = jnp.ones((4,))
def foo(x):
return x * 2
_assert_jitted(foo, fn_input, True)
with self.assertRaises(AssertionError):
_assert_jitted(foo, fn_input, False)
@parameterized.named_parameters([
3
Source : fake_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_fake_jit(self, fake_kwargs, is_jitted):
fn_input = jnp.ones((4,))
def foo(x):
return x * 2
# Call with context manager
with fake.fake_jit(**fake_kwargs):
_assert_jitted(foo, fn_input, is_jitted)
# Call with start/stop
ctx = fake.fake_jit(**fake_kwargs)
ctx.start()
_assert_jitted(foo, fn_input, is_jitted)
ctx.stop()
@parameterized.named_parameters([
3
Source : fake_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_fake_pmap_(self, is_pmapped, jit_result):
enable_patching = not is_pmapped
fn_input = jnp.ones((4,))
def foo(x):
return x * 2
# Call with context manager
with fake.fake_pmap(enable_patching=enable_patching, jit_result=jit_result):
_assert_pmapped(foo, fn_input, is_pmapped, jit_result)
# Call with start/stop
ctx = fake.fake_pmap(enable_patching=enable_patching, jit_result=jit_result)
ctx.start()
_assert_pmapped(foo, fn_input, is_pmapped, jit_result)
ctx.stop()
def test_fake_pmap_axis_name(self):
3
Source : fake_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_pmap_and_jit(self, fake_kwargs, is_pmapped, is_jitted):
fn_input = jnp.ones((4,))
def foo(x):
return x * 2
# Call with context manager
with fake.fake_pmap_and_jit(**fake_kwargs):
_assert_pmapped(foo, fn_input, is_pmapped)
_assert_jitted(foo, fn_input, is_jitted)
# Call with start/stop
ctx = fake.fake_pmap_and_jit(**fake_kwargs)
ctx.start()
_assert_pmapped(foo, fn_input, is_pmapped)
_assert_jitted(foo, fn_input, is_jitted)
ctx.stop()
@parameterized.named_parameters([
3
Source : pretrain_common.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def _fake_data_generator(self, image_shape: Tuple[int, int, int, int]):
mask1 = np.random.uniform(low=0, high=7, size=image_shape + (1,))
mask2 = np.random.uniform(low=0, high=7, size=image_shape + (1,))
while True:
yield {
'view1': jnp.ones(image_shape + (3,)) * 0.5,
'view2': jnp.ones(image_shape + (3,)) * 0.3,
'fh_segmentations1': jnp.array(np.round(mask1), dtype=jnp.uint8),
'fh_segmentations2': jnp.array(np.round(mask2), dtype=jnp.uint8),
'labels': jnp.ones([image_shape[0], image_shape[1], 1],
dtype=jnp.int64),
}
def _build_train_input(self) -> Generator[image_dataset.Batch, None, None]:
3
Source : utils.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def rendezvous() -> None:
"""Forces all hosts to check in."""
with log_activity("rendezvous"):
x = jnp.ones([jax.local_device_count()])
x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, "i"), "i")(x))
if x[0] != jax.device_count():
raise ValueError(f"Expected {jax.device_count()} got {x}")
class PeriodicAction:
3
Source : utils_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_bcast_local_devices_tree(self):
num_devices = jax.local_device_count()
tree = utils.bcast_local_devices({"ones": jnp.ones([]),
"zeros": jnp.zeros([])})
self.assertEqual(tree, {"ones": jnp.ones([num_devices]),
"zeros": jnp.zeros([num_devices])})
class TestLogActivity(absltest.TestCase):
3
Source : linear_bound_utils.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def initial_params(
self,
*inps: Optional[Union[Bound, Tensor]],
) -> Nest[Tensor]:
# If an input is [known to be] a fixed tensor, don't allocate a
# relaxation parameter for it.
return [
.5 * jnp.ones(shape=inp_shape)
if inp is None or isinstance(inp, Bound) else None
for inp, inp_shape in zip(inps, self._input_shapes)]
def project_params(self, relax_params: Nest[Tensor]) -> Nest[Tensor]:
3
Source : nonconvex.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def lower(self) -> Tensor:
if self._concretized_bounds is None:
logging.warning('.lower called on a non-concretized bound.'
'Returning spurious bounds.')
return -float('inf') * jnp.ones(self.shape)
return self._concretized_bounds.lower
@property
3
Source : nonconvex.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def upper(self) -> Tensor:
if self._concretized_bounds is None:
logging.warning('.upper called on a non-concretized bound.'
'Returning spurious bounds.')
return float('inf') * jnp.ones(self.shape)
return self._concretized_bounds.upper
def evaluate(self,
3
Source : pga_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def setUp(self):
super().setUp()
self.prng_key = jax.random.PRNGKey(1234)
self.bounds = [
sdp_utils.IntBound(
lb=-jnp.ones(X_SHAPE),
ub=jnp.ones(X_SHAPE),
lb_pre=None,
ub_pre=None),
sdp_utils.IntBound(lb=None, ub=None, lb_pre=None, ub_pre=None),
]
def test_intermediate_problem(self):
3
Source : uncertainty_spec_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def setUp(self):
super(UncertaintySpecTest, self).setUp()
self._prng_seq = hk.PRNGSequence(13579)
self._n_classes = X_SHAPE[1]
self.bounds = [
sdp_utils.IntBound(
lb_pre=-0.1 * jnp.ones(X_SHAPE),
ub_pre=0.1 * jnp.ones(X_SHAPE),
lb=None,
ub=None)
]
def test_softmax_upper(self):
3
Source : cvxpy_verify_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def _fgsm_example_and_bound(params, target_label, label):
model_fn = lambda x: utils.predict_mlp(params, x)
x = 0.5 * jnp.ones(utils.nn_layer_sizes(params)[0])
epsilon = 0.5
x_adv = utils.fgsm_single(model_fn, x, label, target_label, epsilon,
num_steps=30, step_size=0.03)
return x_adv, utils.adv_objective(model_fn, x_adv, label, target_label)
MARGIN = 1e-6
3
Source : privacy_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_noise_multiplier(self, l2_norm_clip, noise_multiplier):
"""Standard dev. of noise should be l2_norm_clip * noise_multiplier."""
dp_agg = privacy.differentially_private_aggregate(
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
seed=1337)
state = dp_agg.init(None)
update_fn = self.variant(dp_agg.update)
expected_std = l2_norm_clip * noise_multiplier
grads = [jnp.ones((1, 100, 100))] # batch size 1
for _ in range(3):
updates, state = update_fn(grads, state)
chex.assert_tree_all_close(expected_std,
jnp.std(updates[0]),
atol=0.1 * expected_std)
def test_aggregated_updates_as_input_fails(self):
3
Source : schedule_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_static_args(self, static_args):
@functools.partial(schedule.inject_hyperparams, static_args=static_args)
def custom_optim(learning_rate, mask):
return wrappers.masked(transform.scale(-learning_rate), mask)
optim = custom_optim(
0.1, functools.partial(jax.tree_map, lambda x: x.ndim > 1))
params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))]
grads = params
state = self.variant(optim.init)(params)
updates, state = self.variant(optim.update)(grads, state)
expected_updates = jax.tree_map(lambda x: -0.1 * x if x.ndim > 1 else x,
grads)
assert set(state.hyperparams.keys()) == {'learning_rate'}, state.hyperparams
chex.assert_tree_all_close(updates, expected_updates)
@chex.all_variants
3
Source : schedule_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_numeric_static_args(self, static_args):
optim = schedule.inject_hyperparams(
transform.scale_by_adam, static_args=static_args)(b1=0.9, b2=0.95)
params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))]
grads = params
state = self.variant(optim.init)(params)
_, state = self.variant(optim.update)(grads, state)
assert not set(state.hyperparams.keys()).intersection(set(static_args))
@parameterized.named_parameters(('string', 'lr'), ('list', ['lr']))
3
Source : update_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_apply_updates(self):
params = ({'a': jnp.ones((3, 2))}, jnp.ones((1,)))
grads = jax.tree_map(lambda t: 2 * t, params)
exp_params = jax.tree_map(lambda t: 3 * t, params)
new_params = self.variant(update.apply_updates)(params, grads)
chex.assert_tree_all_close(
exp_params, new_params, atol=1e-10, rtol=1e-5)
@chex.all_variants()
3
Source : update_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_incremental_update(self):
params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,)))
params_2 = jax.tree_map(lambda t: 2 * t, params_1)
exp_params = jax.tree_map(lambda t: 1.5 * t, params_1)
new_params = self.variant(
update.incremental_update)(params_2, params_1, 0.5)
chex.assert_tree_all_close(
exp_params, new_params, atol=1e-10, rtol=1e-5)
@chex.all_variants()
3
Source : wrappers_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_mask_fn(self):
params = {'a': jnp.ones((1, 2)), 'b': (jnp.ones((1,)), np.ones((1, 2, 3)))}
mask_fn = lambda p: jax.tree_map(lambda x: x.ndim > 1, p)
init_fn, update_fn = wrappers.masked(transform.add_decayed_weights(0.1),
mask_fn)
update_fn = self.variant(update_fn)
state = self.variant(init_fn)(params)
grads = jax.tree_map(lambda x: x*2, params)
updates, state = update_fn(grads, state, params)
np.testing.assert_allclose(updates['a'], grads['a'] + 0.1*params['a'])
np.testing.assert_allclose(updates['b'][0], grads['b'][0])
np.testing.assert_allclose(updates['b'][1],
grads['b'][1] + 0.1*params['b'][1])
@chex.all_variants
3
Source : wrappers_test.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def test_masked_state_structure(self):
# https://github.com/deepmind/optax/issues/271
params = {'a': [jnp.ones(1), (jnp.ones(2), jnp.ones(3))],
'b': {'c': jnp.ones(4), 'd': jnp.ones(5)}}
mask = {'a': [True, (True, False)], 'b': False}
tx = wrappers.masked(_build_stateful_sgd(), mask)
trace = self.variant(tx.init)(params).inner_state[0].trace
expected_trace = {'a': [jnp.zeros(1), (jnp.zeros(2), None)], 'b': None}
chex.assert_tree_all_equal_structs(trace, expected_trace)
class MaybeUpdateTest(chex.TestCase):
See More Examples