inox.random

Extended utilities for random number generation.

Classes

PRNG

Creates a pseudo-random number generator (PRNG).

Functions

set_rng

Sets named RNG states within a context.

get_rng

Returns a context-bound RNG state given its name.

Descriptions

class inox.random.PRNG(seed, **kwargs)

Creates a pseudo-random number generator (PRNG).

This class is a thin wrapper around the jax.random module, and allows to generate new PRNG keys or sample from distributions without having to split keys with jax.random.split by hand.

Parameters:

Example

>>> rng = PRNG(42)
>>> rng.split()  # generates a key
Array([1832780943,  270669613], dtype=uint32)
>>> rng.split(3)  # generates a vector of 3 keys
Array([[3187376881,  129218101],
       [2350016172, 1168365246],
       [ 257214496,  567757975]], dtype=uint32)
>>> rng.normal((5,))
Array([ 0.6611632 , -1.0414096 ,  0.5554834 , -1.8841821 ,  0.36664668],      dtype=float32)
split(num=None)
Parameters:

num (int) – The number of keys to generate.

Returns:

A new key if num=None and a vector of keys otherwise.

Return type:

Array

inox.random.set_rng(**rngs)

Sets named RNG states within a context.

See also

get_rng

Parameters:

rngs (Dict[str, PRNG]) – Named PRNG instances.

Example

>>> with set_rng(init=PRNG(0), dropout=PRNG(42)):
...     keys = get_rng("init").split(3)
...     mask = get_rng("dropout").bernoulli(shape=(2, 3))
inox.random.get_rng(name)

Returns a context-bound RNG state given its name.

See also

set_rng

Parameters:

name (str) – The RNG state’s name.