inox.random#

Extended utilities for random number generation

Functions#

set_rng

Sets the PRNG within a context.

get_rng

Returns the context-bound PRNG.

Classes#

PRNG

Creates a pseudo-random number generator (PRNG).

Descriptions#

class inox.random.PRNG(seed, **kwargs)#

Creates a pseudo-random number generator (PRNG).

This class is a thin wrapper around the jax.random module, and allows to generate new PRNG keys or sample from distributions without having to split keys with jax.random.split by hand.

Parameters:

Example

>>> rng = PRNG(42)
>>> rng.split()  # generates a key
Array([2465931498, 3679230171], dtype=uint32)
>>> rng.split(3)  # generates a vector of 3 keys
Array([[ 956272045, 3465119146],
       [1903583750,  988321301],
       [3226638877, 2833683589]], dtype=uint32)
>>> rng.normal((5,))
Array([ 0.5694761 , -1.4582146 ,  0.2309113 , -0.03029377,  0.11095619], dtype=float32)
split(num=None)#
Parameters:

num (int | None) – The number of keys to generate.

Returns:

A new key if num=None and a vector of keys otherwise.

Return type:

Array

inox.random.set_rng(rng)#

Sets the PRNG within a context.

See also

PRNG and get_rng

Parameters:

rng (PRNG) – A PRNG instance.

Example

>>> with set_rng(PRNG(0)):
>>> ... a = get_rng().split()
>>> ... b = get_rng().normal((2, 3))
inox.random.get_rng()#

Returns the context-bound PRNG.

See also

PRNG and set_rng