jax.numpy.max

Here are the examples of the python api jax.numpy.max taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

51 Examples 7

3 Source : decoders.py
with Apache License 2.0
from deepmind

def _decode_graph_diffs(decoders, h_t: _Array, graph_fts: _Array) -> _Array:
  """Decodes graph diffs."""

  gr_emb = jnp.max(h_t, axis=-2)
  g_pred_n = decoders[0](gr_emb)
  g_pred_g = decoders[1](graph_fts)
  preds = jnp.squeeze(g_pred_n + g_pred_g, -1)

  return preds

3 Source : policies.py
with Apache License 2.0
from deepmind

def _mask_invalid_actions(logits, invalid_actions):
  """Returns logits with zero mass to invalid actions."""
  if invalid_actions is None:
    return logits
  chex.assert_equal_shape([logits, invalid_actions])
  logits = logits - jnp.max(logits, axis=-1, keepdims=True)
  # At the end of an episode, all actions can be invalid. A softmax would then
  # produce NaNs, if using -inf for the logits. We avoid the NaNs by using
  # a finite `min_logit` for the invalid actions.
  min_logit = jnp.finfo(logits.dtype).min
  return jnp.where(invalid_actions, min_logit, logits)


def _get_logits_from_probs(probs):

3 Source : policies.py
with Apache License 2.0
from deepmind

def _apply_temperature(logits, temperature):
  """Returns `logits / temperature`, supporting also temperature=0."""
  # The max subtraction prevents +inf after dividing by a small temperature.
  logits = logits - jnp.max(logits, keepdims=True, axis=-1)
  tiny = jnp.finfo(logits.dtype).tiny
  return logits / jnp.maximum(tiny, temperature)

3 Source : seq_halving.py
with Apache License 2.0
from deepmind

def score_considered(considered_visit, gumbel, logits, normalized_qvalues,
                     visit_counts):
  """Returns a score usable for an argmax."""
  # We allow to visit a child, if it is the only considered child.
  low_logit = -1e9
  logits = logits - jnp.max(logits, keepdims=True, axis=-1)
  penalty = jnp.where(
      visit_counts == considered_visit,
      0, -jnp.inf)
  chex.assert_equal_shape([gumbel, logits, normalized_qvalues, penalty])
  return jnp.maximum(low_logit, gumbel + logits + normalized_qvalues) + penalty


def get_sequence_of_considered_visits(max_num_considered_actions,

3 Source : masks_test.py
with Apache License 2.0
from google-research

def create_input(lengths):
  input_tokens = jnp.ones((len(lengths), jnp.max(lengths)), dtype=jnp.float32)
  input_tokens = input_tokens * (
      jnp.arange(jnp.max(lengths))   <   jnp.reshape(lengths, (-1, 1)))
  return input_tokens, lengths


def create_decoder_input(prefix_lengths, target_lengths):

3 Source : masks_test.py
with Apache License 2.0
from google-research

def create_decoder_input(prefix_lengths, target_lengths):
  targets, example_lengths = create_input(prefix_lengths + target_lengths)
  causal, _ = create_input(prefix_lengths)
  causal = jnp.concatenate([
      causal,
      jnp.zeros((causal.shape[0], jnp.max(example_lengths) - causal.shape[1]))
  ],
                           axis=1)
  return targets, causal, example_lengths


class PromptDecoderAttentionTest(parameterized.TestCase):

3 Source : masks_test.py
with Apache License 2.0
from google-research

def create_decoder_input(prefix_lengths, target_lengths):
  targets, example_lengths = create_input(prefix_lengths + target_lengths)
  causal, _ = create_input(prefix_lengths)
  causal = jnp.concatenate([
      causal,
      jnp.zeros((causal.shape[0], jnp.max(example_lengths) - causal.shape[1]))
  ],
                           axis=1)
  return targets, causal, example_lengths


class CreatePromptDecoderOnlyMaskTest(parameterized.TestCase):

3 Source : ode.py
with MIT License
from jacobjinkelly

def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
                      dfactor=0.2, order=5.0):
  """Compute optimal Runge-Kutta stepsize."""
  mean_error_ratio = np.max(mean_error_ratio)
  dfactor = np.where(mean_error_ratio   <   1, 1.0, dfactor)

  err_ratio = np.sqrt(mean_error_ratio)
  factor = np.maximum(1.0 / ifactor,
                      np.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
  return np.where(mean_error_ratio == 0, last_step * ifactor, last_step / factor)

def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=np.inf):

3 Source : dqn.py
with MIT License
from ku2482

    def _calculate_target(
        self,
        params: hk.Params,
        params_target: hk.Params,
        reward: np.ndarray,
        done: np.ndarray,
        next_state: np.ndarray,
    ) -> jnp.ndarray:
        if self.double_q:
            next_action = self._forward(params, next_state)[..., None]
            next_q = self._calculate_value(params_target, next_state, next_action)
        else:
            next_q = jnp.max(self.net.apply(params_target, next_state), axis=-1, keepdims=True)
        return jax.lax.stop_gradient(reward + (1.0 - done) * self.discount * next_q)

    @partial(jax.jit, static_argnums=0)

3 Source : numpy_ops.py
with GNU General Public License v3.0
from PKU-NIP-Lab

def max(a, axis=None, keepdims=None, initial=None, where=None):
  a = _remove_jaxarray(a)
  r = jnp.max(a, axis=axis, keepdims=keepdims, initial=initial, where=where)
  return r if axis is None else JaxArray(r)


def min(a, axis=None, keepdims=None, initial=None, where=None):

3 Source : test_ssm.py
with MIT License
from SamDuffield

    def _test_bootstrap(self):
        if not hasattr(self, 'sim_samps'):
            self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
        self.pf_samps = run_particle_filter_for_marginals(self.ssm_scenario,
                                                          BootstrapFilter(),
                                                          self.sim_samps.y,
                                                          self.t,
                                                          random.PRNGKey(0),
                                                          n=self.n)

        npt.assert_array_less(self.sim_samps.x[:, 0], jnp.max(self.pf_samps.value, axis=1)[:, 0])
        npt.assert_array_less(jnp.min(self.pf_samps.value, axis=1)[:, 0], self.sim_samps.x[:, 0])

    def backward_preprocess(self):

3 Source : test_ssm.py
with MIT License
from SamDuffield

    def backward_postprocess(self):
        npt.assert_array_less(self.sim_samps.x[:, 0], jnp.max(self.backward_samps.value, axis=1)[:, 0])
        npt.assert_array_less(jnp.min(self.backward_samps.value, axis=1)[:, 0], self.sim_samps.x[:, 0])
        self.assertFalse(hasattr(self.backward_samps, 'log_weight'))

    def _test_ffbsi_full(self):

3 Source : test_ssm.py
with MIT License
from SamDuffield

    def _test_online_smoothing_pf_full(self):
        if not hasattr(self, 'sim_samps'):
            self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
        pf = BootstrapFilter()
        len_t = len(self.t)
        rkeys = random.split(random.PRNGKey(0), len_t)

        particles = initiate_particles(self.ssm_scenario, pf, self.n,
                                       rkeys[0], self.sim_samps.y[0], self.t[0])
        for i in range(1, len_t):
            particles = propagate_particle_smoother(self.ssm_scenario, pf, particles,
                                                    self.sim_samps.y[i], self.t[i], rkeys[i], 3,
                                                    False)
        npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.max(particles.value, axis=1)[:, 0]) > 0).mean(), 0.1)
        npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.min(particles.value, axis=1)[:, 0])   <   0).mean(), 0.1)

    def _test_online_smoothing_pf_rejection(self):

3 Source : test_ssm.py
with MIT License
from SamDuffield

    def _test_online_smoothing_pf_rejection(self):
        if not hasattr(self, 'sim_samps'):
            self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
        pf = BootstrapFilter()
        len_t = len(self.t)
        rkeys = random.split(random.PRNGKey(0), len_t)

        particles = initiate_particles(self.ssm_scenario, pf, self.n,
                                       rkeys[0], self.sim_samps.y[0], self.t[0])
        for i in range(1, len_t):
            particles = propagate_particle_smoother(self.ssm_scenario, pf, particles,
                                                    self.sim_samps.y[i], self.t[i], rkeys[i], 3,
                                                    False, maximum_rejections=10)

        npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.max(particles.value, axis=1)[:, 0]) > 0).mean(), 0.1)
        npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.min(particles.value, axis=1)[:, 0])   <   0).mean(), 0.1)

    def _test_online_smoothing_bs_full(self):

3 Source : test_ssm.py
with MIT License
from SamDuffield

    def _test_online_smoothing_bs_full(self):
        if not hasattr(self, 'sim_samps'):
            self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
        pf = BootstrapFilter()
        len_t = len(self.t)
        rkeys = random.split(random.PRNGKey(0), len_t)

        particles = initiate_particles(self.ssm_scenario, pf, self.n,
                                       rkeys[0], self.sim_samps.y[0], self.t[0])
        for i in range(1, len_t):
            particles = propagate_particle_smoother(self.ssm_scenario, pf, particles,
                                                    self.sim_samps.y[i], self.t[i], rkeys[i], 3,
                                                    True)
        npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.max(particles.value, axis=1)[:, 0]) > 0).mean(), 0.1)
        npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.min(particles.value, axis=1)[:, 0])   <   0).mean(), 0.1)

    def _test_online_smoothing_bs_rejection(self):

3 Source : test_ssm.py
with MIT License
from SamDuffield

    def _test_online_smoothing_bs_rejection(self):
        if not hasattr(self, 'sim_samps'):
            self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
        pf = BootstrapFilter()
        len_t = len(self.t)
        rkeys = random.split(random.PRNGKey(0), len_t)

        particles = initiate_particles(self.ssm_scenario, pf, self.n,
                                       rkeys[0], self.sim_samps.y[0], self.t[0])
        for i in range(1, len_t):
            particles = propagate_particle_smoother(self.ssm_scenario, pf, particles,
                                                    self.sim_samps.y[i], self.t[i], rkeys[i], 3,
                                                    True, maximum_rejections=10)
        npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.max(particles.value, axis=1)[:, 0]) > 0).mean(), 0.1)
        npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.min(particles.value, axis=1)[:, 0])   <   0).mean(), 0.1)


if __name__ == '__main__':

3 Source : jax_loss.py
with GNU Affero General Public License v3.0
from synsense

def logsoftmax(x: np.ndarray, temperature: float = 1.0) -> np.ndarray:
    """
    Efficient implementation of the log softmax function

    .. math ::

        log S(x, \\tau) = (l / \\tau) - \\log \\Sigma { \\exp (l / \\tau) }

        l = x - \\max (x)

    Args:
        x (np.ndarray): Input vector of scores
        temperature (float): Temperature :math:`\\tau` of the softmax. As :math:`\\tau \\rightarrow 0`, the function becomes a hard :math:`\\max` operation. Default: ``1.0``.

    Returns:
        np.ndarray: The output of the logsoftmax.
    """
    logits = x - np.max(x)
    return (logits / temperature) - np.log(np.sum(np.exp(logits / temperature)))

3 Source : embedding_softmax.py
with Apache License 2.0
from tensorflow

  def compute_z_loss(self, logits):
    """Returns a z_loss regularization which stablize logits."""
    # Applies stop_gradient to max_logit instead of logits.
    max_logit = jax.lax.stop_gradient(jnp.max(logits, axis=-1, keepdims=True))
    exp_x = jnp.exp(logits - max_logit)
    sum_exp_x = jnp.sum(exp_x, axis=-1, keepdims=True)
    log_z = jnp.log(sum_exp_x) + max_logit
    return jnp.square(log_z)

  def fprop(self,

3 Source : poolings.py
with Apache License 2.0
from tensorflow

  def fprop(self, inputs: JTensor) -> JTensor:
    """Applies global spatial pooling to inputs.

    Args:
      inputs: An input tensor.

    Returns:
      Output tensor with global pooling applied.
    """
    p = self.params
    if p.pooling_type == 'MAX':
      outputs = jnp.max(inputs, p.pooling_dims, keepdims=p.keepdims)
    elif p.pooling_type == 'AVG':
      outputs = jnp.mean(inputs, p.pooling_dims, keepdims=p.keepdims)
    return outputs

0 Source : jaxent.py
with MIT License
from adamhaber

    def train(self,data,data_kind="samples",data_n_samp=None,alpha=0.32,lr=1e-1,threshold=1.,kind=None,n_samps=5000):
        """fit a maximum entropy model to data.
        
        Parameters
        ----------
        data : array_like
            either an array of binary samples, or an array of desired marginals
        data_kind : str, optional
            "samples" - data samples are passed
            "marginals" - desired marginals are passed
        data_n_samp : int, optional
            number of trials, needed to compute confidence intervals
        alpha : float, optional
            confidence level
        lr : float, optional
            learning rate, by default 1e-1
        threshold : float, optional
            maximum allowed difference between model marginals and empirical marginals, in empirical standard deviations units. by default 1
        kind : str, optional
            "exhuastive" means analytical computation of model marginals, "sample" means MCMC estimation, by default None and estimated from the number of units N.
        n_samps : int, optional
            number of samples to generate in each MCMC estimation of model marginals, by default 5000
        """
        @jit
        def _training_step_ex(i,opt_state):
            params = get_params(opt_state)
            model_marg = self.calc_marginals_ex(self._calc_p(params))
            g = self.empirical_marginals-model_marg
            return opt_update(i, g, opt_state),model_marg
        
        @jit
        def _training_step(i,opt_state):
            params = get_params(opt_state)
            samples = self._sample(random.PRNGKey(i),n_samps,params, -1, -1)
            model_marg = self.calc_marginals(samples)
            g = self.empirical_marginals-model_marg
            return opt_update(i, g, opt_state),model_marg
        
        @jit
        def _training_loop(loop_carry):
                i,opt_state, params,_ = loop_carry
                opt_state,marginals = step(i,opt_state)
                params = get_params(opt_state)
                return i+1,opt_state, params,marginals

        if kind is None:
            kind = onp.where(self.N>20,'sample','exhuastive')

        if kind=='exhuastive':
            step = _training_step_ex
            if self.words is None:
                self.create_words()
            self.model_marginals = self.calc_marginals_ex(self._calc_p(self.factors))
        elif kind=='sample':
            step = _training_step
            self.model_marginals = self.calc_marginals(self._sample(random.PRNGKey(0),n_samps,self.factors,-1,-1))

        lower, upper = self.calc_empirical_marginals_and_stds(data,data_kind,data_n_samp,alpha)
        self.empirical_std = upper-lower
    
        opt_init, opt_update, get_params = optimizers.adam(lr)
        training_steps, opt_state, params,marginals = while_loop(lambda x: np.max(self.calc_deviations(x[3])) > threshold,_training_loop, (0,opt_init(self.factors), self.factors,self.model_marginals))
        
        self.factors = params
        self.training_steps = training_steps
        self.model_marginals = marginals
        self.trained = True

        if kind=='exhuastive':
            self.p_model = self._calc_p(self.factors)
            self.Z = np.exp(self.calc_logZ(self.calc_logp_unnormed(self.factors)))  
            self.entropy = self._calc_entropy(self.p_model)

    def _calc_entropy(self,p):

0 Source : networks.py
with MIT License
from brentyi

    def __call__(self, inputs: jnp.ndarray):  # type: ignore
        x = inputs
        N = x.shape[0]
        assert x.shape == (N, 120, 120, 3), x.shape

        # Some conv layers
        for _ in range(3):
            x = nn.Conv(features=32, kernel_size=(3, 3), kernel_init=relu_layer_init)(x)
            x = nn.relu(x)
        x = nn.Conv(features=32, kernel_size=(3, 3), kernel_init=linear_layer_init)(x)

        # Channel-wise max pool
        x = jnp.max(x, axis=3, keepdims=True)

        # Spanning mean pools (to regularize X/Y coordinate regression)
        assert x.shape == (N, 120, 120, 1)
        x_horizontal = nn.avg_pool(x, window_shape=(120, 1))
        x_vertical = nn.avg_pool(x, window_shape=(1, 120))

        # Concatenate, feed through MLP
        x = jnp.concatenate(
            [x_horizontal.reshape((N, -1)), x_vertical.reshape((N, -1))], axis=1
        )
        assert x.shape == (N, 240)
        x = MLP.make(units=32, layers=3, output_dim=self.output_dim)(x)

        return x


def make_position_cnn(seed: int = 0) -> Tuple[DiskVirtualSensor, Pytree]:

0 Source : tmp_potential.py
with BSD 3-Clause "New" or "Revised" License
from CCQC

def tmp_potential(geom, basis, charges):
    """
    Build potential one-electron integrals array
    """
    coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
    nbf = get_nbf(basis)
    nprim = coeffs.shape[0]
    max_am = np.max(ams)
    A_vals = np.zeros(2*max_am+1)

    # Save various AM distributions for indexing
    # Obtain all possible primitive duet index combinations 
    primitive_duets = cartesian_product(np.arange(nprim), np.arange(nprim))

    with loops.Scope() as s:
      s.V = np.zeros((nbf,nbf))
      s.a = 0  # center A angular momentum iterator 
      s.b = 0  # center B angular momentum iterator 

      for prim_duet in s.range(primitive_duets.shape[0]):
        p1,p2 = primitive_duets[prim_duet]
        coef = coeffs[p1] * coeffs[p2]
        aa, bb = exps[p1], exps[p2]
        atom1, atom2 = atoms[p1], atoms[p2]
        am1, am2 = ams[p1], ams[p2]
        A, B = geom[atom1], geom[atom2]
        ld1, ld2 = am_leading_indices[am1], am_leading_indices[am2]

        gamma = aa + bb
        prefactor = np.exp(-aa * bb * np.dot(A-B,A-B) / gamma)
        P = (aa * A + bb * B) / gamma
        # Maximum angular momentum: hard coded
        # Precompute all powers up to 2+max_am of Pi-Ai, Pi-Bi. 
        # We need 2+max_am since kinetic requires incrementing angluar momentum by +2
        PA_pow = np.power(np.broadcast_to(P-A, (max_am+3,3)).T, np.arange(max_am+3))
        PB_pow = np.power(np.broadcast_to(P-B, (max_am+3,3)).T, np.arange(max_am+3))

        # For potential integrals, we need the difference between 
        # the gaussian product center P and ALL atoms in the molecule, 
        # and then take all possible powers up to 2*max_am. 
        # We pre-collect this into a 3d array, and then just pull out what we need via indexing in the loops, so they need not be recomputed.
        # The resulting array has dimensions (atom, cartesian component, power) so index (0, 1, 3) would return (Py - atom0_y)^3
        P_minus_geom = np.broadcast_to(P, geom.shape) - geom
        Pgeom_pow = np.power(np.transpose(np.broadcast_to(P_minus_geom, (2*max_am + 1,geom.shape[0],geom.shape[1])), (1,2,0)), np.arange(2*max_am + 1))
        # All possible np.dot(P-atom,P-atom) 
        rcp2 = np.einsum('ij,ij->i', P_minus_geom, P_minus_geom)
        # All needed (and unneeded, for am   <   max_am) boys function evaluations
        boys_arg = np.broadcast_to(rcp2 * gamma, (2*max_am+1, geom.shape[0]))
        boys_nu = np.tile(np.arange(2*max_am+1), (geom.shape[0],1)).T
        boys_eval = boys(boys_nu,boys_arg)

        s.a = 0
        for _ in s.while_range(lambda: s.a  <  dims[p1]):
          s.b = 0
          for _ in s.while_range(lambda: s.b  <  dims[p2]):
            # Gather angular momentum and index
            la,ma,na = angular_momentum_combinations[s.a + ld1]
            lb,mb,nb = angular_momentum_combinations[s.b + ld2]
            # To only create unique indices, need to have separate indices arrays for i and j.
            i = indices[p1] + s.a
            j = indices[p2] + s.b
            # Compute one electron integrals and add to appropriate index
            potential_int = potential(la,ma,na,lb,mb,nb,aa,bb,PA_pow,PB_pow,Pgeom_pow,boys_eval,prefactor,charges,A_vals) * coef
            s.V = jax.ops.index_add(s.V, jax.ops.index[i,j],  potential_int)

            s.b += 1
          s.a += 1
      return s.V


0 Source : oei.py
with BSD 3-Clause "New" or "Revised" License
from CCQC

def oei_arrays(geom, basis, charges):
    """
    Build one electron integral arrays (overlap, kinetic, and potential integrals)
    """
    coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
    nbf = get_nbf(basis)
    nprim = coeffs.shape[0]
    max_am = jnp.max(ams)
    A_vals = jnp.zeros(2*max_am+1)

    # Save various AM distributions for indexing
    # Obtain all possible primitive quartet index combinations 
    primitive_duets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim))

    with loops.Scope() as s:
      s.oei = jnp.zeros((3,nbf,nbf))
      s.a = 0  # center A angular momentum iterator 
      s.b = 0  # center B angular momentum iterator 

      for prim_duet in s.range(primitive_duets.shape[0]):
        p1,p2 = primitive_duets[prim_duet]
        coef = coeffs[p1] * coeffs[p2]
        aa, bb = exps[p1], exps[p2]
        atom1, atom2 = atoms[p1], atoms[p2]
        am1, am2 = ams[p1], ams[p2]
        A, B = geom[atom1], geom[atom2]
        ld1, ld2 = am_leading_indices[am1], am_leading_indices[am2]

        gamma = aa + bb
        prefactor = jnp.exp(-aa * bb * jnp.dot(A-B,A-B) / gamma)
        P = (aa * A + bb * B) / gamma
        # Maximum angular momentum: hard coded
        #max_am = 3 # f function support
        # Precompute all powers up to 2+max_am of Pi-Ai, Pi-Bi. 
        # We need 2+max_am since kinetic requires incrementing angluar momentum by +2
        PA_pow = jnp.power(jnp.broadcast_to(P-A, (max_am+3,3)).T, jnp.arange(max_am+3))
        PB_pow = jnp.power(jnp.broadcast_to(P-B, (max_am+3,3)).T, jnp.arange(max_am+3))

        # For potential integrals, we need the difference between 
        # the gaussian product center P and ALL atoms in the molecule, 
        # and then take all possible powers up to 2*max_am. 
        # We pre-collect this into a 3d array, and then just pull out what we need via indexing in the loops, so they need not be recomputed.
        # The resulting array has dimensions (atom, cartesian component, power) so index (0, 1, 3) would return (Py - atom0_y)^3
        P_minus_geom = jnp.broadcast_to(P, geom.shape) - geom
        Pgeom_pow = jnp.power(jnp.transpose(jnp.broadcast_to(P_minus_geom, (2*max_am + 1,geom.shape[0],geom.shape[1])), (1,2,0)), jnp.arange(2*max_am + 1))
        # All possible jnp.dot(P-atom,P-atom) 
        rcp2 = jnp.einsum('ij,ij->i', P_minus_geom, P_minus_geom)
        # All needed (and unneeded, for am   <   max_am) boys function evaluations
        boys_arg = jnp.broadcast_to(rcp2 * gamma, (2*max_am+1, geom.shape[0]))
        boys_nu = jnp.tile(jnp.arange(2*max_am+1), (geom.shape[0],1)).T
        boys_eval = boys(boys_nu,boys_arg)

        s.a = 0
        for _ in s.while_range(lambda: s.a  <  dims[p1]):
          s.b = 0
          for _ in s.while_range(lambda: s.b  <  dims[p2]):
            # Gather angular momentum and index
            la,ma,na = angular_momentum_combinations[s.a + ld1]
            lb,mb,nb = angular_momentum_combinations[s.b + ld2]
            # To only create unique indices, need to have separate indices arrays for i and j.
            i = indices[p1] + s.a
            j = indices[p2] + s.b
            # Compute one electron integrals and add to appropriate index
            overlap_int = overlap(la,ma,na,lb,mb,nb,aa,bb,PA_pow,PB_pow,prefactor) * coef
            kinetic_int = kinetic(la,ma,na,lb,mb,nb,aa,bb,PA_pow,PB_pow,prefactor) * coef
            potential_int = potential(la,ma,na,lb,mb,nb,aa,bb,PA_pow,PB_pow,Pgeom_pow,boys_eval,prefactor,charges,A_vals) * coef
            s.oei = jax.ops.index_add(s.oei, ([0,1,2],[i,i,i],[j,j,j]), (overlap_int, kinetic_int, potential_int))

            s.b += 1
          s.a += 1
    S, T, V = s.oei[0], s.oei[1], s.oei[2]
    return S, T, V

0 Source : tei.py
with BSD 3-Clause "New" or "Revised" License
from CCQC

def tei_array(geom, basis):
    """
    Build two electron integral array from a jax.numpy array of the cartesian geometry in Bohr, 
    and a basis dictionary as defined by basis_utils.build_basis_set
    We have to loop over primitives rather than shells because JAX needs intermediates to be consistent 
    sizes in order to compile.
    """
    # Smush primitive data together into vectors
    coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
    nbf = get_nbf(basis)
    max_am = jnp.max(ams)
    max_am_idx = max_am * 4 + 1 
    #TODO add excpetion raise if angular momentum is too high
    B_vals = jnp.zeros(4*max_am+1)  
    nprim = coeffs.shape[0]
    # Obtain all possible primitive quartet index combinations 
    primitive_quartets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim))

    #print("Number of basis functions: ", nbf)
    #print("Number of primitve quartets: ", primitive_quartets.shape[0])

    #TODO Experimental: precompute quantities and lookup inside loop
    # Compute all possible Gaussian products for this basis set
    aa_plus_bb = jnp.broadcast_to(exps, (nprim,nprim)) + jnp.transpose(jnp.broadcast_to(exps, (nprim,nprim)), (1,0))
    aa_times_A = jnp.einsum('i,ij->ij', exps, geom[atoms])
    aaxA_plus_bbxB = aa_times_A[:,None,:] + aa_times_A[None,:,:]
    gaussian_products = jnp.einsum('ijk,ij->ijk', aaxA_plus_bbxB, 1/aa_plus_bb)  

    # Compute all rab2 (rcd2), every possible jnp.dot(A-B,A-B)
    natom = geom.shape[0]
    tmpA = jnp.broadcast_to(geom, (natom,natom,3))
    AminusB = (tmpA - jnp.transpose(tmpA, (1,0,2)))
    AmBdot = jnp.einsum('ijk,ijk->ij', AminusB, AminusB) # shape: (natom,natom)

    # Compute all differences between gaussian product centers with all atom centers
    tmpP = jnp.tile(gaussian_products, natom).reshape(nprim,nprim,natom,3)
    PminusA = tmpP - jnp.broadcast_to(geom, tmpP.shape)

    # Commpute all powers (up to max_am) of differences between gaussian product centers and atom centers
    # Shape: (nprim, nprim, natom, 3, max_am+1). In loop index PA_pow as [p1,p2,atoms[p1],:,:]
    PminusA_pow = jnp.power(jnp.transpose(jnp.broadcast_to(PminusA, (max_am+1,nprim,nprim,natom,3)), (1,2,3,4,0)), jnp.arange(max_am+1))

    with loops.Scope() as s:
      s.G = jnp.zeros((nbf,nbf,nbf,nbf))
      s.a = 0  # center A angular momentum iterator 
      s.b = 0  # center B angular momentum iterator 
      s.c = 0  # center C angular momentum iterator 
      s.d = 0  # center D angular momentum iterator 

      # Loop over primitive quartets, compute integral, add to appropriate index in G
      for prim_quar in s.range(primitive_quartets.shape[0]):
        # Load in primitive indices, coeffs, exponents, centers, angular momentum index, and leading placement index in TEI array
        p1,p2,p3,p4 = primitive_quartets[prim_quar] 
        coef = coeffs[p1] * coeffs[p2] * coeffs[p3] * coeffs[p4]
        aa, bb, cc, dd = exps[p1], exps[p2], exps[p3], exps[p4]
        ld1, ld2, ld3, ld4 = am_leading_indices[ams[p1]],am_leading_indices[ams[p2]],am_leading_indices[ams[p3]],am_leading_indices[ams[p4]]
        idx1, idx2, idx3, idx4 = indices[p1],indices[p2],indices[p3],indices[p4],
        #A, B, C, D = geom[atoms[p1]], geom[atoms[p2]], geom[atoms[p3]], geom[atoms[p4]]

        # Compute common intermediates before looping over AM distributions.
        # Avoids redundant recomputations/reassignment for all classes other than (ss|ss).
        #AB = A - B
        #CD = C - D
        #rab2 = jnp.dot(AB,AB)
        #rcd2 = jnp.dot(CD,CD)
        #P = (aa * A + bb * B) / gamma1
        #Q = (cc * C + dd * D) / gamma2
        gamma1 = aa + bb
        gamma2 = cc + dd

        #TODO
        P = gaussian_products[p1,p2]
        Q = gaussian_products[p3,p4]
        rab2 = AmBdot[atoms[p1],atoms[p2]]
        rcd2 = AmBdot[atoms[p3],atoms[p4]]
        #PA = PminusA[p1,p2,atoms[p1]]
        #PB = PminusA[p1,p2,atoms[p2]]
        #QC = PminusA[p3,p4,atoms[p3]]
        #QD = PminusA[p3,p4,atoms[p4]]
        #TODO

        PQ = P - Q
        rpq2 = jnp.dot(PQ,PQ)
        delta = 0.25*(1/gamma1+1/gamma2)

        boys_arg = 0.25 * rpq2 / delta
        boys_eval = boys(jnp.arange(max_am_idx), boys_arg) 

        # Need all powers of Pi-Ai,Pi-Bi,Qi-Ci,Qi-Di (i=x,y,z) up to max_am and Qi-Pi up to max_am_idx
        # note: this computes unncessary quantities for lower angular momentum, 
        # but avoids repeated computation of the same quantities in loops for higher angular momentum

        #PA_pow = jnp.power(jnp.broadcast_to(P-A, (max_am+1,3)).T, jnp.arange(max_am+1))
        #PB_pow = jnp.power(jnp.broadcast_to(P-B, (max_am+1,3)).T, jnp.arange(max_am+1))
        #QC_pow = jnp.power(jnp.broadcast_to(Q-C, (max_am+1,3)).T, jnp.arange(max_am+1))
        #QD_pow = jnp.power(jnp.broadcast_to(Q-D, (max_am+1,3)).T, jnp.arange(max_am+1))

        PA_pow = PminusA_pow[p1,p2,atoms[p1],:,:]
        PB_pow = PminusA_pow[p1,p2,atoms[p2],:,:]
        QC_pow = PminusA_pow[p3,p4,atoms[p3],:,:]
        QD_pow = PminusA_pow[p3,p4,atoms[p4],:,:]

        QP_pow = jnp.power(jnp.broadcast_to(Q-P, (max_am_idx,3)).T, jnp.arange(max_am_idx))
        # Gamma powers are negative, up to -(l1+l2). 
        # Make array such that the given negative index returns the same negative power.
        g1_pow = jnp.power(4*gamma1, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1)) 
        g2_pow = jnp.power(4*gamma2, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1)) 
        oodelta_pow = jnp.power(1 / delta, jnp.arange(max_am_idx))  # l1 + l2 + l3 + l4 + 1

        prefactor = 34.986836655249726 / (gamma1*gamma2*jnp.sqrt(gamma1+gamma2)) \
                    * jnp.exp(-aa*bb*rab2/gamma1 + -cc*dd*rcd2/gamma2) * coef

        # TODO is there symmetry here?
        s.a = 0
        for _ in s.while_range(lambda: s.a   <   dims[p1]):
          s.b = 0
          for _ in s.while_range(lambda: s.b  <  dims[p2]):
            s.c = 0
            for _ in s.while_range(lambda: s.c  <  dims[p3]):
              s.d = 0
              for _ in s.while_range(lambda: s.d  <  dims[p4]):
                # Collect angular momentum and index in G
                la, ma, na = angular_momentum_combinations[s.a + ld1]
                lb, mb, nb = angular_momentum_combinations[s.b + ld2]
                lc, mc, nc = angular_momentum_combinations[s.c + ld3]
                ld, md, nd = angular_momentum_combinations[s.d + ld4]
                i = idx1 + s.a
                j = idx2 + s.b
                k = idx3 + s.c
                l = idx4 + s.d
                # Compute the primitive quartet tei and add to appropriate index in G
                Bx = B_array(la,lb,lc,ld,PA_pow[0],PB_pow[0],QC_pow[0],QD_pow[0],QP_pow[0],g1_pow,g2_pow,oodelta_pow,B_vals)
                By = B_array(ma,mb,mc,md,PA_pow[1],PB_pow[1],QC_pow[1],QD_pow[1],QP_pow[1],g1_pow,g2_pow,oodelta_pow,B_vals)
                Bz = B_array(na,nb,nc,nd,PA_pow[2],PB_pow[2],QC_pow[2],QD_pow[2],QP_pow[2],g1_pow,g2_pow,oodelta_pow,B_vals)

                with loops.Scope() as S:
                  S.primitive = 0.
                  S.I = 0
                  S.J = 0
                  S.K = 0
                  for _ in S.while_range(lambda: S.I  <  la + lb + lc + ld + 1):
                    S.J = 0 
                    tmp = Bx[S.I] 
                    for _ in S.while_range(lambda: S.J  <  ma + mb + mc + md + 1):
                      S.K = 0 
                      tmp *= By[S.J] 
                      for _ in S.while_range(lambda: S.K  <  na + nb + nc + nd + 1):
                        tmp *= Bz[S.K] * boys_eval[S.I + S.J + S.K]
                        S.primitive += tmp
                        S.K += 1
                      S.J += 1
                    S.I += 1
                tei = prefactor * S.primitive
                s.G = jax.ops.index_add(s.G, jax.ops.index[i,j,k,l], tei) 

                s.d += 1
              s.c += 1
            s.b += 1
          s.a += 1
      return s.G

0 Source : decoders.py
with Apache License 2.0
from deepmind

def _decode_graph_fts(decoders, t: str, h_t: _Array,
                      graph_fts: _Array) -> _Array:
  """Decodes graph features."""

  gr_emb = jnp.max(h_t, axis=-2)
  pred_n = decoders[0](gr_emb)
  pred_g = decoders[1](graph_fts)
  pred = pred_n + pred_g
  if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
    preds = jnp.squeeze(pred, -1)
  elif t == _Type.CATEGORICAL:
    preds = pred
  elif t == _Type.POINTER:
    pred_2 = decoders[2](h_t)
    ptr_p = jnp.matmul(
        jnp.expand_dims(pred, 1), jnp.transpose(pred_2, (0, 2, 1)))
    preds = jnp.squeeze(ptr_p, 1)

  return preds


def maybe_decode_diffs(

0 Source : agent.py
with Apache License 2.0
from deepmind

  def __init__(
      self,
      preprocessor: processors.Processor,
      sample_network_input: jnp.ndarray,
      network: parts.Network,
      optimizer: optax.GradientTransformation,
      transition_accumulator: replay_lib.TransitionAccumulator,
      replay: replay_lib.PrioritizedTransitionReplay,
      batch_size: int,
      exploration_epsilon: Callable[[int], float],
      min_replay_capacity_fraction: float,
      learn_period: int,
      target_network_update_period: int,
      grad_error_bound: float,
      rng_key: parts.PRNGKey,
  ):
    self._preprocessor = preprocessor
    self._replay = replay
    self._transition_accumulator = transition_accumulator
    self._batch_size = batch_size
    self._exploration_epsilon = exploration_epsilon
    self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
    self._learn_period = learn_period
    self._target_network_update_period = target_network_update_period

    # Initialize network parameters and optimizer.
    self._rng_key, network_rng_key = jax.random.split(rng_key)
    self._online_params = network.init(network_rng_key,
                                       sample_network_input[None, ...])
    self._target_params = self._online_params
    self._opt_state = optimizer.init(self._online_params)

    # Other agent state: last action, frame count, etc.
    self._action = None
    self._frame_t = -1  # Current frame index.
    self._statistics = {'state_value': np.nan}
    self._max_seen_priority = 1.

    # Define jitted loss, update, and policy functions here instead of as
    # class methods, to emphasize that these are meant to be pure functions
    # and should not access the agent object's state via `self`.

    def loss_fn(online_params, target_params, transitions, weights, rng_key):
      """Calculates loss given network parameters and transitions."""
      _, *apply_keys = jax.random.split(rng_key, 4)
      q_tm1 = network.apply(online_params, apply_keys[0],
                            transitions.s_tm1).q_values
      q_t = network.apply(online_params, apply_keys[1],
                          transitions.s_t).q_values
      q_target_t = network.apply(target_params, apply_keys[2],
                                 transitions.s_t).q_values
      td_errors = _batch_double_q_learning(
          q_tm1,
          transitions.a_tm1,
          transitions.r_t,
          transitions.discount_t,
          q_target_t,
          q_t,
      )
      td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                     grad_error_bound)
      losses = rlax.l2_loss(td_errors)
      chex.assert_shape((losses, weights), (self._batch_size,))
      # This is not the same as using a huber loss and multiplying by weights.
      loss = jnp.mean(losses * weights)
      return loss, td_errors

    def update(rng_key, opt_state, online_params, target_params, transitions,
               weights):
      """Computes learning update from batch of replay transitions."""
      rng_key, update_key = jax.random.split(rng_key)
      d_loss_d_params, td_errors = jax.grad(
          loss_fn, has_aux=True)(online_params, target_params, transitions,
                                 weights, update_key)
      updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state)
      new_online_params = optax.apply_updates(online_params, updates)
      return rng_key, new_opt_state, new_online_params, td_errors

    self._update = jax.jit(update)

    def select_action(rng_key, network_params, s_t, exploration_epsilon):
      """Samples action from eps-greedy policy wrt Q-values at given state."""
      rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
      q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0]
      a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
      v_t = jnp.max(q_t, axis=-1)
      return rng_key, a_t, v_t

    self._select_action = jax.jit(select_action)

  def step(self, timestep: dm_env.TimeStep) -> parts.Action:

0 Source : agent.py
with Apache License 2.0
from deepmind

  def __init__(
      self,
      preprocessor: processors.Processor,
      sample_network_input: jnp.ndarray,
      network: parts.Network,
      support: jnp.ndarray,
      optimizer: optax.GradientTransformation,
      transition_accumulator: Any,
      replay: replay_lib.PrioritizedTransitionReplay,
      batch_size: int,
      min_replay_capacity_fraction: float,
      learn_period: int,
      target_network_update_period: int,
      rng_key: parts.PRNGKey,
  ):
    self._preprocessor = preprocessor
    self._replay = replay
    self._transition_accumulator = transition_accumulator
    self._batch_size = batch_size
    self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
    self._learn_period = learn_period
    self._target_network_update_period = target_network_update_period

    # Initialize network parameters and optimizer.
    self._rng_key, network_rng_key = jax.random.split(rng_key)
    self._online_params = network.init(network_rng_key,
                                       sample_network_input[None, ...])
    self._target_params = self._online_params
    self._opt_state = optimizer.init(self._online_params)

    # Other agent state: last action, frame count, etc.
    self._action = None
    self._frame_t = -1  # Current frame index.
    self._statistics = {'state_value': np.nan}
    self._max_seen_priority = 1.

    # Define jitted loss, update, and policy functions here instead of as
    # class methods, to emphasize that these are meant to be pure functions
    # and should not access the agent object's state via `self`.

    def loss_fn(online_params, target_params, transitions, weights, rng_key):
      """Calculates loss given network parameters and transitions."""
      _, *apply_keys = jax.random.split(rng_key, 4)
      logits_q_tm1 = network.apply(online_params, apply_keys[0],
                                   transitions.s_tm1).q_logits
      q_t = network.apply(online_params, apply_keys[1],
                          transitions.s_t).q_values
      logits_q_target_t = network.apply(target_params, apply_keys[2],
                                        transitions.s_t).q_logits
      losses = _batch_categorical_double_q_learning(
          support,
          logits_q_tm1,
          transitions.a_tm1,
          transitions.r_t,
          transitions.discount_t,
          support,
          logits_q_target_t,
          q_t,
      )
      loss = jnp.mean(losses * weights)
      chex.assert_shape((losses, weights), (self._batch_size,))
      return loss, losses

    def update(rng_key, opt_state, online_params, target_params, transitions,
               weights):
      """Computes learning update from batch of replay transitions."""
      rng_key, update_key = jax.random.split(rng_key)
      d_loss_d_params, losses = jax.grad(
          loss_fn, has_aux=True)(online_params, target_params, transitions,
                                 weights, update_key)
      updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state)
      new_online_params = optax.apply_updates(online_params, updates)
      return rng_key, new_opt_state, new_online_params, losses

    self._update = jax.jit(update)

    def select_action(rng_key, network_params, s_t):
      """Computes greedy (argmax) action wrt Q-values at given state."""
      rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
      q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0]
      a_t = rlax.greedy().sample(policy_key, q_t)
      v_t = jnp.max(q_t, axis=-1)
      return rng_key, a_t, v_t

    self._select_action = jax.jit(select_action)

  def step(self, timestep: dm_env.TimeStep) -> parts.Action:

0 Source : networks.py
with Apache License 2.0
from deepmind

def logdet_matmul(xs: Sequence[jnp.ndarray],
                  w: Optional[jnp.ndarray] = None) -> jnp.ndarray:
  """Combines determinants and takes dot product with weights in log-domain.

  We use the log-sum-exp trick to reduce numerical instabilities.

  Args:
    xs: FermiNet orbitals in each determinant. Either of length 1 with shape
      (ndet, nelectron, nelectron) (full_det=True) or length 2 with shapes
      (ndet, nalpha, nalpha) and (ndet, nbeta, nbeta) (full_det=False,
      determinants are factorised into block-diagonals for each spin channel).
    w: weight of each determinant. If none, a uniform weight is assumed.

  Returns:
    sum_i w_i D_i in the log domain, where w_i is the weight of D_i, the i-th
    determinant (or product of the i-th determinant in each spin channel, if
    full_det is not used).
  """
  # Special case to avoid taking log(0) if any matrix is of size 1x1.
  # We can avoid this by not going into the log domain and skipping the
  # log-sum-exp trick.
  det1 = functools.reduce(
      lambda a, b: a * b,
      [x.reshape(-1) for x in xs if x.shape[-1] == 1],
      1
  )

  # Compute the logdet for all matrices larger than 1x1
  sign_in, logdet = functools.reduce(
      lambda a, b: (a[0] * b[0], a[1] + b[1]),
      [slogdet(x) for x in xs if x.shape[-1] > 1],
      (1, 0)
  )
  # log-sum-exp trick
  maxlogdet = jnp.max(logdet)
  det = sign_in * det1 * jnp.exp(logdet - maxlogdet)

  if w is None:
    result = jnp.sum(det)
  else:
    result = jnp.dot(det, w)

  sign_out = jnp.sign(result)
  log_out = jnp.log(jnp.abs(result)) + maxlogdet
  return sign_out, log_out


def fermi_net_orbitals(

0 Source : pga_strategy.py
with Apache License 2.0
from deepmind

  def _build_optimizer(self, method, n_iter, lr, lower_bound, upper_bound):
    epsilon = jnp.max(upper_bound - lower_bound) / 2

    if method == 'square':
      init_fn = utils.bounded_initialize_fn(bounds=(lower_bound, upper_bound))
      return square.Square(
          num_steps=n_iter,
          epsilon=epsilon,
          bounds=(lower_bound, upper_bound),
          initialize_fn=init_fn)
    elif method == 'pgd':
      init_fn = utils.noop_initialize_fn()
      project_fn = utils.linf_project_fn(
          epsilon=epsilon, bounds=(lower_bound, upper_bound))
      optimizer = optimizer_module.IteratedFGSM(lr)
      return optimizer_module.PGD(optimizer, n_iter, init_fn, project_fn)

    else:
      raise ValueError(f'Unknown method: "{method}"')

  def supports_stochastic_parameters(self):

0 Source : policies.py
with Apache License 2.0
from deepmind

def gumbel_muzero_policy(
    params: base.Params,
    rng_key: chex.PRNGKey,
    root: base.RootFnOutput,
    recurrent_fn: base.RecurrentFn,
    num_simulations: int,
    invalid_actions: Optional[chex.Array] = None,
    max_depth: Optional[int] = None,
    *,
    qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
    max_num_considered_actions: int = 16,
    gumbel_scale: chex.Numeric = 1.,
) -> base.PolicyOutput[action_selection.GumbelMuZeroExtraData]:
  """Runs Gumbel MuZero search and returns the `PolicyOutput`.

  This policy implements Full Gumbel MuZero from
  "Policy improvement by planning with Gumbel".
  https://openreview.net/forum?id=bERaNdoegnO

  At the root of the search tree, actions are selected by Sequential Halving
  with Gumbel. At non-root nodes (aka interior nodes), actions are selected by
  the Full Gumbel MuZero deterministic action selection.

  In the shape descriptions, `B` denotes the batch dimension.

  Args:
    params: params to be forwarded to root and recurrent functions.
    rng_key: random number generator state, the key is consumed.
    root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
      `prior_logits` are from a policy network. The shapes are
      `([B, num_actions], [B], [B, ...])`, respectively.
    recurrent_fn: a callable to be called on the leaf nodes and unvisited
      actions retrieved by the simulation step, which takes as args
      `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
      and the new state embedding. The `rng_key` argument is consumed.
    num_simulations: the number of simulations.
    invalid_actions: a mask with invalid actions. Invalid actions
      have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
    max_depth: maximum search tree depth allowed during simulation.
    qtransform: function to obtain completed Q-values for a node.
    max_num_considered_actions: the maximum number of actions expanded at the
      root node. A smaller number of actions will be expanded if the number of
      valid actions is smaller.
    gumbel_scale: scale for the Gumbel noise. Evalution on perfect-information
      games can use gumbel_scale=0.0.

  Returns:
    `PolicyOutput` containing the proposed action, action_weights and the used
    search tree.
  """
  # Masking invalid actions.
  root = root.replace(
      prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions))

  # Generating Gumbel.
  rng_key, gumbel_rng = jax.random.split(rng_key)
  gumbel = gumbel_scale * jax.random.gumbel(
      gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype)

  # Searching.
  extra_data = action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel)
  search_tree = search.search(
      params=params,
      rng_key=rng_key,
      root=root,
      recurrent_fn=recurrent_fn,
      root_action_selection_fn=functools.partial(
          action_selection.gumbel_muzero_root_action_selection,
          num_simulations=num_simulations,
          max_num_considered_actions=max_num_considered_actions,
          qtransform=qtransform,
      ),
      interior_action_selection_fn=functools.partial(
          action_selection.gumbel_muzero_interior_action_selection,
          qtransform=qtransform,
      ),
      num_simulations=num_simulations,
      max_depth=max_depth,
      invalid_actions=invalid_actions,
      extra_data=extra_data)
  summary = search_tree.summary()

  # Acting with the best action from the most visited actions.
  # The "best" action has the highest `gumbel + logits + q`.
  # Inside the minibatch, the considered_visit can be different on states with
  # a smaller number of valid actions.
  considered_visit = jnp.max(summary.visit_counts, axis=-1, keepdims=True)
  # The completed_qvalues include imputed values for unvisited actions.
  completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])(
      search_tree, search_tree.ROOT_INDEX)
  to_argmax = seq_halving.score_considered(
      considered_visit, gumbel, root.prior_logits, completed_qvalues,
      summary.visit_counts)
  action = action_selection.masked_argmax(to_argmax, invalid_actions)

  # Producing action_weights usable to train the policy network.
  completed_search_logits = _mask_invalid_actions(
      root.prior_logits + completed_qvalues, invalid_actions)
  action_weights = jax.nn.softmax(completed_search_logits)
  return base.PolicyOutput(
      action=action,
      action_weights=action_weights,
      search_tree=search_tree)


def _mask_invalid_actions(logits, invalid_actions):

0 Source : qtransforms.py
with Apache License 2.0
from deepmind

def qtransform_completed_by_mix_value(
    tree: tree_lib.Tree,
    node_index: chex.Numeric,
    *,
    value_scale: chex.Numeric = 0.1,
    maxvisit_init: chex.Numeric = 50.0,
    rescale_values: bool = True,
    use_mixed_value: bool = True,
    epsilon: chex.Numeric = 1e-8,
) -> chex.Array:
  """Returns completed qvalues.

  The missing Q-values of the unvisited actions are replaced by the
  mixed value, defined in Appendix D of
  "Policy improvement by planning with Gumbel":
  https://openreview.net/forum?id=bERaNdoegnO

  The Q-values are transformed by a linear transformation:
    `(maxvisit_init + max(visit_counts)) * value_scale * qvalues`.

  Args:
    tree: _unbatched_ MCTS tree state.
    node_index: scalar index of the parent node.
    value_scale: scale for the Q-values.
    maxvisit_init: offset to the `max(visit_counts)` in the scaling factor.
    rescale_values: if True, scale the qvalues by `1 / (max_q - min_q)`.
    use_mixed_value: if True, complete the Q-values with mixed value,
      otherwise complete the Q-values with the raw value.
    epsilon: the minimum denominator when using `rescale_values`.

  Returns:
    Completed Q-values. Shape `[num_actions]`.
  """
  chex.assert_shape(node_index, ())
  qvalues = tree.qvalues(node_index)
  visit_counts = tree.children_visits[node_index]

  # Computing the mixed value and producing completed_qvalues.
  raw_value = tree.raw_values[node_index]
  prior_probs = jax.nn.softmax(
      tree.children_prior_logits[node_index])
  if use_mixed_value:
    value = _compute_mixed_value(
        raw_value,
        qvalues=qvalues,
        visit_counts=visit_counts,
        prior_probs=prior_probs)
  else:
    value = raw_value
  completed_qvalues = _complete_qvalues(
      qvalues, visit_counts=visit_counts, value=value)

  # Scaling the Q-values.
  if rescale_values:
    completed_qvalues = _rescale_qvalues(completed_qvalues, epsilon)
  maxvisit = jnp.max(visit_counts, axis=-1)
  visit_scale = maxvisit_init + maxvisit
  return visit_scale * value_scale * completed_qvalues


def _rescale_qvalues(qvalues, epsilon):

0 Source : loop_test.py
with MIT License
from gehring

    def testUnrollGrad(self, jit):
        max_steps = 10

        def step(x):
            return x*0.1

        def converge_test(x_new, x_old):
            return np.max(x_new - x_old)   <   1e-3

        init_x = np.ones(())

        def run_unrolled(x):
            return loop.fixed_point_iteration(
                init_x=x,
                func=step,
                convergence_test=converge_test,
                max_iter=max_steps,
                batched_iter_size=1,
                unroll=True,
            ).value

        grad_fun = jax.grad(run_unrolled)
        if jit:
            grad_fun = jax.jit(grad_fun)
        grad_fun(init_x)

    def testBatchedRaise(self):

0 Source : loss.py
with Apache License 2.0
from google

def compute_normalization_fixed_point(activations: jnp.ndarray,
                                      t: float,
                                      num_iters: int = 5):
  """Returns the normalization value for each example (t > 1.0).

  Args:
    activations: A multi-dimensional array with last dimension `num_classes`.
    t: Temperature 2 (> 1.0 for tail heaviness).
    num_iters: Number of iterations to run the method.
  Return: An array of same rank as activation with the last dimension being 1.
  """

  mu = jnp.max(activations, -1, keepdims=True)
  normalized_activations_step_0 = activations - mu

  def cond_fun(carry):
    _, iters = carry
    return iters   <   num_iters

  def body_fun(carry):
    normalized_activations, iters = carry
    logt_partition = jnp.sum(
        exp_t(normalized_activations, t), -1, keepdims=True)
    normalized_activations_t = normalized_activations_step_0 * jnp.power(
        logt_partition, 1.0 - t)
    return normalized_activations_t, iters + 1

  normalized_activations_t, _ = while_loop(cond_fun, body_fun,
                                           (normalized_activations_step_0, 0))
  logt_partition = jnp.sum(
      exp_t(normalized_activations_t, t), -1, keepdims=True)
  return -log_t(1.0 / logt_partition, t) + mu


@jax.jit

0 Source : loss.py
with Apache License 2.0
from google

def compute_normalization_binary_search(activations: jnp.ndarray,
                                        t: float,
                                        num_iters: int = 10):
  """Returns the normalization value for each example (t   <   1.0).

  Args:
    activations: A multi-dimensional array with last dimension `num_classes`.
    t: Temperature 2 ( <  1.0 for finite support).
    num_iters: Number of iterations to run the method.
  Return: An array of same rank as activation with the last dimension being 1.
  """
  mu = jnp.max(activations, -1, keepdims=True)
  normalized_activations = activations - mu
  shape_activations = activations.shape
  effective_dim = jnp.float32(
      jnp.sum(
          jnp.int32(normalized_activations > -1.0 / (1.0 - t)),
          -1,
          keepdims=True))
  shape_partition = list(shape_activations[:-1]) + [1]
  lower = jnp.zeros(shape_partition)
  upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition)

  def cond_fun(carry):
    _, _, iters = carry
    return iters  <  num_iters

  def body_fun(carry):
    lower, upper, iters = carry
    logt_partition = (upper + lower) / 2.0
    sum_probs = jnp.sum(
        exp_t(normalized_activations - logt_partition, t), -1, keepdims=True)
    update = jnp.float32(sum_probs  <  1.0)
    lower = jnp.reshape(lower * update + (1.0 - update) * logt_partition,
                        shape_partition)
    upper = jnp.reshape(upper * (1.0 - update) + update * logt_partition,
                        shape_partition)
    return lower, upper, iters + 1

  lower = jnp.zeros(shape_partition)
  upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition)
  lower, upper, _ = while_loop(cond_fun, body_fun, (lower, upper, 0))

  logt_partition = (upper + lower) / 2.0
  return logt_partition + mu


@jax.jit

0 Source : functions.py
with Apache License 2.0
from google

def log_softmax(x: Array,
                axis: Optional[Union[int, Tuple[int, ...]]] = -1,
                where: Optional[Array] = None,
                initial: Optional[Array] = None) -> Array:
  r"""Log-Softmax function.

  Computes the logarithm of the :code:`softmax` function, which rescales
  elements to the range :math:`[-\infty, 0)`.

  .. math ::
    \mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
    \right)

  Args:
    x : input array
    axis: the axis or axes along which the :code:`log_softmax` should be
      computed. Either an integer or a tuple of integers.
    where: Elements to include in the :code:`log_softmax`.
    initial: The minimum value used to shift the input array. Must be present
      when :code:`where` is not None.
  """
  x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
  shifted = x - lax.stop_gradient(x_max)
  shifted_logsumexp = jnp.log(
      jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
  return shifted - shifted_logsumexp


# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
#@partial(jax.jit, static_argnames=("axis",))
def softmax(x: Array,

0 Source : functions.py
with Apache License 2.0
from google

def softmax(x: Array,
            axis: Optional[Union[int, Tuple[int, ...]]] = -1,
            where: Optional[Array] = None,
            initial: Optional[Array] = None) -> Array:
  r"""Softmax function.

  Computes the function which rescales elements to the range :math:`[0, 1]`
  such that the elements along :code:`axis` sum to :math:`1`.

  .. math ::
    \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}

  Args:
    x : input array
    axis: the axis or axes along which the softmax should be computed. The
      softmax output summed across these dimensions should sum to :math:`1`.
      Either an integer or a tuple of integers.
    where: Elements to include in the :code:`softmax`.
    initial: The minimum value used to shift the input array. Must be present
      when :code:`where` is not None.
  """
  x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
  unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
  return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)

@partial(jax.jit, static_argnames=("axis",))

0 Source : jet_test.py
with Apache License 2.0
from google

  def test_all_max(self):    self.unary_check(jnp.max)
  @jtu.skip_on_devices("tpu")

0 Source : vector.py
with Apache License 2.0
from google

  def max(self):
    parts = map(jnp.max, tree_util.tree_leaves(self))
    return jnp.asarray(list(parts)).max()

@tree_util.register_pytree_node_class

0 Source : active_learning.py
with Apache License 2.0
from google

def get_msp_scores(logits, masks):
  """Obtain scores using maximum softmax probability scoring.

  Args:
    logits: the logits of the pool set.
    masks: the masks belonging to the pool set.

  Returns:
    a list of scores belonging to the pool set.
  """
  probs = jax.nn.softmax(logits)
  max_probs = jnp.max(probs, axis=-1)

  # High max prob means low uncertainty, so we invert the value.
  msp_scores = -max_probs
  msp_scores = jnp.where(masks, msp_scores, NINF_SCORE)

  return msp_scores


def get_uniform_scores(masks, rng):

0 Source : moe.py
with Apache License 2.0
from google-research

def _get_top_experts_per_item_einsum_dispatcher(
    gates: Array, num_selected_experts: int, capacity: int,
    batch_priority: bool, **dispatcher_kwargs) -> EinsumDispatcher:
  """Returns an EinsumDispatcher performing Top-Experts-Per-Item routing.

  Args:
    gates: (S, E) array with the gating values for each (item, expert).
      These values will also be used as combine_weights for the selected pairs.
    num_selected_experts: Maximum number of experts to select per each item.
    capacity: Maximum number of items processed by each expert.
    batch_priority: Whether to use batch priority routing or not.
    **dispatcher_kwargs: Additional arguments for the EinsumDispatcher.

  Returns:
    An EinsumDispatcher object.
  """
  _, _, buffer_idx = _get_top_experts_per_item_common(
      gates, num_selected_experts, batch_priority)
  # (S, K, E) -> (S, E). Select the only buffer index for each (item, expert).
  buffer_idx = jnp.max(buffer_idx, axis=1)
  # (S, E, C). Convert the buffer indices to a one-hot matrix. We rely on the
  # fact that indices   <   0 or >= capacity will be ignored by the dispatcher.
  dispatch_weights = jax.nn.one_hot(buffer_idx, capacity, dtype=jnp.bool_)
  einsum_precision = dispatcher_kwargs.get("einsum_precision",
                                           jax.lax.Precision.DEFAULT)
  combine_weights = jnp.einsum(
      "SE,SEC->SEC", gates, dispatch_weights, precision=einsum_precision)
  return EinsumDispatcher(
      combine_weights=combine_weights,
      dispatch_weights=dispatch_weights,
      **dispatcher_kwargs)


def _get_top_experts_per_item_expert_indices_dispatcher(

0 Source : moe.py
with Apache License 2.0
from google-research

def _get_top_experts_per_item_expert_indices_dispatcher(
    gates: Array, num_selected_experts: int, capacity: int,
    batch_priority: bool, **dispatcher_kwargs) -> ExpertIndicesDispatcher:
  """Returns an ExpertIndicesDispatcher performing Top-Experts-Per-Item routing.

  Args:
    gates: (S, E) array with the gating values for each (item, expert).
      These values will also be used as combine_weights for the selected pairs.
    num_selected_experts: Maximum number of experts to select per each item.
    capacity: Maximum number of items processed by each expert.
    batch_priority: Whether to use batch priority routing or not.
    **dispatcher_kwargs: Additional arguments for the ExpertIndicesDispatcher.

  Returns:
    An ExpertIndicesDispatcher object.
  """
  _, num_experts = gates.shape
  combine_weights, expert_idx, buffer_idx = _get_top_experts_per_item_common(
      gates, num_selected_experts, batch_priority)
  # (S, K, E) -> (S, K). Select the only buffer index for each (item, k_choice).
  buffer_idx = jnp.max(buffer_idx, axis=2)
  return ExpertIndicesDispatcher(
      indices=jnp.stack([expert_idx, buffer_idx], axis=-1),
      combine_weights=combine_weights,
      num_experts=num_experts,
      capacity=capacity,
      **dispatcher_kwargs)

0 Source : semirings.py
with MIT License
from harvardnlp

    def sum(xs, dim=-1):
        return np.max(xs, axis=dim)

0 Source : train_tools.py
with MIT License
from jemisjoky

def minibatches(str_set, mini_size, keep_end=False):
    """
    Convert StrSet object into iterator over StrSet's of size `mini_size`

    Args:
        str_set:    StrSet object, which in practice will be a large one
                    holding an entire dataset
        mini_size:  Size of the minibatches we want to yield
        keep_end:   Whether to use the last part of our dataset when 
                    mini_size doesn't evenly divide the number of strings 
                    in str_set (Default: False)

    Returns:
        batch_iter: An iterator over minibatches of size mini_size, each
                    of which is itself a StrSet instance
    """
    index_mat, str_lens = str_set.index_mat, str_set.str_lens
    num_batches, tiny_batch = divmod(len(index_mat), mini_size)
    if keep_end and tiny_batch > 0:
        num_batches += 1

    for ind in range(num_batches):
        # Pull out part of index_mat and str_lens, truncate former so that
        # it doesn't contain unnecessary padding
        ind_mat = index_mat[ind*mini_size: (ind+1)*mini_size]
        s_lens = str_lens[ind*mini_size: (ind+1)*mini_size]
        m_len = jnp.max(s_lens)
        assert jnp.all(ind_mat[:, m_len:] == 0)
        ind_mat = ind_mat[:, :m_len]

        yield StrSet(index_mat=ind_mat, str_lens=s_lens)

def to_string(str_set, alphabet):

0 Source : base.py
with MIT License
from markusschmitt

    def get_s_primes(self, s, *args):
        """Compute matrix elements

        For a list of computational basis states :math:`s` this member function computes the corresponding \
        matrix elements :math:`O_{s,s'}=\langle s|\hat O|s'\\rangle` and their respective configurations \
        :math:`s'`.

        Arguments:
            * ``s``: Array of computational basis states.
            * ``*args``: Further positional arguments that are passed on to the specific operator implementation.

        Returns:
            An array holding `all` configurations :math:`s'` and the corresponding matrix elements :math:`O_{s,s'}`.

        """

        if (not self.compiled) or self.compiled_argnum!=len(args):
            _get_s_primes = jax.vmap(self.compile(), in_axes=(0,)+(None,)*len(args))
            self._get_s_primes_pmapd = global_defs.pmap_for_my_devices(_get_s_primes, in_axes=(0,)+(None,)*len(args))
            self.compiled = True
            self.compiled_argnum = len(args)

        # Compute matrix elements
        self.sp, self.matEl = self._get_s_primes_pmapd(s, *args)

        # Get only non-zero contributions
        idx, self.numNonzero = self._find_nonzero_pmapd(self.matEl)
        self.matEl = self._set_zero_to_zero_pmapd(self.matEl, idx[..., :jnp.max(self.numNonzero)], self.numNonzero)
        self.sp = self._array_idx_pmapd(self.sp, idx[..., :jnp.max(self.numNonzero)])

        return self._flatten_pmapd(self.sp), self.matEl


    def _get_O_loc(self, matEl, logPsiS, logPsiSP):

0 Source : adamp.py
with Apache License 2.0
from nestordemeure

def _is_scale_invariant(param, grad, delta, eps):
    """test to determine if the tensor is scale invariant"""
    return jnp.max(_cosine_similarity(param, grad, eps))   <   delta / jnp.sqrt(param.shape[1])

def _make_view_conditionalupdate(update, param, grad, delta, view_func, eps):

0 Source : hmm_logspace_lib.py
with MIT License
from probml

def logdotexp(u, v, axis=-1):
    '''
    Calculates jnp.log(jnp.exp(u) * jnp.exp(v)) in a stable way.
    Parameters
    ----------
    u : array
    v : array
    axis : int
    Returns
    -------
    * array
        Logarithm of the Hadamard product of u and v
    '''
    max_u = jnp.max(u, axis=axis, keepdims=True)
    max_v = jnp.max(v, axis=axis, keepdims=True)

    diff_u = jnp.nan_to_num(u - max_u, -jnp.inf)
    diff_v = jnp.nan_to_num(v - max_v, -jnp.inf)

    u_dot_v = jnp.log(jnp.exp(diff_u) * jnp.exp(diff_v))
    u_dot_v = u_dot_v + max_u + max_v

    return u_dot_v


def log_normalize(u, axis=-1):

0 Source : vectorise.py
with MIT License
from SamDuffield

def auto_axes_lims(vec_dens: Callable,
                   xlim: float = 10.,
                   ylim: float = 10.):
    # Assumes ijnput is a vectorised function that goes to 0 in the tails

    # Initial evaluation grid
    ix, iy, grid = _generate_plot_grid([-xlim, xlim], [-ylim, ylim], resolution=100, linspace=True)
    z = vec_dens(grid)

    # Find mode
    max_z = jnp.max(z)
    if jnp.isnan(max_z):
        raise TypeError('nan found attempting auto_axes_lims, try giving manual xlim and ylim as kwargs')

    if max_z == 0.:
        return auto_axes_lims(vec_dens, xlim=xlim/1.5, ylim=ylim/1.5)

    # Area with probability mass
    z_keep = z > max_z / 10

    # Find bounds of area with probability mass
    xlim_new = jnp.array([ix[_find_first_non_zero_row(z_keep.T, direction)] for direction in [1, -1]])
    ylim_new = jnp.array([iy[_find_first_non_zero_row(z_keep, direction)] for direction in [1, -1]])

    if xlim in jnp.abs(xlim_new) or ylim in jnp.abs(ylim_new):
        return auto_axes_lims(vec_dens,
                              xlim=2*xlim if xlim in jnp.abs(xlim_new) else xlim,
                              ylim=2*ylim if ylim in jnp.abs(ylim_new) else ylim)

    # Expand
    expansion = 0.05
    xlim += (xlim_new[1] - xlim_new[0]) * expansion * jnp.array([-1, 1])
    ylim += (ylim_new[1] - ylim_new[0]) * expansion * jnp.array([-1, 1])

    return tuple(xlim_new), tuple(ylim_new)


def _plot_densf(ax, x, y, z, **kwargs):

0 Source : jax_loss.py
with GNU Affero General Public License v3.0
from synsense

def softmax(x: np.ndarray, temperature: float = 1.0) -> np.ndarray:
    """
    Implements the softmax function

    .. math::

        S(x, \\tau) = \\exp(l / \\tau) / { \\Sigma { \\exp(l / \\tau)} }

        l = x - \\max(x)

    Args:
        x (np.ndarray): Input vector of scores
        temperature (float): Temperature :math:`\\tau` of the softmax. As :math:`\\tau \\rightarrow 0`, the function becomes a hard :math:`\\max` operation. Default: ``1.0``.

    Returns:
        np.ndarray: The output of the softmax.
    """
    logits = x - np.max(x)
    eta = np.exp(logits / temperature)
    return eta / np.sum(eta)


def logsoftmax(x: np.ndarray, temperature: float = 1.0) -> np.ndarray:

0 Source : attentions.py
with Apache License 2.0
from tensorflow

  def _log_softmax_with_extra_logit(self, logits: JTensor) -> JTensor:
    """Compute log softmax with extra logit.

    self.params.attention_extra_logit is an user defined float value that
    helps to stablize logit values so that they don't drift too much from it.

    Args:
      logits: input logit tensor

    Returns:
      Log softmax with extra logit value.
    """
    # Applies stop_gradient to max_logit instead of logits.
    max_logit = jnp.max(jax.lax.stop_gradient(logits), axis=-1, keepdims=True)
    extra_logit = self.params.attention_extra_logit
    if extra_logit is not None:
      extra_logit = jnp.asarray(extra_logit, dtype=max_logit.dtype)
      max_logit = jnp.maximum(max_logit, extra_logit)
    exp_x = jnp.exp(logits - max_logit)
    sum_exp_x = jnp.sum(exp_x, axis=-1, keepdims=True)
    if extra_logit is not None:
      sum_exp_x += jnp.exp(extra_logit - max_logit)
    return logits - jnp.log(sum_exp_x) - max_logit

  def _dot_atten(

0 Source : matching.py
with MIT License
from tristandeleu

def matching_log_probas(embeddings, targets, test_embeddings, num_classes, eps=1e-8):
    num_samples = test_embeddings.shape[0]
    similarities = pairwise_cosine_similarity(embeddings, test_embeddings, eps=eps)
    logsumexp = nn.logsumexp(similarities, axis=0, keepdims=True)

    max_similarities = jnp.max(similarities, axis=0, keepdims=True)
    exp_similarities = jnp.exp(similarities - max_similarities)

    sum_exp = jnp.zeros((num_classes, num_samples), dtype=exp_similarities.dtype)
    indices = jnp.expand_dims(targets, axis=-1)
    dimension_numbers = ScatterDimensionNumbers(
        update_window_dims=(1,),
        inserted_window_dims=(0,),
        scatter_dims_to_operand_dims=(0,))
    sum_exp = scatter_add(sum_exp, indices, exp_similarities, dimension_numbers)

    return jnp.log(sum_exp) + max_similarities - logsumexp


def matching_probas(embeddings, targets, test_embeddings, num_classes, eps=1e-8):

See More Examples