Here are the examples of the python api jax.numpy.max taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.
51 Examples
3
Source : decoders.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def _decode_graph_diffs(decoders, h_t: _Array, graph_fts: _Array) -> _Array:
"""Decodes graph diffs."""
gr_emb = jnp.max(h_t, axis=-2)
g_pred_n = decoders[0](gr_emb)
g_pred_g = decoders[1](graph_fts)
preds = jnp.squeeze(g_pred_n + g_pred_g, -1)
return preds
3
Source : policies.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def _mask_invalid_actions(logits, invalid_actions):
"""Returns logits with zero mass to invalid actions."""
if invalid_actions is None:
return logits
chex.assert_equal_shape([logits, invalid_actions])
logits = logits - jnp.max(logits, axis=-1, keepdims=True)
# At the end of an episode, all actions can be invalid. A softmax would then
# produce NaNs, if using -inf for the logits. We avoid the NaNs by using
# a finite `min_logit` for the invalid actions.
min_logit = jnp.finfo(logits.dtype).min
return jnp.where(invalid_actions, min_logit, logits)
def _get_logits_from_probs(probs):
3
Source : policies.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def _apply_temperature(logits, temperature):
"""Returns `logits / temperature`, supporting also temperature=0."""
# The max subtraction prevents +inf after dividing by a small temperature.
logits = logits - jnp.max(logits, keepdims=True, axis=-1)
tiny = jnp.finfo(logits.dtype).tiny
return logits / jnp.maximum(tiny, temperature)
3
Source : seq_halving.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def score_considered(considered_visit, gumbel, logits, normalized_qvalues,
visit_counts):
"""Returns a score usable for an argmax."""
# We allow to visit a child, if it is the only considered child.
low_logit = -1e9
logits = logits - jnp.max(logits, keepdims=True, axis=-1)
penalty = jnp.where(
visit_counts == considered_visit,
0, -jnp.inf)
chex.assert_equal_shape([gumbel, logits, normalized_qvalues, penalty])
return jnp.maximum(low_logit, gumbel + logits + normalized_qvalues) + penalty
def get_sequence_of_considered_visits(max_num_considered_actions,
3
Source : masks_test.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def create_input(lengths):
input_tokens = jnp.ones((len(lengths), jnp.max(lengths)), dtype=jnp.float32)
input_tokens = input_tokens * (
jnp.arange(jnp.max(lengths)) < jnp.reshape(lengths, (-1, 1)))
return input_tokens, lengths
def create_decoder_input(prefix_lengths, target_lengths):
3
Source : masks_test.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def create_decoder_input(prefix_lengths, target_lengths):
targets, example_lengths = create_input(prefix_lengths + target_lengths)
causal, _ = create_input(prefix_lengths)
causal = jnp.concatenate([
causal,
jnp.zeros((causal.shape[0], jnp.max(example_lengths) - causal.shape[1]))
],
axis=1)
return targets, causal, example_lengths
class PromptDecoderAttentionTest(parameterized.TestCase):
3
Source : masks_test.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def create_decoder_input(prefix_lengths, target_lengths):
targets, example_lengths = create_input(prefix_lengths + target_lengths)
causal, _ = create_input(prefix_lengths)
causal = jnp.concatenate([
causal,
jnp.zeros((causal.shape[0], jnp.max(example_lengths) - causal.shape[1]))
],
axis=1)
return targets, causal, example_lengths
class CreatePromptDecoderOnlyMaskTest(parameterized.TestCase):
3
Source : ode.py
with MIT License
from jacobjinkelly
with MIT License
from jacobjinkelly
def optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0,
dfactor=0.2, order=5.0):
"""Compute optimal Runge-Kutta stepsize."""
mean_error_ratio = np.max(mean_error_ratio)
dfactor = np.where(mean_error_ratio < 1, 1.0, dfactor)
err_ratio = np.sqrt(mean_error_ratio)
factor = np.maximum(1.0 / ifactor,
np.minimum(err_ratio**(1.0 / order) / safety, 1.0 / dfactor))
return np.where(mean_error_ratio == 0, last_step * ifactor, last_step / factor)
def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=np.inf):
3
Source : dqn.py
with MIT License
from ku2482
with MIT License
from ku2482
def _calculate_target(
self,
params: hk.Params,
params_target: hk.Params,
reward: np.ndarray,
done: np.ndarray,
next_state: np.ndarray,
) -> jnp.ndarray:
if self.double_q:
next_action = self._forward(params, next_state)[..., None]
next_q = self._calculate_value(params_target, next_state, next_action)
else:
next_q = jnp.max(self.net.apply(params_target, next_state), axis=-1, keepdims=True)
return jax.lax.stop_gradient(reward + (1.0 - done) * self.discount * next_q)
@partial(jax.jit, static_argnums=0)
3
Source : numpy_ops.py
with GNU General Public License v3.0
from PKU-NIP-Lab
with GNU General Public License v3.0
from PKU-NIP-Lab
def max(a, axis=None, keepdims=None, initial=None, where=None):
a = _remove_jaxarray(a)
r = jnp.max(a, axis=axis, keepdims=keepdims, initial=initial, where=where)
return r if axis is None else JaxArray(r)
def min(a, axis=None, keepdims=None, initial=None, where=None):
3
Source : test_ssm.py
with MIT License
from SamDuffield
with MIT License
from SamDuffield
def _test_bootstrap(self):
if not hasattr(self, 'sim_samps'):
self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
self.pf_samps = run_particle_filter_for_marginals(self.ssm_scenario,
BootstrapFilter(),
self.sim_samps.y,
self.t,
random.PRNGKey(0),
n=self.n)
npt.assert_array_less(self.sim_samps.x[:, 0], jnp.max(self.pf_samps.value, axis=1)[:, 0])
npt.assert_array_less(jnp.min(self.pf_samps.value, axis=1)[:, 0], self.sim_samps.x[:, 0])
def backward_preprocess(self):
3
Source : test_ssm.py
with MIT License
from SamDuffield
with MIT License
from SamDuffield
def backward_postprocess(self):
npt.assert_array_less(self.sim_samps.x[:, 0], jnp.max(self.backward_samps.value, axis=1)[:, 0])
npt.assert_array_less(jnp.min(self.backward_samps.value, axis=1)[:, 0], self.sim_samps.x[:, 0])
self.assertFalse(hasattr(self.backward_samps, 'log_weight'))
def _test_ffbsi_full(self):
3
Source : test_ssm.py
with MIT License
from SamDuffield
with MIT License
from SamDuffield
def _test_online_smoothing_pf_full(self):
if not hasattr(self, 'sim_samps'):
self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
pf = BootstrapFilter()
len_t = len(self.t)
rkeys = random.split(random.PRNGKey(0), len_t)
particles = initiate_particles(self.ssm_scenario, pf, self.n,
rkeys[0], self.sim_samps.y[0], self.t[0])
for i in range(1, len_t):
particles = propagate_particle_smoother(self.ssm_scenario, pf, particles,
self.sim_samps.y[i], self.t[i], rkeys[i], 3,
False)
npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.max(particles.value, axis=1)[:, 0]) > 0).mean(), 0.1)
npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.min(particles.value, axis=1)[:, 0]) < 0).mean(), 0.1)
def _test_online_smoothing_pf_rejection(self):
3
Source : test_ssm.py
with MIT License
from SamDuffield
with MIT License
from SamDuffield
def _test_online_smoothing_pf_rejection(self):
if not hasattr(self, 'sim_samps'):
self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
pf = BootstrapFilter()
len_t = len(self.t)
rkeys = random.split(random.PRNGKey(0), len_t)
particles = initiate_particles(self.ssm_scenario, pf, self.n,
rkeys[0], self.sim_samps.y[0], self.t[0])
for i in range(1, len_t):
particles = propagate_particle_smoother(self.ssm_scenario, pf, particles,
self.sim_samps.y[i], self.t[i], rkeys[i], 3,
False, maximum_rejections=10)
npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.max(particles.value, axis=1)[:, 0]) > 0).mean(), 0.1)
npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.min(particles.value, axis=1)[:, 0]) < 0).mean(), 0.1)
def _test_online_smoothing_bs_full(self):
3
Source : test_ssm.py
with MIT License
from SamDuffield
with MIT License
from SamDuffield
def _test_online_smoothing_bs_full(self):
if not hasattr(self, 'sim_samps'):
self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
pf = BootstrapFilter()
len_t = len(self.t)
rkeys = random.split(random.PRNGKey(0), len_t)
particles = initiate_particles(self.ssm_scenario, pf, self.n,
rkeys[0], self.sim_samps.y[0], self.t[0])
for i in range(1, len_t):
particles = propagate_particle_smoother(self.ssm_scenario, pf, particles,
self.sim_samps.y[i], self.t[i], rkeys[i], 3,
True)
npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.max(particles.value, axis=1)[:, 0]) > 0).mean(), 0.1)
npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.min(particles.value, axis=1)[:, 0]) < 0).mean(), 0.1)
def _test_online_smoothing_bs_rejection(self):
3
Source : test_ssm.py
with MIT License
from SamDuffield
with MIT License
from SamDuffield
def _test_online_smoothing_bs_rejection(self):
if not hasattr(self, 'sim_samps'):
self.sim_samps = self.ssm_scenario.simulate(self.t, random.PRNGKey(0))
pf = BootstrapFilter()
len_t = len(self.t)
rkeys = random.split(random.PRNGKey(0), len_t)
particles = initiate_particles(self.ssm_scenario, pf, self.n,
rkeys[0], self.sim_samps.y[0], self.t[0])
for i in range(1, len_t):
particles = propagate_particle_smoother(self.ssm_scenario, pf, particles,
self.sim_samps.y[i], self.t[i], rkeys[i], 3,
True, maximum_rejections=10)
npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.max(particles.value, axis=1)[:, 0]) > 0).mean(), 0.1)
npt.assert_array_less(((self.sim_samps.x[:, 0] - jnp.min(particles.value, axis=1)[:, 0]) < 0).mean(), 0.1)
if __name__ == '__main__':
3
Source : jax_loss.py
with GNU Affero General Public License v3.0
from synsense
with GNU Affero General Public License v3.0
from synsense
def logsoftmax(x: np.ndarray, temperature: float = 1.0) -> np.ndarray:
"""
Efficient implementation of the log softmax function
.. math ::
log S(x, \\tau) = (l / \\tau) - \\log \\Sigma { \\exp (l / \\tau) }
l = x - \\max (x)
Args:
x (np.ndarray): Input vector of scores
temperature (float): Temperature :math:`\\tau` of the softmax. As :math:`\\tau \\rightarrow 0`, the function becomes a hard :math:`\\max` operation. Default: ``1.0``.
Returns:
np.ndarray: The output of the logsoftmax.
"""
logits = x - np.max(x)
return (logits / temperature) - np.log(np.sum(np.exp(logits / temperature)))
3
Source : embedding_softmax.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def compute_z_loss(self, logits):
"""Returns a z_loss regularization which stablize logits."""
# Applies stop_gradient to max_logit instead of logits.
max_logit = jax.lax.stop_gradient(jnp.max(logits, axis=-1, keepdims=True))
exp_x = jnp.exp(logits - max_logit)
sum_exp_x = jnp.sum(exp_x, axis=-1, keepdims=True)
log_z = jnp.log(sum_exp_x) + max_logit
return jnp.square(log_z)
def fprop(self,
3
Source : poolings.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def fprop(self, inputs: JTensor) -> JTensor:
"""Applies global spatial pooling to inputs.
Args:
inputs: An input tensor.
Returns:
Output tensor with global pooling applied.
"""
p = self.params
if p.pooling_type == 'MAX':
outputs = jnp.max(inputs, p.pooling_dims, keepdims=p.keepdims)
elif p.pooling_type == 'AVG':
outputs = jnp.mean(inputs, p.pooling_dims, keepdims=p.keepdims)
return outputs
0
Source : jaxent.py
with MIT License
from adamhaber
with MIT License
from adamhaber
def train(self,data,data_kind="samples",data_n_samp=None,alpha=0.32,lr=1e-1,threshold=1.,kind=None,n_samps=5000):
"""fit a maximum entropy model to data.
Parameters
----------
data : array_like
either an array of binary samples, or an array of desired marginals
data_kind : str, optional
"samples" - data samples are passed
"marginals" - desired marginals are passed
data_n_samp : int, optional
number of trials, needed to compute confidence intervals
alpha : float, optional
confidence level
lr : float, optional
learning rate, by default 1e-1
threshold : float, optional
maximum allowed difference between model marginals and empirical marginals, in empirical standard deviations units. by default 1
kind : str, optional
"exhuastive" means analytical computation of model marginals, "sample" means MCMC estimation, by default None and estimated from the number of units N.
n_samps : int, optional
number of samples to generate in each MCMC estimation of model marginals, by default 5000
"""
@jit
def _training_step_ex(i,opt_state):
params = get_params(opt_state)
model_marg = self.calc_marginals_ex(self._calc_p(params))
g = self.empirical_marginals-model_marg
return opt_update(i, g, opt_state),model_marg
@jit
def _training_step(i,opt_state):
params = get_params(opt_state)
samples = self._sample(random.PRNGKey(i),n_samps,params, -1, -1)
model_marg = self.calc_marginals(samples)
g = self.empirical_marginals-model_marg
return opt_update(i, g, opt_state),model_marg
@jit
def _training_loop(loop_carry):
i,opt_state, params,_ = loop_carry
opt_state,marginals = step(i,opt_state)
params = get_params(opt_state)
return i+1,opt_state, params,marginals
if kind is None:
kind = onp.where(self.N>20,'sample','exhuastive')
if kind=='exhuastive':
step = _training_step_ex
if self.words is None:
self.create_words()
self.model_marginals = self.calc_marginals_ex(self._calc_p(self.factors))
elif kind=='sample':
step = _training_step
self.model_marginals = self.calc_marginals(self._sample(random.PRNGKey(0),n_samps,self.factors,-1,-1))
lower, upper = self.calc_empirical_marginals_and_stds(data,data_kind,data_n_samp,alpha)
self.empirical_std = upper-lower
opt_init, opt_update, get_params = optimizers.adam(lr)
training_steps, opt_state, params,marginals = while_loop(lambda x: np.max(self.calc_deviations(x[3])) > threshold,_training_loop, (0,opt_init(self.factors), self.factors,self.model_marginals))
self.factors = params
self.training_steps = training_steps
self.model_marginals = marginals
self.trained = True
if kind=='exhuastive':
self.p_model = self._calc_p(self.factors)
self.Z = np.exp(self.calc_logZ(self.calc_logp_unnormed(self.factors)))
self.entropy = self._calc_entropy(self.p_model)
def _calc_entropy(self,p):
0
Source : networks.py
with MIT License
from brentyi
with MIT License
from brentyi
def __call__(self, inputs: jnp.ndarray): # type: ignore
x = inputs
N = x.shape[0]
assert x.shape == (N, 120, 120, 3), x.shape
# Some conv layers
for _ in range(3):
x = nn.Conv(features=32, kernel_size=(3, 3), kernel_init=relu_layer_init)(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(3, 3), kernel_init=linear_layer_init)(x)
# Channel-wise max pool
x = jnp.max(x, axis=3, keepdims=True)
# Spanning mean pools (to regularize X/Y coordinate regression)
assert x.shape == (N, 120, 120, 1)
x_horizontal = nn.avg_pool(x, window_shape=(120, 1))
x_vertical = nn.avg_pool(x, window_shape=(1, 120))
# Concatenate, feed through MLP
x = jnp.concatenate(
[x_horizontal.reshape((N, -1)), x_vertical.reshape((N, -1))], axis=1
)
assert x.shape == (N, 240)
x = MLP.make(units=32, layers=3, output_dim=self.output_dim)(x)
return x
def make_position_cnn(seed: int = 0) -> Tuple[DiskVirtualSensor, Pytree]:
0
Source : tmp_potential.py
with BSD 3-Clause "New" or "Revised" License
from CCQC
with BSD 3-Clause "New" or "Revised" License
from CCQC
def tmp_potential(geom, basis, charges):
"""
Build potential one-electron integrals array
"""
coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
nbf = get_nbf(basis)
nprim = coeffs.shape[0]
max_am = np.max(ams)
A_vals = np.zeros(2*max_am+1)
# Save various AM distributions for indexing
# Obtain all possible primitive duet index combinations
primitive_duets = cartesian_product(np.arange(nprim), np.arange(nprim))
with loops.Scope() as s:
s.V = np.zeros((nbf,nbf))
s.a = 0 # center A angular momentum iterator
s.b = 0 # center B angular momentum iterator
for prim_duet in s.range(primitive_duets.shape[0]):
p1,p2 = primitive_duets[prim_duet]
coef = coeffs[p1] * coeffs[p2]
aa, bb = exps[p1], exps[p2]
atom1, atom2 = atoms[p1], atoms[p2]
am1, am2 = ams[p1], ams[p2]
A, B = geom[atom1], geom[atom2]
ld1, ld2 = am_leading_indices[am1], am_leading_indices[am2]
gamma = aa + bb
prefactor = np.exp(-aa * bb * np.dot(A-B,A-B) / gamma)
P = (aa * A + bb * B) / gamma
# Maximum angular momentum: hard coded
# Precompute all powers up to 2+max_am of Pi-Ai, Pi-Bi.
# We need 2+max_am since kinetic requires incrementing angluar momentum by +2
PA_pow = np.power(np.broadcast_to(P-A, (max_am+3,3)).T, np.arange(max_am+3))
PB_pow = np.power(np.broadcast_to(P-B, (max_am+3,3)).T, np.arange(max_am+3))
# For potential integrals, we need the difference between
# the gaussian product center P and ALL atoms in the molecule,
# and then take all possible powers up to 2*max_am.
# We pre-collect this into a 3d array, and then just pull out what we need via indexing in the loops, so they need not be recomputed.
# The resulting array has dimensions (atom, cartesian component, power) so index (0, 1, 3) would return (Py - atom0_y)^3
P_minus_geom = np.broadcast_to(P, geom.shape) - geom
Pgeom_pow = np.power(np.transpose(np.broadcast_to(P_minus_geom, (2*max_am + 1,geom.shape[0],geom.shape[1])), (1,2,0)), np.arange(2*max_am + 1))
# All possible np.dot(P-atom,P-atom)
rcp2 = np.einsum('ij,ij->i', P_minus_geom, P_minus_geom)
# All needed (and unneeded, for am < max_am) boys function evaluations
boys_arg = np.broadcast_to(rcp2 * gamma, (2*max_am+1, geom.shape[0]))
boys_nu = np.tile(np.arange(2*max_am+1), (geom.shape[0],1)).T
boys_eval = boys(boys_nu,boys_arg)
s.a = 0
for _ in s.while_range(lambda: s.a < dims[p1]):
s.b = 0
for _ in s.while_range(lambda: s.b < dims[p2]):
# Gather angular momentum and index
la,ma,na = angular_momentum_combinations[s.a + ld1]
lb,mb,nb = angular_momentum_combinations[s.b + ld2]
# To only create unique indices, need to have separate indices arrays for i and j.
i = indices[p1] + s.a
j = indices[p2] + s.b
# Compute one electron integrals and add to appropriate index
potential_int = potential(la,ma,na,lb,mb,nb,aa,bb,PA_pow,PB_pow,Pgeom_pow,boys_eval,prefactor,charges,A_vals) * coef
s.V = jax.ops.index_add(s.V, jax.ops.index[i,j], potential_int)
s.b += 1
s.a += 1
return s.V
0
Source : oei.py
with BSD 3-Clause "New" or "Revised" License
from CCQC
with BSD 3-Clause "New" or "Revised" License
from CCQC
def oei_arrays(geom, basis, charges):
"""
Build one electron integral arrays (overlap, kinetic, and potential integrals)
"""
coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
nbf = get_nbf(basis)
nprim = coeffs.shape[0]
max_am = jnp.max(ams)
A_vals = jnp.zeros(2*max_am+1)
# Save various AM distributions for indexing
# Obtain all possible primitive quartet index combinations
primitive_duets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim))
with loops.Scope() as s:
s.oei = jnp.zeros((3,nbf,nbf))
s.a = 0 # center A angular momentum iterator
s.b = 0 # center B angular momentum iterator
for prim_duet in s.range(primitive_duets.shape[0]):
p1,p2 = primitive_duets[prim_duet]
coef = coeffs[p1] * coeffs[p2]
aa, bb = exps[p1], exps[p2]
atom1, atom2 = atoms[p1], atoms[p2]
am1, am2 = ams[p1], ams[p2]
A, B = geom[atom1], geom[atom2]
ld1, ld2 = am_leading_indices[am1], am_leading_indices[am2]
gamma = aa + bb
prefactor = jnp.exp(-aa * bb * jnp.dot(A-B,A-B) / gamma)
P = (aa * A + bb * B) / gamma
# Maximum angular momentum: hard coded
#max_am = 3 # f function support
# Precompute all powers up to 2+max_am of Pi-Ai, Pi-Bi.
# We need 2+max_am since kinetic requires incrementing angluar momentum by +2
PA_pow = jnp.power(jnp.broadcast_to(P-A, (max_am+3,3)).T, jnp.arange(max_am+3))
PB_pow = jnp.power(jnp.broadcast_to(P-B, (max_am+3,3)).T, jnp.arange(max_am+3))
# For potential integrals, we need the difference between
# the gaussian product center P and ALL atoms in the molecule,
# and then take all possible powers up to 2*max_am.
# We pre-collect this into a 3d array, and then just pull out what we need via indexing in the loops, so they need not be recomputed.
# The resulting array has dimensions (atom, cartesian component, power) so index (0, 1, 3) would return (Py - atom0_y)^3
P_minus_geom = jnp.broadcast_to(P, geom.shape) - geom
Pgeom_pow = jnp.power(jnp.transpose(jnp.broadcast_to(P_minus_geom, (2*max_am + 1,geom.shape[0],geom.shape[1])), (1,2,0)), jnp.arange(2*max_am + 1))
# All possible jnp.dot(P-atom,P-atom)
rcp2 = jnp.einsum('ij,ij->i', P_minus_geom, P_minus_geom)
# All needed (and unneeded, for am < max_am) boys function evaluations
boys_arg = jnp.broadcast_to(rcp2 * gamma, (2*max_am+1, geom.shape[0]))
boys_nu = jnp.tile(jnp.arange(2*max_am+1), (geom.shape[0],1)).T
boys_eval = boys(boys_nu,boys_arg)
s.a = 0
for _ in s.while_range(lambda: s.a < dims[p1]):
s.b = 0
for _ in s.while_range(lambda: s.b < dims[p2]):
# Gather angular momentum and index
la,ma,na = angular_momentum_combinations[s.a + ld1]
lb,mb,nb = angular_momentum_combinations[s.b + ld2]
# To only create unique indices, need to have separate indices arrays for i and j.
i = indices[p1] + s.a
j = indices[p2] + s.b
# Compute one electron integrals and add to appropriate index
overlap_int = overlap(la,ma,na,lb,mb,nb,aa,bb,PA_pow,PB_pow,prefactor) * coef
kinetic_int = kinetic(la,ma,na,lb,mb,nb,aa,bb,PA_pow,PB_pow,prefactor) * coef
potential_int = potential(la,ma,na,lb,mb,nb,aa,bb,PA_pow,PB_pow,Pgeom_pow,boys_eval,prefactor,charges,A_vals) * coef
s.oei = jax.ops.index_add(s.oei, ([0,1,2],[i,i,i],[j,j,j]), (overlap_int, kinetic_int, potential_int))
s.b += 1
s.a += 1
S, T, V = s.oei[0], s.oei[1], s.oei[2]
return S, T, V
0
Source : tei.py
with BSD 3-Clause "New" or "Revised" License
from CCQC
with BSD 3-Clause "New" or "Revised" License
from CCQC
def tei_array(geom, basis):
"""
Build two electron integral array from a jax.numpy array of the cartesian geometry in Bohr,
and a basis dictionary as defined by basis_utils.build_basis_set
We have to loop over primitives rather than shells because JAX needs intermediates to be consistent
sizes in order to compile.
"""
# Smush primitive data together into vectors
coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
nbf = get_nbf(basis)
max_am = jnp.max(ams)
max_am_idx = max_am * 4 + 1
#TODO add excpetion raise if angular momentum is too high
B_vals = jnp.zeros(4*max_am+1)
nprim = coeffs.shape[0]
# Obtain all possible primitive quartet index combinations
primitive_quartets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim))
#print("Number of basis functions: ", nbf)
#print("Number of primitve quartets: ", primitive_quartets.shape[0])
#TODO Experimental: precompute quantities and lookup inside loop
# Compute all possible Gaussian products for this basis set
aa_plus_bb = jnp.broadcast_to(exps, (nprim,nprim)) + jnp.transpose(jnp.broadcast_to(exps, (nprim,nprim)), (1,0))
aa_times_A = jnp.einsum('i,ij->ij', exps, geom[atoms])
aaxA_plus_bbxB = aa_times_A[:,None,:] + aa_times_A[None,:,:]
gaussian_products = jnp.einsum('ijk,ij->ijk', aaxA_plus_bbxB, 1/aa_plus_bb)
# Compute all rab2 (rcd2), every possible jnp.dot(A-B,A-B)
natom = geom.shape[0]
tmpA = jnp.broadcast_to(geom, (natom,natom,3))
AminusB = (tmpA - jnp.transpose(tmpA, (1,0,2)))
AmBdot = jnp.einsum('ijk,ijk->ij', AminusB, AminusB) # shape: (natom,natom)
# Compute all differences between gaussian product centers with all atom centers
tmpP = jnp.tile(gaussian_products, natom).reshape(nprim,nprim,natom,3)
PminusA = tmpP - jnp.broadcast_to(geom, tmpP.shape)
# Commpute all powers (up to max_am) of differences between gaussian product centers and atom centers
# Shape: (nprim, nprim, natom, 3, max_am+1). In loop index PA_pow as [p1,p2,atoms[p1],:,:]
PminusA_pow = jnp.power(jnp.transpose(jnp.broadcast_to(PminusA, (max_am+1,nprim,nprim,natom,3)), (1,2,3,4,0)), jnp.arange(max_am+1))
with loops.Scope() as s:
s.G = jnp.zeros((nbf,nbf,nbf,nbf))
s.a = 0 # center A angular momentum iterator
s.b = 0 # center B angular momentum iterator
s.c = 0 # center C angular momentum iterator
s.d = 0 # center D angular momentum iterator
# Loop over primitive quartets, compute integral, add to appropriate index in G
for prim_quar in s.range(primitive_quartets.shape[0]):
# Load in primitive indices, coeffs, exponents, centers, angular momentum index, and leading placement index in TEI array
p1,p2,p3,p4 = primitive_quartets[prim_quar]
coef = coeffs[p1] * coeffs[p2] * coeffs[p3] * coeffs[p4]
aa, bb, cc, dd = exps[p1], exps[p2], exps[p3], exps[p4]
ld1, ld2, ld3, ld4 = am_leading_indices[ams[p1]],am_leading_indices[ams[p2]],am_leading_indices[ams[p3]],am_leading_indices[ams[p4]]
idx1, idx2, idx3, idx4 = indices[p1],indices[p2],indices[p3],indices[p4],
#A, B, C, D = geom[atoms[p1]], geom[atoms[p2]], geom[atoms[p3]], geom[atoms[p4]]
# Compute common intermediates before looping over AM distributions.
# Avoids redundant recomputations/reassignment for all classes other than (ss|ss).
#AB = A - B
#CD = C - D
#rab2 = jnp.dot(AB,AB)
#rcd2 = jnp.dot(CD,CD)
#P = (aa * A + bb * B) / gamma1
#Q = (cc * C + dd * D) / gamma2
gamma1 = aa + bb
gamma2 = cc + dd
#TODO
P = gaussian_products[p1,p2]
Q = gaussian_products[p3,p4]
rab2 = AmBdot[atoms[p1],atoms[p2]]
rcd2 = AmBdot[atoms[p3],atoms[p4]]
#PA = PminusA[p1,p2,atoms[p1]]
#PB = PminusA[p1,p2,atoms[p2]]
#QC = PminusA[p3,p4,atoms[p3]]
#QD = PminusA[p3,p4,atoms[p4]]
#TODO
PQ = P - Q
rpq2 = jnp.dot(PQ,PQ)
delta = 0.25*(1/gamma1+1/gamma2)
boys_arg = 0.25 * rpq2 / delta
boys_eval = boys(jnp.arange(max_am_idx), boys_arg)
# Need all powers of Pi-Ai,Pi-Bi,Qi-Ci,Qi-Di (i=x,y,z) up to max_am and Qi-Pi up to max_am_idx
# note: this computes unncessary quantities for lower angular momentum,
# but avoids repeated computation of the same quantities in loops for higher angular momentum
#PA_pow = jnp.power(jnp.broadcast_to(P-A, (max_am+1,3)).T, jnp.arange(max_am+1))
#PB_pow = jnp.power(jnp.broadcast_to(P-B, (max_am+1,3)).T, jnp.arange(max_am+1))
#QC_pow = jnp.power(jnp.broadcast_to(Q-C, (max_am+1,3)).T, jnp.arange(max_am+1))
#QD_pow = jnp.power(jnp.broadcast_to(Q-D, (max_am+1,3)).T, jnp.arange(max_am+1))
PA_pow = PminusA_pow[p1,p2,atoms[p1],:,:]
PB_pow = PminusA_pow[p1,p2,atoms[p2],:,:]
QC_pow = PminusA_pow[p3,p4,atoms[p3],:,:]
QD_pow = PminusA_pow[p3,p4,atoms[p4],:,:]
QP_pow = jnp.power(jnp.broadcast_to(Q-P, (max_am_idx,3)).T, jnp.arange(max_am_idx))
# Gamma powers are negative, up to -(l1+l2).
# Make array such that the given negative index returns the same negative power.
g1_pow = jnp.power(4*gamma1, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1))
g2_pow = jnp.power(4*gamma2, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1))
oodelta_pow = jnp.power(1 / delta, jnp.arange(max_am_idx)) # l1 + l2 + l3 + l4 + 1
prefactor = 34.986836655249726 / (gamma1*gamma2*jnp.sqrt(gamma1+gamma2)) \
* jnp.exp(-aa*bb*rab2/gamma1 + -cc*dd*rcd2/gamma2) * coef
# TODO is there symmetry here?
s.a = 0
for _ in s.while_range(lambda: s.a < dims[p1]):
s.b = 0
for _ in s.while_range(lambda: s.b < dims[p2]):
s.c = 0
for _ in s.while_range(lambda: s.c < dims[p3]):
s.d = 0
for _ in s.while_range(lambda: s.d < dims[p4]):
# Collect angular momentum and index in G
la, ma, na = angular_momentum_combinations[s.a + ld1]
lb, mb, nb = angular_momentum_combinations[s.b + ld2]
lc, mc, nc = angular_momentum_combinations[s.c + ld3]
ld, md, nd = angular_momentum_combinations[s.d + ld4]
i = idx1 + s.a
j = idx2 + s.b
k = idx3 + s.c
l = idx4 + s.d
# Compute the primitive quartet tei and add to appropriate index in G
Bx = B_array(la,lb,lc,ld,PA_pow[0],PB_pow[0],QC_pow[0],QD_pow[0],QP_pow[0],g1_pow,g2_pow,oodelta_pow,B_vals)
By = B_array(ma,mb,mc,md,PA_pow[1],PB_pow[1],QC_pow[1],QD_pow[1],QP_pow[1],g1_pow,g2_pow,oodelta_pow,B_vals)
Bz = B_array(na,nb,nc,nd,PA_pow[2],PB_pow[2],QC_pow[2],QD_pow[2],QP_pow[2],g1_pow,g2_pow,oodelta_pow,B_vals)
with loops.Scope() as S:
S.primitive = 0.
S.I = 0
S.J = 0
S.K = 0
for _ in S.while_range(lambda: S.I < la + lb + lc + ld + 1):
S.J = 0
tmp = Bx[S.I]
for _ in S.while_range(lambda: S.J < ma + mb + mc + md + 1):
S.K = 0
tmp *= By[S.J]
for _ in S.while_range(lambda: S.K < na + nb + nc + nd + 1):
tmp *= Bz[S.K] * boys_eval[S.I + S.J + S.K]
S.primitive += tmp
S.K += 1
S.J += 1
S.I += 1
tei = prefactor * S.primitive
s.G = jax.ops.index_add(s.G, jax.ops.index[i,j,k,l], tei)
s.d += 1
s.c += 1
s.b += 1
s.a += 1
return s.G
0
Source : decoders.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def _decode_graph_fts(decoders, t: str, h_t: _Array,
graph_fts: _Array) -> _Array:
"""Decodes graph features."""
gr_emb = jnp.max(h_t, axis=-2)
pred_n = decoders[0](gr_emb)
pred_g = decoders[1](graph_fts)
pred = pred_n + pred_g
if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
preds = jnp.squeeze(pred, -1)
elif t == _Type.CATEGORICAL:
preds = pred
elif t == _Type.POINTER:
pred_2 = decoders[2](h_t)
ptr_p = jnp.matmul(
jnp.expand_dims(pred, 1), jnp.transpose(pred_2, (0, 2, 1)))
preds = jnp.squeeze(ptr_p, 1)
return preds
def maybe_decode_diffs(
0
Source : agent.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def __init__(
self,
preprocessor: processors.Processor,
sample_network_input: jnp.ndarray,
network: parts.Network,
optimizer: optax.GradientTransformation,
transition_accumulator: replay_lib.TransitionAccumulator,
replay: replay_lib.PrioritizedTransitionReplay,
batch_size: int,
exploration_epsilon: Callable[[int], float],
min_replay_capacity_fraction: float,
learn_period: int,
target_network_update_period: int,
grad_error_bound: float,
rng_key: parts.PRNGKey,
):
self._preprocessor = preprocessor
self._replay = replay
self._transition_accumulator = transition_accumulator
self._batch_size = batch_size
self._exploration_epsilon = exploration_epsilon
self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
self._learn_period = learn_period
self._target_network_update_period = target_network_update_period
# Initialize network parameters and optimizer.
self._rng_key, network_rng_key = jax.random.split(rng_key)
self._online_params = network.init(network_rng_key,
sample_network_input[None, ...])
self._target_params = self._online_params
self._opt_state = optimizer.init(self._online_params)
# Other agent state: last action, frame count, etc.
self._action = None
self._frame_t = -1 # Current frame index.
self._statistics = {'state_value': np.nan}
self._max_seen_priority = 1.
# Define jitted loss, update, and policy functions here instead of as
# class methods, to emphasize that these are meant to be pure functions
# and should not access the agent object's state via `self`.
def loss_fn(online_params, target_params, transitions, weights, rng_key):
"""Calculates loss given network parameters and transitions."""
_, *apply_keys = jax.random.split(rng_key, 4)
q_tm1 = network.apply(online_params, apply_keys[0],
transitions.s_tm1).q_values
q_t = network.apply(online_params, apply_keys[1],
transitions.s_t).q_values
q_target_t = network.apply(target_params, apply_keys[2],
transitions.s_t).q_values
td_errors = _batch_double_q_learning(
q_tm1,
transitions.a_tm1,
transitions.r_t,
transitions.discount_t,
q_target_t,
q_t,
)
td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
grad_error_bound)
losses = rlax.l2_loss(td_errors)
chex.assert_shape((losses, weights), (self._batch_size,))
# This is not the same as using a huber loss and multiplying by weights.
loss = jnp.mean(losses * weights)
return loss, td_errors
def update(rng_key, opt_state, online_params, target_params, transitions,
weights):
"""Computes learning update from batch of replay transitions."""
rng_key, update_key = jax.random.split(rng_key)
d_loss_d_params, td_errors = jax.grad(
loss_fn, has_aux=True)(online_params, target_params, transitions,
weights, update_key)
updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state)
new_online_params = optax.apply_updates(online_params, updates)
return rng_key, new_opt_state, new_online_params, td_errors
self._update = jax.jit(update)
def select_action(rng_key, network_params, s_t, exploration_epsilon):
"""Samples action from eps-greedy policy wrt Q-values at given state."""
rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0]
a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
v_t = jnp.max(q_t, axis=-1)
return rng_key, a_t, v_t
self._select_action = jax.jit(select_action)
def step(self, timestep: dm_env.TimeStep) -> parts.Action:
0
Source : agent.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def __init__(
self,
preprocessor: processors.Processor,
sample_network_input: jnp.ndarray,
network: parts.Network,
support: jnp.ndarray,
optimizer: optax.GradientTransformation,
transition_accumulator: Any,
replay: replay_lib.PrioritizedTransitionReplay,
batch_size: int,
min_replay_capacity_fraction: float,
learn_period: int,
target_network_update_period: int,
rng_key: parts.PRNGKey,
):
self._preprocessor = preprocessor
self._replay = replay
self._transition_accumulator = transition_accumulator
self._batch_size = batch_size
self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
self._learn_period = learn_period
self._target_network_update_period = target_network_update_period
# Initialize network parameters and optimizer.
self._rng_key, network_rng_key = jax.random.split(rng_key)
self._online_params = network.init(network_rng_key,
sample_network_input[None, ...])
self._target_params = self._online_params
self._opt_state = optimizer.init(self._online_params)
# Other agent state: last action, frame count, etc.
self._action = None
self._frame_t = -1 # Current frame index.
self._statistics = {'state_value': np.nan}
self._max_seen_priority = 1.
# Define jitted loss, update, and policy functions here instead of as
# class methods, to emphasize that these are meant to be pure functions
# and should not access the agent object's state via `self`.
def loss_fn(online_params, target_params, transitions, weights, rng_key):
"""Calculates loss given network parameters and transitions."""
_, *apply_keys = jax.random.split(rng_key, 4)
logits_q_tm1 = network.apply(online_params, apply_keys[0],
transitions.s_tm1).q_logits
q_t = network.apply(online_params, apply_keys[1],
transitions.s_t).q_values
logits_q_target_t = network.apply(target_params, apply_keys[2],
transitions.s_t).q_logits
losses = _batch_categorical_double_q_learning(
support,
logits_q_tm1,
transitions.a_tm1,
transitions.r_t,
transitions.discount_t,
support,
logits_q_target_t,
q_t,
)
loss = jnp.mean(losses * weights)
chex.assert_shape((losses, weights), (self._batch_size,))
return loss, losses
def update(rng_key, opt_state, online_params, target_params, transitions,
weights):
"""Computes learning update from batch of replay transitions."""
rng_key, update_key = jax.random.split(rng_key)
d_loss_d_params, losses = jax.grad(
loss_fn, has_aux=True)(online_params, target_params, transitions,
weights, update_key)
updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state)
new_online_params = optax.apply_updates(online_params, updates)
return rng_key, new_opt_state, new_online_params, losses
self._update = jax.jit(update)
def select_action(rng_key, network_params, s_t):
"""Computes greedy (argmax) action wrt Q-values at given state."""
rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0]
a_t = rlax.greedy().sample(policy_key, q_t)
v_t = jnp.max(q_t, axis=-1)
return rng_key, a_t, v_t
self._select_action = jax.jit(select_action)
def step(self, timestep: dm_env.TimeStep) -> parts.Action:
0
Source : networks.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def logdet_matmul(xs: Sequence[jnp.ndarray],
w: Optional[jnp.ndarray] = None) -> jnp.ndarray:
"""Combines determinants and takes dot product with weights in log-domain.
We use the log-sum-exp trick to reduce numerical instabilities.
Args:
xs: FermiNet orbitals in each determinant. Either of length 1 with shape
(ndet, nelectron, nelectron) (full_det=True) or length 2 with shapes
(ndet, nalpha, nalpha) and (ndet, nbeta, nbeta) (full_det=False,
determinants are factorised into block-diagonals for each spin channel).
w: weight of each determinant. If none, a uniform weight is assumed.
Returns:
sum_i w_i D_i in the log domain, where w_i is the weight of D_i, the i-th
determinant (or product of the i-th determinant in each spin channel, if
full_det is not used).
"""
# Special case to avoid taking log(0) if any matrix is of size 1x1.
# We can avoid this by not going into the log domain and skipping the
# log-sum-exp trick.
det1 = functools.reduce(
lambda a, b: a * b,
[x.reshape(-1) for x in xs if x.shape[-1] == 1],
1
)
# Compute the logdet for all matrices larger than 1x1
sign_in, logdet = functools.reduce(
lambda a, b: (a[0] * b[0], a[1] + b[1]),
[slogdet(x) for x in xs if x.shape[-1] > 1],
(1, 0)
)
# log-sum-exp trick
maxlogdet = jnp.max(logdet)
det = sign_in * det1 * jnp.exp(logdet - maxlogdet)
if w is None:
result = jnp.sum(det)
else:
result = jnp.dot(det, w)
sign_out = jnp.sign(result)
log_out = jnp.log(jnp.abs(result)) + maxlogdet
return sign_out, log_out
def fermi_net_orbitals(
0
Source : pga_strategy.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def _build_optimizer(self, method, n_iter, lr, lower_bound, upper_bound):
epsilon = jnp.max(upper_bound - lower_bound) / 2
if method == 'square':
init_fn = utils.bounded_initialize_fn(bounds=(lower_bound, upper_bound))
return square.Square(
num_steps=n_iter,
epsilon=epsilon,
bounds=(lower_bound, upper_bound),
initialize_fn=init_fn)
elif method == 'pgd':
init_fn = utils.noop_initialize_fn()
project_fn = utils.linf_project_fn(
epsilon=epsilon, bounds=(lower_bound, upper_bound))
optimizer = optimizer_module.IteratedFGSM(lr)
return optimizer_module.PGD(optimizer, n_iter, init_fn, project_fn)
else:
raise ValueError(f'Unknown method: "{method}"')
def supports_stochastic_parameters(self):
0
Source : policies.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def gumbel_muzero_policy(
params: base.Params,
rng_key: chex.PRNGKey,
root: base.RootFnOutput,
recurrent_fn: base.RecurrentFn,
num_simulations: int,
invalid_actions: Optional[chex.Array] = None,
max_depth: Optional[int] = None,
*,
qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
max_num_considered_actions: int = 16,
gumbel_scale: chex.Numeric = 1.,
) -> base.PolicyOutput[action_selection.GumbelMuZeroExtraData]:
"""Runs Gumbel MuZero search and returns the `PolicyOutput`.
This policy implements Full Gumbel MuZero from
"Policy improvement by planning with Gumbel".
https://openreview.net/forum?id=bERaNdoegnO
At the root of the search tree, actions are selected by Sequential Halving
with Gumbel. At non-root nodes (aka interior nodes), actions are selected by
the Full Gumbel MuZero deterministic action selection.
In the shape descriptions, `B` denotes the batch dimension.
Args:
params: params to be forwarded to root and recurrent functions.
rng_key: random number generator state, the key is consumed.
root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
`prior_logits` are from a policy network. The shapes are
`([B, num_actions], [B], [B, ...])`, respectively.
recurrent_fn: a callable to be called on the leaf nodes and unvisited
actions retrieved by the simulation step, which takes as args
`(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
and the new state embedding. The `rng_key` argument is consumed.
num_simulations: the number of simulations.
invalid_actions: a mask with invalid actions. Invalid actions
have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
max_depth: maximum search tree depth allowed during simulation.
qtransform: function to obtain completed Q-values for a node.
max_num_considered_actions: the maximum number of actions expanded at the
root node. A smaller number of actions will be expanded if the number of
valid actions is smaller.
gumbel_scale: scale for the Gumbel noise. Evalution on perfect-information
games can use gumbel_scale=0.0.
Returns:
`PolicyOutput` containing the proposed action, action_weights and the used
search tree.
"""
# Masking invalid actions.
root = root.replace(
prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions))
# Generating Gumbel.
rng_key, gumbel_rng = jax.random.split(rng_key)
gumbel = gumbel_scale * jax.random.gumbel(
gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype)
# Searching.
extra_data = action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel)
search_tree = search.search(
params=params,
rng_key=rng_key,
root=root,
recurrent_fn=recurrent_fn,
root_action_selection_fn=functools.partial(
action_selection.gumbel_muzero_root_action_selection,
num_simulations=num_simulations,
max_num_considered_actions=max_num_considered_actions,
qtransform=qtransform,
),
interior_action_selection_fn=functools.partial(
action_selection.gumbel_muzero_interior_action_selection,
qtransform=qtransform,
),
num_simulations=num_simulations,
max_depth=max_depth,
invalid_actions=invalid_actions,
extra_data=extra_data)
summary = search_tree.summary()
# Acting with the best action from the most visited actions.
# The "best" action has the highest `gumbel + logits + q`.
# Inside the minibatch, the considered_visit can be different on states with
# a smaller number of valid actions.
considered_visit = jnp.max(summary.visit_counts, axis=-1, keepdims=True)
# The completed_qvalues include imputed values for unvisited actions.
completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])(
search_tree, search_tree.ROOT_INDEX)
to_argmax = seq_halving.score_considered(
considered_visit, gumbel, root.prior_logits, completed_qvalues,
summary.visit_counts)
action = action_selection.masked_argmax(to_argmax, invalid_actions)
# Producing action_weights usable to train the policy network.
completed_search_logits = _mask_invalid_actions(
root.prior_logits + completed_qvalues, invalid_actions)
action_weights = jax.nn.softmax(completed_search_logits)
return base.PolicyOutput(
action=action,
action_weights=action_weights,
search_tree=search_tree)
def _mask_invalid_actions(logits, invalid_actions):
0
Source : qtransforms.py
with Apache License 2.0
from deepmind
with Apache License 2.0
from deepmind
def qtransform_completed_by_mix_value(
tree: tree_lib.Tree,
node_index: chex.Numeric,
*,
value_scale: chex.Numeric = 0.1,
maxvisit_init: chex.Numeric = 50.0,
rescale_values: bool = True,
use_mixed_value: bool = True,
epsilon: chex.Numeric = 1e-8,
) -> chex.Array:
"""Returns completed qvalues.
The missing Q-values of the unvisited actions are replaced by the
mixed value, defined in Appendix D of
"Policy improvement by planning with Gumbel":
https://openreview.net/forum?id=bERaNdoegnO
The Q-values are transformed by a linear transformation:
`(maxvisit_init + max(visit_counts)) * value_scale * qvalues`.
Args:
tree: _unbatched_ MCTS tree state.
node_index: scalar index of the parent node.
value_scale: scale for the Q-values.
maxvisit_init: offset to the `max(visit_counts)` in the scaling factor.
rescale_values: if True, scale the qvalues by `1 / (max_q - min_q)`.
use_mixed_value: if True, complete the Q-values with mixed value,
otherwise complete the Q-values with the raw value.
epsilon: the minimum denominator when using `rescale_values`.
Returns:
Completed Q-values. Shape `[num_actions]`.
"""
chex.assert_shape(node_index, ())
qvalues = tree.qvalues(node_index)
visit_counts = tree.children_visits[node_index]
# Computing the mixed value and producing completed_qvalues.
raw_value = tree.raw_values[node_index]
prior_probs = jax.nn.softmax(
tree.children_prior_logits[node_index])
if use_mixed_value:
value = _compute_mixed_value(
raw_value,
qvalues=qvalues,
visit_counts=visit_counts,
prior_probs=prior_probs)
else:
value = raw_value
completed_qvalues = _complete_qvalues(
qvalues, visit_counts=visit_counts, value=value)
# Scaling the Q-values.
if rescale_values:
completed_qvalues = _rescale_qvalues(completed_qvalues, epsilon)
maxvisit = jnp.max(visit_counts, axis=-1)
visit_scale = maxvisit_init + maxvisit
return visit_scale * value_scale * completed_qvalues
def _rescale_qvalues(qvalues, epsilon):
0
Source : loop_test.py
with MIT License
from gehring
with MIT License
from gehring
def testUnrollGrad(self, jit):
max_steps = 10
def step(x):
return x*0.1
def converge_test(x_new, x_old):
return np.max(x_new - x_old) < 1e-3
init_x = np.ones(())
def run_unrolled(x):
return loop.fixed_point_iteration(
init_x=x,
func=step,
convergence_test=converge_test,
max_iter=max_steps,
batched_iter_size=1,
unroll=True,
).value
grad_fun = jax.grad(run_unrolled)
if jit:
grad_fun = jax.jit(grad_fun)
grad_fun(init_x)
def testBatchedRaise(self):
0
Source : loss.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def compute_normalization_fixed_point(activations: jnp.ndarray,
t: float,
num_iters: int = 5):
"""Returns the normalization value for each example (t > 1.0).
Args:
activations: A multi-dimensional array with last dimension `num_classes`.
t: Temperature 2 (> 1.0 for tail heaviness).
num_iters: Number of iterations to run the method.
Return: An array of same rank as activation with the last dimension being 1.
"""
mu = jnp.max(activations, -1, keepdims=True)
normalized_activations_step_0 = activations - mu
def cond_fun(carry):
_, iters = carry
return iters < num_iters
def body_fun(carry):
normalized_activations, iters = carry
logt_partition = jnp.sum(
exp_t(normalized_activations, t), -1, keepdims=True)
normalized_activations_t = normalized_activations_step_0 * jnp.power(
logt_partition, 1.0 - t)
return normalized_activations_t, iters + 1
normalized_activations_t, _ = while_loop(cond_fun, body_fun,
(normalized_activations_step_0, 0))
logt_partition = jnp.sum(
exp_t(normalized_activations_t, t), -1, keepdims=True)
return -log_t(1.0 / logt_partition, t) + mu
@jax.jit
0
Source : loss.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def compute_normalization_binary_search(activations: jnp.ndarray,
t: float,
num_iters: int = 10):
"""Returns the normalization value for each example (t < 1.0).
Args:
activations: A multi-dimensional array with last dimension `num_classes`.
t: Temperature 2 ( < 1.0 for finite support).
num_iters: Number of iterations to run the method.
Return: An array of same rank as activation with the last dimension being 1.
"""
mu = jnp.max(activations, -1, keepdims=True)
normalized_activations = activations - mu
shape_activations = activations.shape
effective_dim = jnp.float32(
jnp.sum(
jnp.int32(normalized_activations > -1.0 / (1.0 - t)),
-1,
keepdims=True))
shape_partition = list(shape_activations[:-1]) + [1]
lower = jnp.zeros(shape_partition)
upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition)
def cond_fun(carry):
_, _, iters = carry
return iters < num_iters
def body_fun(carry):
lower, upper, iters = carry
logt_partition = (upper + lower) / 2.0
sum_probs = jnp.sum(
exp_t(normalized_activations - logt_partition, t), -1, keepdims=True)
update = jnp.float32(sum_probs < 1.0)
lower = jnp.reshape(lower * update + (1.0 - update) * logt_partition,
shape_partition)
upper = jnp.reshape(upper * (1.0 - update) + update * logt_partition,
shape_partition)
return lower, upper, iters + 1
lower = jnp.zeros(shape_partition)
upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition)
lower, upper, _ = while_loop(cond_fun, body_fun, (lower, upper, 0))
logt_partition = (upper + lower) / 2.0
return logt_partition + mu
@jax.jit
0
Source : functions.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def log_softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
r"""Log-Softmax function.
Computes the logarithm of the :code:`softmax` function, which rescales
elements to the range :math:`[-\infty, 0)`.
.. math ::
\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
\right)
Args:
x : input array
axis: the axis or axes along which the :code:`log_softmax` should be
computed. Either an integer or a tuple of integers.
where: Elements to include in the :code:`log_softmax`.
initial: The minimum value used to shift the input array. Must be present
when :code:`where` is not None.
"""
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
shifted = x - lax.stop_gradient(x_max)
shifted_logsumexp = jnp.log(
jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True))
return shifted - shifted_logsumexp
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
#@partial(jax.jit, static_argnames=("axis",))
def softmax(x: Array,
0
Source : functions.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
such that the elements along :code:`axis` sum to :math:`1`.
.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Args:
x : input array
axis: the axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:`1`.
Either an integer or a tuple of integers.
where: Elements to include in the :code:`softmax`.
initial: The minimum value used to shift the input array. Must be present
when :code:`where` is not None.
"""
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
@partial(jax.jit, static_argnames=("axis",))
0
Source : jet_test.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def test_all_max(self): self.unary_check(jnp.max)
@jtu.skip_on_devices("tpu")
0
Source : vector.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def max(self):
parts = map(jnp.max, tree_util.tree_leaves(self))
return jnp.asarray(list(parts)).max()
@tree_util.register_pytree_node_class
0
Source : active_learning.py
with Apache License 2.0
from google
with Apache License 2.0
from google
def get_msp_scores(logits, masks):
"""Obtain scores using maximum softmax probability scoring.
Args:
logits: the logits of the pool set.
masks: the masks belonging to the pool set.
Returns:
a list of scores belonging to the pool set.
"""
probs = jax.nn.softmax(logits)
max_probs = jnp.max(probs, axis=-1)
# High max prob means low uncertainty, so we invert the value.
msp_scores = -max_probs
msp_scores = jnp.where(masks, msp_scores, NINF_SCORE)
return msp_scores
def get_uniform_scores(masks, rng):
0
Source : moe.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def _get_top_experts_per_item_einsum_dispatcher(
gates: Array, num_selected_experts: int, capacity: int,
batch_priority: bool, **dispatcher_kwargs) -> EinsumDispatcher:
"""Returns an EinsumDispatcher performing Top-Experts-Per-Item routing.
Args:
gates: (S, E) array with the gating values for each (item, expert).
These values will also be used as combine_weights for the selected pairs.
num_selected_experts: Maximum number of experts to select per each item.
capacity: Maximum number of items processed by each expert.
batch_priority: Whether to use batch priority routing or not.
**dispatcher_kwargs: Additional arguments for the EinsumDispatcher.
Returns:
An EinsumDispatcher object.
"""
_, _, buffer_idx = _get_top_experts_per_item_common(
gates, num_selected_experts, batch_priority)
# (S, K, E) -> (S, E). Select the only buffer index for each (item, expert).
buffer_idx = jnp.max(buffer_idx, axis=1)
# (S, E, C). Convert the buffer indices to a one-hot matrix. We rely on the
# fact that indices < 0 or >= capacity will be ignored by the dispatcher.
dispatch_weights = jax.nn.one_hot(buffer_idx, capacity, dtype=jnp.bool_)
einsum_precision = dispatcher_kwargs.get("einsum_precision",
jax.lax.Precision.DEFAULT)
combine_weights = jnp.einsum(
"SE,SEC->SEC", gates, dispatch_weights, precision=einsum_precision)
return EinsumDispatcher(
combine_weights=combine_weights,
dispatch_weights=dispatch_weights,
**dispatcher_kwargs)
def _get_top_experts_per_item_expert_indices_dispatcher(
0
Source : moe.py
with Apache License 2.0
from google-research
with Apache License 2.0
from google-research
def _get_top_experts_per_item_expert_indices_dispatcher(
gates: Array, num_selected_experts: int, capacity: int,
batch_priority: bool, **dispatcher_kwargs) -> ExpertIndicesDispatcher:
"""Returns an ExpertIndicesDispatcher performing Top-Experts-Per-Item routing.
Args:
gates: (S, E) array with the gating values for each (item, expert).
These values will also be used as combine_weights for the selected pairs.
num_selected_experts: Maximum number of experts to select per each item.
capacity: Maximum number of items processed by each expert.
batch_priority: Whether to use batch priority routing or not.
**dispatcher_kwargs: Additional arguments for the ExpertIndicesDispatcher.
Returns:
An ExpertIndicesDispatcher object.
"""
_, num_experts = gates.shape
combine_weights, expert_idx, buffer_idx = _get_top_experts_per_item_common(
gates, num_selected_experts, batch_priority)
# (S, K, E) -> (S, K). Select the only buffer index for each (item, k_choice).
buffer_idx = jnp.max(buffer_idx, axis=2)
return ExpertIndicesDispatcher(
indices=jnp.stack([expert_idx, buffer_idx], axis=-1),
combine_weights=combine_weights,
num_experts=num_experts,
capacity=capacity,
**dispatcher_kwargs)
0
Source : semirings.py
with MIT License
from harvardnlp
with MIT License
from harvardnlp
def sum(xs, dim=-1):
return np.max(xs, axis=dim)
0
Source : train_tools.py
with MIT License
from jemisjoky
with MIT License
from jemisjoky
def minibatches(str_set, mini_size, keep_end=False):
"""
Convert StrSet object into iterator over StrSet's of size `mini_size`
Args:
str_set: StrSet object, which in practice will be a large one
holding an entire dataset
mini_size: Size of the minibatches we want to yield
keep_end: Whether to use the last part of our dataset when
mini_size doesn't evenly divide the number of strings
in str_set (Default: False)
Returns:
batch_iter: An iterator over minibatches of size mini_size, each
of which is itself a StrSet instance
"""
index_mat, str_lens = str_set.index_mat, str_set.str_lens
num_batches, tiny_batch = divmod(len(index_mat), mini_size)
if keep_end and tiny_batch > 0:
num_batches += 1
for ind in range(num_batches):
# Pull out part of index_mat and str_lens, truncate former so that
# it doesn't contain unnecessary padding
ind_mat = index_mat[ind*mini_size: (ind+1)*mini_size]
s_lens = str_lens[ind*mini_size: (ind+1)*mini_size]
m_len = jnp.max(s_lens)
assert jnp.all(ind_mat[:, m_len:] == 0)
ind_mat = ind_mat[:, :m_len]
yield StrSet(index_mat=ind_mat, str_lens=s_lens)
def to_string(str_set, alphabet):
0
Source : base.py
with MIT License
from markusschmitt
with MIT License
from markusschmitt
def get_s_primes(self, s, *args):
"""Compute matrix elements
For a list of computational basis states :math:`s` this member function computes the corresponding \
matrix elements :math:`O_{s,s'}=\langle s|\hat O|s'\\rangle` and their respective configurations \
:math:`s'`.
Arguments:
* ``s``: Array of computational basis states.
* ``*args``: Further positional arguments that are passed on to the specific operator implementation.
Returns:
An array holding `all` configurations :math:`s'` and the corresponding matrix elements :math:`O_{s,s'}`.
"""
if (not self.compiled) or self.compiled_argnum!=len(args):
_get_s_primes = jax.vmap(self.compile(), in_axes=(0,)+(None,)*len(args))
self._get_s_primes_pmapd = global_defs.pmap_for_my_devices(_get_s_primes, in_axes=(0,)+(None,)*len(args))
self.compiled = True
self.compiled_argnum = len(args)
# Compute matrix elements
self.sp, self.matEl = self._get_s_primes_pmapd(s, *args)
# Get only non-zero contributions
idx, self.numNonzero = self._find_nonzero_pmapd(self.matEl)
self.matEl = self._set_zero_to_zero_pmapd(self.matEl, idx[..., :jnp.max(self.numNonzero)], self.numNonzero)
self.sp = self._array_idx_pmapd(self.sp, idx[..., :jnp.max(self.numNonzero)])
return self._flatten_pmapd(self.sp), self.matEl
def _get_O_loc(self, matEl, logPsiS, logPsiSP):
0
Source : adamp.py
with Apache License 2.0
from nestordemeure
with Apache License 2.0
from nestordemeure
def _is_scale_invariant(param, grad, delta, eps):
"""test to determine if the tensor is scale invariant"""
return jnp.max(_cosine_similarity(param, grad, eps)) < delta / jnp.sqrt(param.shape[1])
def _make_view_conditionalupdate(update, param, grad, delta, view_func, eps):
0
Source : hmm_logspace_lib.py
with MIT License
from probml
with MIT License
from probml
def logdotexp(u, v, axis=-1):
'''
Calculates jnp.log(jnp.exp(u) * jnp.exp(v)) in a stable way.
Parameters
----------
u : array
v : array
axis : int
Returns
-------
* array
Logarithm of the Hadamard product of u and v
'''
max_u = jnp.max(u, axis=axis, keepdims=True)
max_v = jnp.max(v, axis=axis, keepdims=True)
diff_u = jnp.nan_to_num(u - max_u, -jnp.inf)
diff_v = jnp.nan_to_num(v - max_v, -jnp.inf)
u_dot_v = jnp.log(jnp.exp(diff_u) * jnp.exp(diff_v))
u_dot_v = u_dot_v + max_u + max_v
return u_dot_v
def log_normalize(u, axis=-1):
0
Source : vectorise.py
with MIT License
from SamDuffield
with MIT License
from SamDuffield
def auto_axes_lims(vec_dens: Callable,
xlim: float = 10.,
ylim: float = 10.):
# Assumes ijnput is a vectorised function that goes to 0 in the tails
# Initial evaluation grid
ix, iy, grid = _generate_plot_grid([-xlim, xlim], [-ylim, ylim], resolution=100, linspace=True)
z = vec_dens(grid)
# Find mode
max_z = jnp.max(z)
if jnp.isnan(max_z):
raise TypeError('nan found attempting auto_axes_lims, try giving manual xlim and ylim as kwargs')
if max_z == 0.:
return auto_axes_lims(vec_dens, xlim=xlim/1.5, ylim=ylim/1.5)
# Area with probability mass
z_keep = z > max_z / 10
# Find bounds of area with probability mass
xlim_new = jnp.array([ix[_find_first_non_zero_row(z_keep.T, direction)] for direction in [1, -1]])
ylim_new = jnp.array([iy[_find_first_non_zero_row(z_keep, direction)] for direction in [1, -1]])
if xlim in jnp.abs(xlim_new) or ylim in jnp.abs(ylim_new):
return auto_axes_lims(vec_dens,
xlim=2*xlim if xlim in jnp.abs(xlim_new) else xlim,
ylim=2*ylim if ylim in jnp.abs(ylim_new) else ylim)
# Expand
expansion = 0.05
xlim += (xlim_new[1] - xlim_new[0]) * expansion * jnp.array([-1, 1])
ylim += (ylim_new[1] - ylim_new[0]) * expansion * jnp.array([-1, 1])
return tuple(xlim_new), tuple(ylim_new)
def _plot_densf(ax, x, y, z, **kwargs):
0
Source : jax_loss.py
with GNU Affero General Public License v3.0
from synsense
with GNU Affero General Public License v3.0
from synsense
def softmax(x: np.ndarray, temperature: float = 1.0) -> np.ndarray:
"""
Implements the softmax function
.. math::
S(x, \\tau) = \\exp(l / \\tau) / { \\Sigma { \\exp(l / \\tau)} }
l = x - \\max(x)
Args:
x (np.ndarray): Input vector of scores
temperature (float): Temperature :math:`\\tau` of the softmax. As :math:`\\tau \\rightarrow 0`, the function becomes a hard :math:`\\max` operation. Default: ``1.0``.
Returns:
np.ndarray: The output of the softmax.
"""
logits = x - np.max(x)
eta = np.exp(logits / temperature)
return eta / np.sum(eta)
def logsoftmax(x: np.ndarray, temperature: float = 1.0) -> np.ndarray:
0
Source : attentions.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def _log_softmax_with_extra_logit(self, logits: JTensor) -> JTensor:
"""Compute log softmax with extra logit.
self.params.attention_extra_logit is an user defined float value that
helps to stablize logit values so that they don't drift too much from it.
Args:
logits: input logit tensor
Returns:
Log softmax with extra logit value.
"""
# Applies stop_gradient to max_logit instead of logits.
max_logit = jnp.max(jax.lax.stop_gradient(logits), axis=-1, keepdims=True)
extra_logit = self.params.attention_extra_logit
if extra_logit is not None:
extra_logit = jnp.asarray(extra_logit, dtype=max_logit.dtype)
max_logit = jnp.maximum(max_logit, extra_logit)
exp_x = jnp.exp(logits - max_logit)
sum_exp_x = jnp.sum(exp_x, axis=-1, keepdims=True)
if extra_logit is not None:
sum_exp_x += jnp.exp(extra_logit - max_logit)
return logits - jnp.log(sum_exp_x) - max_logit
def _dot_atten(
0
Source : matching.py
with MIT License
from tristandeleu
with MIT License
from tristandeleu
def matching_log_probas(embeddings, targets, test_embeddings, num_classes, eps=1e-8):
num_samples = test_embeddings.shape[0]
similarities = pairwise_cosine_similarity(embeddings, test_embeddings, eps=eps)
logsumexp = nn.logsumexp(similarities, axis=0, keepdims=True)
max_similarities = jnp.max(similarities, axis=0, keepdims=True)
exp_similarities = jnp.exp(similarities - max_similarities)
sum_exp = jnp.zeros((num_classes, num_samples), dtype=exp_similarities.dtype)
indices = jnp.expand_dims(targets, axis=-1)
dimension_numbers = ScatterDimensionNumbers(
update_window_dims=(1,),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))
sum_exp = scatter_add(sum_exp, indices, exp_similarities, dimension_numbers)
return jnp.log(sum_exp) + max_similarities - logsumexp
def matching_probas(embeddings, targets, test_embeddings, num_classes, eps=1e-8):
See More Examples