jax.numpy.ones

Here are the examples of the python api jax.numpy.ones taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

619 Examples 7

3 Source : test_model.py
with Apache License 2.0
from akbir

def test_transformer_block():
    mod = UTBlock(3, 5, 0.1)
    input = jnp.ones([1, 2, 3])
    hidden = jnp.ones([1, 2, 3])
    out = mod(input, hidden, mask=None, is_training=False)
    assert out.ndim == 3


@pytest.fixture

3 Source : test_integration.py
with MIT License
from brentyi

    def make(variable: Variable) -> "UniFactor":
        return UniFactor(
            variables=(variable,),
            noise_model=jaxfg.noises.DiagonalGaussian.make_from_covariance(jnp.ones(4)),
        )

    @overrides

3 Source : test_integration.py
with MIT License
from brentyi

    def make(variable1: Variable, variable2: Variable) -> "BiFactor":
        return BiFactor(
            variables=(variable1, variable2),
            noise_model=jaxfg.noises.DiagonalGaussian.make_from_covariance(jnp.ones(4)),
        )

    @overrides

3 Source : test_recurrent.py
with MIT License
from cgarciae

    def test_optional_initial_state(self):
        key = tx.Key(8)
        hidden_dim = 5
        features = 10
        batch_size = 32
        time = 10

        gru = recurrent.GRU(hidden_dim, time_axis=0, go_backwards=False)
        gru = gru.init(key, (jnp.ones((1, 1, features)), jnp.zeros((1, hidden_dim))))

        inputs = np.random.rand(time, batch_size, features)
        assert np.allclose(gru(inputs), gru(inputs, np.zeros((batch_size, hidden_dim))))
        assert np.allclose(gru(inputs), gru(inputs, gru.initialize_state(batch_size)))

    def test_stateful(self):

3 Source : test_treex.py
with MIT License
from cgarciae

    def test_shape_inference(self):

        x = jnp.ones((5, 2))
        module = LazyLinear(1).init(42, x)

        y = module(x)

        assert y.shape == (5, 1)

    def test_init_error(self):

3 Source : test_treex.py
with MIT License
from cgarciae

    def test_init_error(self):

        x = jnp.ones((5, 2))
        module = LazyLinear(1)

        with pytest.raises(RuntimeError):
            y = module(x)

    def test_compact(self):

3 Source : test_poolformer.py
with Apache License 2.0
from DarshanDeshpande

    def test_s12_init(self):
        rng, drop = random.split(random.PRNGKey(0), 2)
        s12 = poolformer_s12()
        x = jnp.ones([1, 256, 256, 3])
        params = s12.init({"params": rng, "dropout": drop}, x, False)["params"]
        sample_out = s12.apply({"params": params}, x, False, rngs={"dropout": drop})
        self.assertEqual(sample_out.shape, (1, 8, 8, 512))

    def test_s24_init(self):

3 Source : test_poolformer.py
with Apache License 2.0
from DarshanDeshpande

    def test_s24_init(self):
        rng, drop = random.split(random.PRNGKey(0), 2)
        s24 = poolformer_s24()
        x = jnp.ones([1, 256, 256, 3])
        params = s24.init({"params": rng, "dropout": drop}, x, False)["params"]
        sample_out = s24.apply({"params": params}, x, False, rngs={"dropout": drop})
        self.assertEqual(sample_out.shape, (1, 8, 8, 512))

    def test_s36_init(self):

3 Source : test_poolformer.py
with Apache License 2.0
from DarshanDeshpande

    def test_s36_init(self):
        rng, drop = random.split(random.PRNGKey(0), 2)
        s36 = poolformer_s36()
        x = jnp.ones([1, 256, 256, 3])
        params = s36.init({"params": rng, "dropout": drop}, x, False)["params"]
        sample_out = s36.apply({"params": params}, x, False, rngs={"dropout": drop})
        self.assertEqual(sample_out.shape, (1, 8, 8, 512))

    def test_m36_init(self):

3 Source : test_poolformer.py
with Apache License 2.0
from DarshanDeshpande

    def test_m36_init(self):
        rng, drop = random.split(random.PRNGKey(0), 2)
        m36 = poolformer_m36()
        x = jnp.ones([1, 256, 256, 3])
        params = m36.init({"params": rng, "dropout": drop}, x, False)["params"]
        sample_out = m36.apply({"params": params}, x, False, rngs={"dropout": drop})
        self.assertEqual(sample_out.shape, (1, 8, 8, 768))

    def test_m48_init(self):

3 Source : test_poolformer.py
with Apache License 2.0
from DarshanDeshpande

    def test_m48_init(self):
        rng, drop = random.split(random.PRNGKey(0), 2)
        m48 = poolformer_m48()
        x = jnp.ones([1, 256, 256, 3])
        params = m48.init({"params": rng, "dropout": drop}, x, False)["params"]
        sample_out = m48.apply({"params": params}, x, False, rngs={"dropout": drop})
        self.assertEqual(sample_out.shape, (1, 8, 8, 768))

3 Source : densities.py
with Apache License 2.0
from deepmind

  def evaluate_log_density(self, x: Array) -> Array:
    mean = jnp.ones(self._num_dim) * self._config.shared_mean
    cov = jnp.diag(jnp.ones(self._num_dim) * self._config.diagonal_cov)
    output = multivariate_normal.logpdf(x,
                                        mean=mean,
                                        cov=cov)
    return output


class FunnelDistribution(LogDensity):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_pass(self):
    tensor = jnp.ones((3, 2, 7, 2))
    for i in range(-tensor.ndim, tensor.ndim):
      s = tensor.shape[i]
      asserts.assert_axis_dimension(tensor, axis=i, expected=s)

  def test_assert_axis_dimension_fail(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_fail(self):
    tensor = jnp.ones((3, 2, 7, 2))
    for i in range(-tensor.ndim, tensor.ndim):
      s = tensor.shape[i]
      with self.assertRaisesRegex(
          AssertionError, _get_err_regex('Expected tensor to have dimension')):
        asserts.assert_axis_dimension(tensor, axis=i, expected=s + 1)

  def test_assert_axis_dimension_axis_invalid(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_axis_invalid(self):
    tensor = jnp.ones((3, 2))
    for i in (2, -3):
      with self.assertRaisesRegex(AssertionError,
                                  _get_err_regex('not available')):
        asserts.assert_axis_dimension(tensor, axis=i, expected=1)

  def test_assert_axis_dimension_gt_pass(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_gt_pass(self):
    tensor = jnp.ones((3, 2, 7, 2))
    for i in range(-tensor.ndim, tensor.ndim):
      s = tensor.shape[i]
      asserts.assert_axis_dimension_gt(tensor, axis=i, val=s - 1)

  def test_assert_axis_dimension_gt_fail(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_gt_fail(self):
    tensor = jnp.ones((3, 2, 7, 2))
    for i in range(-tensor.ndim, tensor.ndim):
      s = tensor.shape[i]
      with self.assertRaisesRegex(
          AssertionError,
          _get_err_regex('Expected tensor to have dimension greater than')):
        asserts.assert_axis_dimension_gt(tensor, axis=i, val=s)

  def test_assert_axis_dimension_gt_axis_invalid(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_gt_axis_invalid(self):
    tensor = jnp.ones((3, 2))
    for i in (2, -3):
      with self.assertRaisesRegex(AssertionError,
                                  _get_err_regex('not available')):
        asserts.assert_axis_dimension_gt(tensor, axis=i, val=0)

  def test_assert_axis_dimension_gteq_pass(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_gteq_pass(self):
    tensor = jnp.ones((3, 2, 7, 2))
    for i in range(-tensor.ndim, tensor.ndim):
      s = tensor.shape[i]
      asserts.assert_axis_dimension_gteq(tensor, axis=i, val=s)

  def test_assert_axis_dimension_gteq_fail(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_gteq_fail(self):
    tensor = jnp.ones((3, 2, 7, 2))
    for i in range(-tensor.ndim, tensor.ndim):
      s = tensor.shape[i]
      with self.assertRaisesRegex(
          AssertionError,
          _get_err_regex('Expected tensor to have dimension greater than or')):
        asserts.assert_axis_dimension_gteq(tensor, axis=i, val=s + 1)

  def test_assert_axis_dimension_gteq_axis_invalid(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_gteq_axis_invalid(self):
    tensor = jnp.ones((3, 2))
    for i in (2, -3):
      with self.assertRaisesRegex(AssertionError,
                                  _get_err_regex('not available')):
        asserts.assert_axis_dimension_gteq(tensor, axis=i, val=0)

  def test_assert_axis_dimension_lt_pass(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_lt_pass(self):
    tensor = jnp.ones((3, 2, 7, 2))
    for i in range(-tensor.ndim, tensor.ndim):
      s = tensor.shape[i]
      asserts.assert_axis_dimension_lt(tensor, axis=i, val=s + 1)

  def test_assert_axis_dimension_lt_fail(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_lt_fail(self):
    tensor = jnp.ones((3, 2, 7, 2))
    for i in range(-tensor.ndim, tensor.ndim):
      s = tensor.shape[i]
      with self.assertRaisesRegex(
          AssertionError,
          _get_err_regex('Expected tensor to have dimension less than')):
        asserts.assert_axis_dimension_lt(tensor, axis=i, val=s)

  def test_assert_axis_dimension_lt_axis_invalid(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_lt_axis_invalid(self):
    tensor = jnp.ones((3, 2))
    for i in (2, -3):
      with self.assertRaisesRegex(AssertionError,
                                  _get_err_regex('not available')):
        asserts.assert_axis_dimension_lt(tensor, axis=i, val=0)

  def test_assert_axis_dimension_lteq_pass(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_lteq_pass(self):
    tensor = jnp.ones((3, 2, 7, 2))
    for i in range(-tensor.ndim, tensor.ndim):
      s = tensor.shape[i]
      asserts.assert_axis_dimension_lteq(tensor, axis=i, val=s)

  def test_assert_axis_dimension_lteq_fail(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_lteq_fail(self):
    tensor = jnp.ones((3, 2, 7, 2))
    for i in range(-tensor.ndim, tensor.ndim):
      s = tensor.shape[i]
      with self.assertRaisesRegex(
          AssertionError,
          _get_err_regex('Expected tensor to have dimension less than or')):
        asserts.assert_axis_dimension_lteq(tensor, axis=i, val=s - 1)

  def test_assert_axis_dimension_lteq_axis_invalid(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_axis_dimension_lteq_axis_invalid(self):
    tensor = jnp.ones((3, 2))
    for i in (2, -3):
      with self.assertRaisesRegex(AssertionError,
                                  _get_err_regex('not available')):
        asserts.assert_axis_dimension_lteq(tensor, axis=i, val=0)


class TreeAssertionsTest(parameterized.TestCase):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_tree_all_finite_should_fail_inf(self):
    inf_tree = {
        'finite_var': jnp.ones((3,)),
        'inf_var': jnp.array([0.0, jnp.inf]),
    }
    with self.assertRaisesRegex(
        AssertionError, _get_err_regex('Tree contains non-finite value')):
      asserts.assert_tree_all_finite(inf_tree)

  def test_assert_trees_all_equal_passes_same_tree(self):

3 Source : asserts_test.py
with Apache License 2.0
from deepmind

  def test_assert_equal_pass_on_arrays(self):
    # Not using named_parameters, becase JAX cannot be used before app.run().
    asserts.assert_equal(jnp.ones([]), np.ones([]))
    asserts.assert_equal(
        jnp.ones([], dtype=jnp.int32), np.ones([], dtype=np.float64))

  @parameterized.named_parameters(

3 Source : fake_test.py
with Apache License 2.0
from deepmind

  def test_assert_pmapped(self):
    def foo(x):
      return x * 2
    fn_input = jnp.ones((4,))

    _assert_pmapped(foo, fn_input, True)
    with self.assertRaises(AssertionError):
      _assert_pmapped(foo, fn_input, False)

  def test_assert_jitted(self):

3 Source : fake_test.py
with Apache License 2.0
from deepmind

  def test_assert_jitted(self):
    fn_input = jnp.ones((4,))
    def foo(x):
      return x * 2

    _assert_jitted(foo, fn_input, True)
    with self.assertRaises(AssertionError):
      _assert_jitted(foo, fn_input, False)

  @parameterized.named_parameters([

3 Source : fake_test.py
with Apache License 2.0
from deepmind

  def test_fake_jit(self, fake_kwargs, is_jitted):
    fn_input = jnp.ones((4,))
    def foo(x):
      return x * 2

    # Call with context manager
    with fake.fake_jit(**fake_kwargs):
      _assert_jitted(foo, fn_input, is_jitted)

    # Call with start/stop
    ctx = fake.fake_jit(**fake_kwargs)
    ctx.start()
    _assert_jitted(foo, fn_input, is_jitted)
    ctx.stop()

  @parameterized.named_parameters([

3 Source : fake_test.py
with Apache License 2.0
from deepmind

  def test_fake_pmap_(self, is_pmapped, jit_result):
    enable_patching = not is_pmapped

    fn_input = jnp.ones((4,))
    def foo(x):
      return x * 2

    # Call with context manager
    with fake.fake_pmap(enable_patching=enable_patching, jit_result=jit_result):
      _assert_pmapped(foo, fn_input, is_pmapped, jit_result)

    # Call with start/stop
    ctx = fake.fake_pmap(enable_patching=enable_patching, jit_result=jit_result)
    ctx.start()
    _assert_pmapped(foo, fn_input, is_pmapped, jit_result)
    ctx.stop()

  def test_fake_pmap_axis_name(self):

3 Source : fake_test.py
with Apache License 2.0
from deepmind

  def test_pmap_and_jit(self, fake_kwargs, is_pmapped, is_jitted):
    fn_input = jnp.ones((4,))
    def foo(x):
      return x * 2

    # Call with context manager
    with fake.fake_pmap_and_jit(**fake_kwargs):
      _assert_pmapped(foo, fn_input, is_pmapped)
      _assert_jitted(foo, fn_input, is_jitted)

    # Call with start/stop
    ctx = fake.fake_pmap_and_jit(**fake_kwargs)
    ctx.start()
    _assert_pmapped(foo, fn_input, is_pmapped)
    _assert_jitted(foo, fn_input, is_jitted)
    ctx.stop()

  @parameterized.named_parameters([

3 Source : pretrain_common.py
with Apache License 2.0
from deepmind

  def _fake_data_generator(self, image_shape: Tuple[int, int, int, int]):
    mask1 = np.random.uniform(low=0, high=7, size=image_shape + (1,))
    mask2 = np.random.uniform(low=0, high=7, size=image_shape + (1,))
    while True:
      yield {
          'view1': jnp.ones(image_shape + (3,)) * 0.5,
          'view2': jnp.ones(image_shape + (3,)) * 0.3,
          'fh_segmentations1': jnp.array(np.round(mask1), dtype=jnp.uint8),
          'fh_segmentations2': jnp.array(np.round(mask2), dtype=jnp.uint8),
          'labels': jnp.ones([image_shape[0], image_shape[1], 1],
                             dtype=jnp.int64),
      }

  def _build_train_input(self) -> Generator[image_dataset.Batch, None, None]:

3 Source : utils.py
with Apache License 2.0
from deepmind

def rendezvous() -> None:
  """Forces all hosts to check in."""
  with log_activity("rendezvous"):
    x = jnp.ones([jax.local_device_count()])
    x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, "i"), "i")(x))
    if x[0] != jax.device_count():
      raise ValueError(f"Expected {jax.device_count()} got {x}")


class PeriodicAction:

3 Source : utils_test.py
with Apache License 2.0
from deepmind

  def test_bcast_local_devices_tree(self):
    num_devices = jax.local_device_count()
    tree = utils.bcast_local_devices({"ones": jnp.ones([]),
                                      "zeros": jnp.zeros([])})
    self.assertEqual(tree, {"ones": jnp.ones([num_devices]),
                            "zeros": jnp.zeros([num_devices])})


class TestLogActivity(absltest.TestCase):

3 Source : linear_bound_utils.py
with Apache License 2.0
from deepmind

  def initial_params(
      self,
      *inps: Optional[Union[Bound, Tensor]],
  ) -> Nest[Tensor]:
    # If an input is [known to be] a fixed tensor, don't allocate a
    # relaxation parameter for it.
    return [
        .5 * jnp.ones(shape=inp_shape)
        if inp is None or isinstance(inp, Bound) else None
        for inp, inp_shape in zip(inps, self._input_shapes)]

  def project_params(self, relax_params: Nest[Tensor]) -> Nest[Tensor]:

3 Source : nonconvex.py
with Apache License 2.0
from deepmind

  def lower(self) -> Tensor:
    if self._concretized_bounds is None:
      logging.warning('.lower called on a non-concretized bound.'
                      'Returning spurious bounds.')
      return -float('inf') * jnp.ones(self.shape)
    return self._concretized_bounds.lower

  @property

3 Source : nonconvex.py
with Apache License 2.0
from deepmind

  def upper(self) -> Tensor:
    if self._concretized_bounds is None:
      logging.warning('.upper called on a non-concretized bound.'
                      'Returning spurious bounds.')
      return float('inf') * jnp.ones(self.shape)
    return self._concretized_bounds.upper

  def evaluate(self,

3 Source : pga_test.py
with Apache License 2.0
from deepmind

  def setUp(self):
    super().setUp()

    self.prng_key = jax.random.PRNGKey(1234)

    self.bounds = [
        sdp_utils.IntBound(
            lb=-jnp.ones(X_SHAPE),
            ub=jnp.ones(X_SHAPE),
            lb_pre=None,
            ub_pre=None),
        sdp_utils.IntBound(lb=None, ub=None, lb_pre=None, ub_pre=None),
    ]

  def test_intermediate_problem(self):

3 Source : uncertainty_spec_test.py
with Apache License 2.0
from deepmind

  def setUp(self):
    super(UncertaintySpecTest, self).setUp()

    self._prng_seq = hk.PRNGSequence(13579)
    self._n_classes = X_SHAPE[1]

    self.bounds = [
        sdp_utils.IntBound(
            lb_pre=-0.1 * jnp.ones(X_SHAPE),
            ub_pre=0.1 * jnp.ones(X_SHAPE),
            lb=None,
            ub=None)
    ]

  def test_softmax_upper(self):

3 Source : cvxpy_verify_test.py
with Apache License 2.0
from deepmind

def _fgsm_example_and_bound(params, target_label, label):
  model_fn = lambda x: utils.predict_mlp(params, x)
  x = 0.5 * jnp.ones(utils.nn_layer_sizes(params)[0])
  epsilon = 0.5
  x_adv = utils.fgsm_single(model_fn, x, label, target_label, epsilon,
                            num_steps=30, step_size=0.03)
  return x_adv, utils.adv_objective(model_fn, x_adv, label, target_label)

MARGIN = 1e-6

3 Source : privacy_test.py
with Apache License 2.0
from deepmind

  def test_noise_multiplier(self, l2_norm_clip, noise_multiplier):
    """Standard dev. of noise should be l2_norm_clip * noise_multiplier."""
    dp_agg = privacy.differentially_private_aggregate(
        l2_norm_clip=l2_norm_clip,
        noise_multiplier=noise_multiplier,
        seed=1337)
    state = dp_agg.init(None)
    update_fn = self.variant(dp_agg.update)
    expected_std = l2_norm_clip * noise_multiplier

    grads = [jnp.ones((1, 100, 100))]  # batch size 1
    for _ in range(3):
      updates, state = update_fn(grads, state)
      chex.assert_tree_all_close(expected_std,
                                 jnp.std(updates[0]),
                                 atol=0.1 * expected_std)

  def test_aggregated_updates_as_input_fails(self):

3 Source : schedule_test.py
with Apache License 2.0
from deepmind

  def test_static_args(self, static_args):
    @functools.partial(schedule.inject_hyperparams, static_args=static_args)
    def custom_optim(learning_rate, mask):
      return wrappers.masked(transform.scale(-learning_rate), mask)

    optim = custom_optim(
        0.1, functools.partial(jax.tree_map, lambda x: x.ndim > 1))
    params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))]
    grads = params
    state = self.variant(optim.init)(params)
    updates, state = self.variant(optim.update)(grads, state)
    expected_updates = jax.tree_map(lambda x: -0.1 * x if x.ndim > 1 else x,
                                    grads)

    assert set(state.hyperparams.keys()) == {'learning_rate'}, state.hyperparams
    chex.assert_tree_all_close(updates, expected_updates)

  @chex.all_variants

3 Source : schedule_test.py
with Apache License 2.0
from deepmind

  def test_numeric_static_args(self, static_args):
    optim = schedule.inject_hyperparams(
        transform.scale_by_adam, static_args=static_args)(b1=0.9, b2=0.95)

    params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))]
    grads = params
    state = self.variant(optim.init)(params)
    _, state = self.variant(optim.update)(grads, state)

    assert not set(state.hyperparams.keys()).intersection(set(static_args))

  @parameterized.named_parameters(('string', 'lr'), ('list', ['lr']))

3 Source : update_test.py
with Apache License 2.0
from deepmind

  def test_apply_updates(self):
    params = ({'a': jnp.ones((3, 2))}, jnp.ones((1,)))
    grads = jax.tree_map(lambda t: 2 * t, params)
    exp_params = jax.tree_map(lambda t: 3 * t, params)
    new_params = self.variant(update.apply_updates)(params, grads)

    chex.assert_tree_all_close(
        exp_params, new_params, atol=1e-10, rtol=1e-5)

  @chex.all_variants()

3 Source : update_test.py
with Apache License 2.0
from deepmind

  def test_incremental_update(self):
    params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,)))
    params_2 = jax.tree_map(lambda t: 2 * t, params_1)
    exp_params = jax.tree_map(lambda t: 1.5 * t, params_1)
    new_params = self.variant(
        update.incremental_update)(params_2, params_1, 0.5)

    chex.assert_tree_all_close(
        exp_params, new_params, atol=1e-10, rtol=1e-5)

  @chex.all_variants()

3 Source : wrappers_test.py
with Apache License 2.0
from deepmind

  def test_mask_fn(self):
    params = {'a': jnp.ones((1, 2)), 'b': (jnp.ones((1,)), np.ones((1, 2, 3)))}
    mask_fn = lambda p: jax.tree_map(lambda x: x.ndim > 1, p)
    init_fn, update_fn = wrappers.masked(transform.add_decayed_weights(0.1),
                                         mask_fn)
    update_fn = self.variant(update_fn)

    state = self.variant(init_fn)(params)
    grads = jax.tree_map(lambda x: x*2, params)
    updates, state = update_fn(grads, state, params)
    np.testing.assert_allclose(updates['a'], grads['a'] + 0.1*params['a'])
    np.testing.assert_allclose(updates['b'][0], grads['b'][0])
    np.testing.assert_allclose(updates['b'][1],
                               grads['b'][1] + 0.1*params['b'][1])

  @chex.all_variants

3 Source : wrappers_test.py
with Apache License 2.0
from deepmind

  def test_masked_state_structure(self):
    # https://github.com/deepmind/optax/issues/271
    params = {'a': [jnp.ones(1), (jnp.ones(2), jnp.ones(3))],
              'b': {'c': jnp.ones(4), 'd': jnp.ones(5)}}
    mask = {'a': [True, (True, False)], 'b': False}
    tx = wrappers.masked(_build_stateful_sgd(), mask)
    trace = self.variant(tx.init)(params).inner_state[0].trace
    expected_trace = {'a': [jnp.zeros(1), (jnp.zeros(2), None)], 'b': None}
    chex.assert_tree_all_equal_structs(trace, expected_trace)


class MaybeUpdateTest(chex.TestCase):

See More Examples