inox.random¶
Extended utilities for random number generation.
Classes¶
Creates a pseudo-random number generator (PRNG). |
Functions¶
Descriptions¶
- class inox.random.PRNG(seed, **kwargs)¶
Creates a pseudo-random number generator (PRNG).
This class is a thin wrapper around the
jax.randommodule, and allows to generate new PRNG keys or sample from distributions without having to split keys withjax.random.splitby hand.- Parameters:
kwargs – Keyword arguments passed to
jax.random.PRNGKey.
Example
>>> rng = PRNG(42) >>> rng.split() # generates a key Array([2465931498, 3679230171], dtype=uint32) >>> rng.split(3) # generates a vector of 3 keys Array([[ 956272045, 3465119146], [1903583750, 988321301], [3226638877, 2833683589]], dtype=uint32) >>> rng.normal((5,)) Array([ 0.5694761 , -1.4582146 , 0.2309113 , -0.03029377, 0.11095619], dtype=float32)
- inox.random.set_rng(rng)¶
Sets the PRNG within a context.
- Parameters:
rng (PRNG) – A PRNG instance.
Example
>>> with set_rng(PRNG(0)): ... a = get_rng().split() ... b = get_rng().normal((2, 3))
- inox.random.get_rng()¶
Returns the context-bound PRNG.