inox.numpy#

Extended NumPy interface

Functions#

flatten

Flattens an axis range of an array.

unflatten

Reshapes an axis of an array.

vectorize

Vectorizes a function with broadcasting.

Descriptions#

inox.numpy.flatten(x, start=0, stop=None)#

Flattens an axis range of an array.

Parameters:
  • x (Array) – An array.

  • start (int) – The start of the axis range to flatten.

  • stop (int | None) – The end of the axis range to flatten (excluded). If None, x.ndim is used instead.

Returns:

The flattened array.

Return type:

Array

Example

>>> x = jax.numpy.zeros((2, 3, 5))
>>> flatten(x, 0, 2).shape
(6, 5)
inox.numpy.unflatten(x, axis, shape)#

Reshapes an axis of an array.

Parameters:
  • x (Array) – An array.

  • axis (int) – The axis to reshape.

  • shape (Sequence[int]) – The shape of the reshaped axis.

Returns:

The array with the reshaped axis.

Return type:

Array

Example

>>> x = jax.numpy.zeros((6, 5))
>>> unflatten(x, 0, (2, 3)).shape
(2, 3, 5)
inox.numpy.vectorize(f, ndims, exclude=())#

Vectorizes a function with broadcasting.

vectorize is similar to jax.numpy.vectorize except that it takes the number of core dimensions of arguments as signature instead of their shape.

Parameters:
  • f (Callable) – A function to vectorize.

  • ndims (int | Sequence[int]) – The number of dimensions expected for each positional argument.

  • exclude (Iterable[int]) – A set of positive indices representing positional arguments that should not be vectorized.

Returns:

The vectorized function.

Example

>>> mvp = vectorize(jax.numpy.dot, (2, 1))
>>> mvp(A, x)  # broadcasting matrix-vector product