inox.numpy¶
Extended NumPy interface.
Functions¶
Descriptions¶
- inox.numpy.flatten(x, start=0, stop=None)¶
Flattens an axis range of an array.
- Parameters:
- Returns:
The flattened array.
- Return type:
Example
>>> x = jax.numpy.zeros((2, 3, 5)) >>> flatten(x, 0, 2).shape (6, 5)
- inox.numpy.unflatten(x, axis, shape)¶
Reshapes an axis of an array.
- Parameters:
- Returns:
The array with the reshaped axis.
- Return type:
Example
>>> x = jax.numpy.zeros((6, 5)) >>> unflatten(x, 0, (2, 3)).shape (2, 3, 5)
- inox.numpy.vectorize(f, ndims, exclude=())¶
Vectorizes a function with broadcasting.
vectorizeis similar tojax.numpy.vectorizeexcept that it takes the number of core dimensions of arguments as signature instead of their shape.- Parameters:
- Returns:
The vectorized function.
- Return type:
Example
>>> mvp = vectorize(jax.numpy.dot, (2, 1)) # broadcasting matrix-vector product >>> A = jax.numpy.ones((5, 3)) >>> x = jax.random.normal(jax.random.key(0), (16, 3)) >>> y = mvp(A, x) >>> y.shape (16, 5)