inox.random¶
Extended utilities for random number generation.
Classes¶
Creates a pseudo-random number generator (PRNG). |
Functions¶
Descriptions¶
- class inox.random.PRNG(seed, **kwargs)¶
Creates a pseudo-random number generator (PRNG).
This class is a thin wrapper around the
jax.randommodule, and allows to generate new PRNG keys or sample from distributions without having to split keys withjax.random.splitby hand.- Parameters:
kwargs – Keyword arguments passed to
jax.random.PRNGKey.
Example
>>> rng = PRNG(42) >>> rng.split() # generates a key Array([1832780943, 270669613], dtype=uint32) >>> rng.split(3) # generates a vector of 3 keys Array([[3187376881, 129218101], [2350016172, 1168365246], [ 257214496, 567757975]], dtype=uint32) >>> rng.normal((5,)) Array([ 0.6611632 , -1.0414096 , 0.5554834 , -1.8841821 , 0.36664668], dtype=float32)
- inox.random.set_rng(**rngs)¶
Sets named RNG states within a context.
See also
Example
>>> with set_rng(init=PRNG(0), dropout=PRNG(42)): ... keys = get_rng("init").split(3) ... mask = get_rng("dropout").bernoulli(shape=(2, 3))