inox.random#
Extended utilities for random number generation
Functions#
Classes#
Creates a pseudo-random number generator (PRNG). |
Descriptions#
- class inox.random.PRNG(seed, **kwargs)#
Creates a pseudo-random number generator (PRNG).
This class is a thin wrapper around the
jax.random
module, and allows to generate new PRNG keys or sample from distributions without having to split keys withjax.random.split
by hand.- Parameters:
kwargs – Keyword arguments passed to
jax.random.PRNGKey
.
Example
>>> rng = PRNG(42) >>> rng.split() # generates a key Array([2465931498, 3679230171], dtype=uint32) >>> rng.split(3) # generates a vector of 3 keys Array([[ 956272045, 3465119146], [1903583750, 988321301], [3226638877, 2833683589]], dtype=uint32) >>> rng.normal((5,)) Array([ 0.5694761 , -1.4582146 , 0.2309113 , -0.03029377, 0.11095619], dtype=float32)
- inox.random.set_rng(rng)#
Sets the PRNG within a context.
- Parameters:
rng (PRNG) – A PRNG instance.
Example
>>> with set_rng(PRNG(0)): >>> ... a = get_rng().split() >>> ... b = get_rng().normal((2, 3))
- inox.random.get_rng()#
Returns the context-bound PRNG.