inox.numpy

Extended NumPy interface.

Functions

flatten

Flattens an axis range of an array.

unflatten

Reshapes an axis of an array.

vectorize

Vectorizes a function with broadcasting.

Descriptions

inox.numpy.flatten(x, start=0, stop=None)

Flattens an axis range of an array.

Parameters:
  • x (Array) – An array.

  • start (int) – The start of the axis range to flatten.

  • stop (int) – The end of the axis range to flatten (excluded). If None, x.ndim is used instead.

Returns:

The flattened array.

Return type:

Array

Example

>>> x = jax.numpy.zeros((2, 3, 5))
>>> flatten(x, 0, 2).shape
(6, 5)
inox.numpy.unflatten(x, axis, shape)

Reshapes an axis of an array.

Parameters:
  • x (Array) – An array.

  • axis (int) – The axis to reshape.

  • shape (Sequence[int]) – The shape of the reshaped axis.

Returns:

The array with the reshaped axis.

Return type:

Array

Example

>>> x = jax.numpy.zeros((6, 5))
>>> unflatten(x, 0, (2, 3)).shape
(2, 3, 5)
inox.numpy.vectorize(f, ndims, exclude=())

Vectorizes a function with broadcasting.

vectorize is similar to jax.numpy.vectorize except that it takes the number of core dimensions of arguments as signature instead of their shape.

Parameters:
  • f (Callable) – A function to vectorize.

  • ndims (int | Sequence[int]) – The number of dimensions expected for each positional argument.

  • exclude (Iterable[int]) – A set of positive indices representing positional arguments that should not be vectorized.

Returns:

The vectorized function.

Return type:

Callable

Example

>>> mvp = vectorize(jax.numpy.dot, (2, 1))  # broadcasting matrix-vector product
>>> A = jax.numpy.ones((5, 3))
>>> x = jax.random.normal(jax.random.key(0), (16, 3))
>>> y = mvp(A, x)
>>> y.shape
(16, 5)