inox.numpy#
Extended NumPy interface
Functions#
Descriptions#
- inox.numpy.flatten(x, start=0, stop=None)#
Flattens an axis range of an array.
- Parameters:
- Returns:
The flattened array.
- Return type:
Example
>>> x = jax.numpy.zeros((2, 3, 5)) >>> flatten(x, 0, 2).shape (6, 5)
- inox.numpy.unflatten(x, axis, shape)#
Reshapes an axis of an array.
- Parameters:
- Returns:
The array with the reshaped axis.
- Return type:
Example
>>> x = jax.numpy.zeros((6, 5)) >>> unflatten(x, 0, (2, 3)).shape (2, 3, 5)
- inox.numpy.vectorize(f, ndims, exclude=())#
Vectorizes a function with broadcasting.
vectorize
is similar tojax.numpy.vectorize
except that it takes the number of core dimensions of arguments as signature instead of their shape.- Parameters:
- Returns:
The vectorized function.
Example
>>> mvp = vectorize(jax.numpy.dot, (2, 1)) >>> mvp(A, x) # broadcasting matrix-vector product