inox.api#

Extended user-facing transformations and utilities

The transformations provided in the inox.api module are lifted versions of native JAX transformations for which all non-array leaves (float, str, functions, …) are considered static, that is part of the tree structure.

Functions#

automask

Lifts a transformation to consider all non-array leaves as static.

Descriptions#

inox.api.automask(transform)#

Lifts a transformation to consider all non-array leaves as static.

For a function f and a JAX transformation jax.tf,

y = automask(jax.tf)(f)(x)

is equivalent to

g = lambda x: inox.tree_mask(f(inox.tree_unmask(x)))
y = inox.tree_unmask(jax.tf(g)(inox.tree_mask(x)))
Parameters:

transform (Callable) – The transformation to lift.

Returns:

The lifted transformation.

Return type:

Callable

inox.api.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)#

Sets up fun for just-in-time compilation with XLA.

Parameters:
  • fun (Callable) –

    Function to be jitted. fun should be a pure function, as side-effects may only be executed once.

    The arguments and return value of fun should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.

    JAX keeps a weak reference to fun for use as a compilation cache key, so the object fun must be weakly-referenceable. Most Callable objects will already satisfy this requirement.

  • in_shardings

    Pytree of structure matching that of arguments to fun, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.

    The in_shardings argument is optional. JAX will infer the shardings from the input jax.Array’s and defaults to replicating the input if the sharding cannot be inferred.

    The valid resource assignment specifications are:
    • XLACompatibleSharding, which will decide how the value

      will be partitioned. With this, using a mesh context manager is not required.

    • None, will give JAX the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.

    The size of every dimension has to be a multiple of the total number of resources assigned to it. This is similar to pjit’s in_shardings.

  • out_shardings

    Like in_shardings, but specifies resource assignment for function outputs. This is similar to pjit’s out_shardings.

    The out_shardings argument is optional. If not specified, jax.jit will use GSPMD’s sharding propagation to figure out what the sharding of the output(s) should be.

  • static_argnums (int | Sequence[int] | None) –

    An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.

    Static arguments should be hashable, meaning both __hash__ and __eq__ are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.

    If neither static_argnums nor static_argnames is provided, no arguments are treated as static. If static_argnums is not provided but static_argnames is, or vice versa, JAX uses inspect.signature(fun) to find any positional arguments that correspond to static_argnames (or vice versa). If both static_argnums and static_argnames are provided, inspect.signature is not used, and only actual parameters listed in either static_argnums or static_argnames will be treated as static.

  • static_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on static_argnums for details. If not provided but static_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • donate_argnums (int | Sequence[int] | None) –

    Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated.

    If neither donate_argnums nor donate_argnames is provided, no arguments are donated. If donate_argnums is not provided but donate_argnames is, or vice versa, JAX uses inspect.signature(fun) to find any positional arguments that correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual parameters listed in either donate_argnums or donate_argnames will be donated.

    For more details on buffer donation see the FAQ.

  • donate_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on donate_argnums for details. If not provided but donate_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • keep_unused (bool) – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.

  • device (xc.Device | None) – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices.) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

  • backend (str | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend: 'cpu', 'gpu', or 'tpu'.

  • inline (bool) – Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False.

Returns:

A wrapped version of fun, set up for just-in-time compilation.

Return type:

pjit.JitWrapped

Examples

In the following example, selu can be compiled into a single fused kernel by XLA:

>>> import jax
>>>
>>> @jax.jit
... def selu(x, alpha=1.67, lmbda=1.05):
...   return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.PRNGKey(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x))  
[-0.54485  0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232  0.76827  0.59566 ]

To pass arguments such as static_argnames when decorating a function, a common pattern is to use functools.partial:

>>> from functools import partial
>>>
>>> @partial(jax.jit, static_argnames=['n'])
... def g(x, n):
...   for i in range(n):
...     x = x ** 2
...   return x
>>>
>>> g(jnp.arange(4), 3)
Array([   0,    1,  256, 6561], dtype=int32)
inox.api.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())#

Creates a function that evaluates the gradient of fun.

Parameters:
  • fun (Callable) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. Argument arrays in the positions specified by argnums must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

  • argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.

  • allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

  • reduce_axes (Sequence[AxisName]) – Optional, tuple of axis names. If an axis is listed here, and fun implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if 'batch' is a named batch axis, grad(f, reduce_axes=('batch',)) will create a function that computes the total gradient while grad(f) will create one that computes the per-example gradient.

Returns:

A function with the same arguments as fun, that evaluates the gradient of fun. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a pair of (gradient, auxiliary_data) is returned.

Return type:

Callable

For example:

>>> import jax
>>>
>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> print(grad_tanh(0.2))
0.961043
inox.api.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())#

Create a function that evaluates both fun and the gradient of fun.

Parameters:
  • fun (Callable) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

  • argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.

  • allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

  • reduce_axes (Sequence[AxisName]) – Optional, tuple of axis names. If an axis is listed here, and fun implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if 'batch' is a named batch axis, value_and_grad(f, reduce_axes=('batch',)) will create a function that computes the total gradient while value_and_grad(f) will create one that computes the per-example gradient.

Returns:

A function with the same arguments as fun that evaluates both fun and the gradient of fun and returns them as a pair (a two-element tuple). If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a tuple of ((value, auxiliary_data), gradient) is returned.

Return type:

Callable[…, tuple[Any, Any]]

inox.api.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)#

Jacobian of fun evaluated column-by-column using forward-mode AD.

Parameters:
  • fun (Callable) – Function whose Jacobian is to be computed.

  • argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Returns:

A function with the same arguments as fun, that evaluates the Jacobian of fun using forward-mode automatic differentiation. If has_aux is True then a pair of (jacobian, auxiliary_data) is returned.

Return type:

Callable

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
 [ 0.       0.       5.     ]
 [ 0.      16.      -2.     ]
 [ 1.6209   0.       0.84147]]
inox.api.jacrev(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)#

Jacobian of fun evaluated row-by-row using reverse-mode AD.

Parameters:
  • fun (Callable) – Function whose Jacobian is to be computed.

  • argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

  • allow_int (bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

Returns:

A function with the same arguments as fun, that evaluates the Jacobian of fun using reverse-mode automatic differentiation. If has_aux is True then a pair of (jacobian, auxiliary_data) is returned.

Return type:

Callable

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
 [ 0.       0.       5.     ]
 [ 0.      16.      -2.     ]
 [ 1.6209   0.       0.84147]]
inox.api.hessian(fun, argnums=0, has_aux=False, holomorphic=False)#

Hessian of fun as a dense array.

Parameters:
  • fun (Callable) – Function whose Hessian is to be computed. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.

  • argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Returns:

A function with the same arguments as fun, that evaluates the Hessian of fun.

Return type:

Callable

>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.   -2.]
 [  -2. -480.]]

hessian is a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure of jax.hessian(fun)(x) is given by forming a tree product of the structure of fun(x) with a tree product of two copies of the structure of x. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:

>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': Array([[[ 2.,  0.], [ 0.,  0.]],
                         [[ 0.,  0.], [ 0., 12.]]], dtype=float32),
             'b': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32)},
       'b': {'a': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32),
             'b': Array([[[0.      , 0.      ], [0.      , 0.      ]],
                         [[0.      , 0.      ], [0.      , 3.843624]]], dtype=float32)}}}

Thus each leaf in the tree structure of jax.hessian(fun)(x) corresponds to a leaf of fun(x) and a pair of leaves of x. For each leaf in jax.hessian(fun)(x), if the corresponding array leaf of fun(x) has shape (out_1, out_2, ...) and the corresponding array leaves of x have shape (in_1_1, in_1_2, ...) and (in_2_1, in_2_2, ...) respectively, then the Hessian leaf has shape (out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...). In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.

In particular, an array is produced (with no pytrees involved) when the function input x and output fun(x) are each a single array, as in the g example above. If fun(x) has shape (out1, out2, ...) and x has shape (in1, in2, ...) then jax.hessian(fun)(x) has shape (out1, out2, ..., in1, in2, ..., in1, in2, ...). To flatten pytrees into 1D vectors, consider using jax.flatten_util.flatten_pytree.

inox.api.checkpoint(fun, *, prevent_cse=True, policy=None, static_argnums=())#

Make fun recompute internal linearization points when differentiated.

The jax.checkpoint decorator, aliased to jax.remat, provides a way to trade off computation time and memory cost in the context of automatic differentiation, especially with reverse-mode autodiff like jax.grad and jax.vjp but also with jax.linearize.

When differentiating a function in reverse-mode, by default all the linearization points (e.g. inputs to elementwise nonlinear primitive operations) are stored when evaluating the forward pass so that they can be reused on the backward pass. This evaluation strategy can lead to a high memory cost, or even to poor performance on hardware accelerators where memory access is much more expensive than FLOPs.

An alternative evaluation strategy is for some of the linearization points to be recomputed (i.e. rematerialized) rather than stored. This approach can reduce memory usage at the cost of increased computation.

This function decorator produces a new version of fun which follows the rematerialization strategy rather than the default store-everything strategy. That is, it returns a new version of fun which, when differentiated, doesn’t store any of its intermediate linearization points. Instead, these linearization points are recomputed from the function’s saved inputs.

See the examples below.

Parameters:
  • fun (Callable) – Function for which the autodiff evaluation strategy is to be changed from the default of storing all intermediate linearization points to recomputing them. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.

  • prevent_cse (bool) – Optional, boolean keyword-only argument indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under a jit or pmap, CSE can defeat the purpose of this decorator. But in some settings, like when used inside a scan, this CSE prevention mechanism is unnecessary, in which case prevent_cse can be set to False.

  • static_argnums (int | tuple[int, ...]) – Optional, int or sequence of ints, a keyword-only argument indicating which argument values on which to specialize for tracing and caching purposes. Specifying arguments as static can avoid ConcretizationTypeErrors when tracing, but at the cost of more retracing overheads. See the example below.

  • policy (Callable[..., bool] | None) – Optional, callable keyword-only argument. It should be one of the attributes of jax.checkpoint_policies. The callable takes as input a type-level specification of a first-order primitive application and returns a boolean indicating whether the corresponding output value(s) can be saved as residuals (or instead must be recomputed in the (co)tangent computation if needed).

Returns:

A function (callable) with the same input/output behavior as fun but which, when differentiated using e.g. jax.grad, jax.vjp, or jax.linearize, recomputes rather than stores intermediate linearization points, thus potentially saving memory at the cost of extra computation.

Return type:

Callable

Here is a simple example:

>>> import jax
>>> import jax.numpy as jnp
>>> @jax.checkpoint
... def g(x):
...   y = jnp.sin(x)
...   z = jnp.sin(y)
...   return z
...
>>> jax.value_and_grad(g)(2.0)
(Array(0.78907233, dtype=float32, weak_type=True), Array(-0.2556391, dtype=float32, weak_type=True))

Here, the same value is produced whether or not the jax.checkpoint decorator is present. When the decorator is not present, the values jnp.cos(2.0) and jnp.cos(jnp.sin(2.0)) are computed on the forward pass and are stored for use in the backward pass, because they are needed on the backward pass and depend only on the primal inputs. When using jax.checkpoint, the forward pass will compute only the primal outputs and only the primal inputs (2.0) will be stored for the backward pass. At that time, the value jnp.sin(2.0) is recomputed, along with the values jnp.cos(2.0) and jnp.cos(jnp.sin(2.0)).

While jax.checkpoint controls what values are stored from the forward-pass to be used on the backward pass, the total amount of memory required to evaluate a function or its VJP depends on many additional internal details of that function. Those details include which numerical primitives are used, how they’re composed, where jit and control flow primitives like scan are used, and other factors.

The jax.checkpoint decorator can be applied recursively to express sophisticated autodiff rematerialization strategies. For example:

>>> def recursive_checkpoint(funs):
...   if len(funs) == 1:
...     return funs[0]
...   elif len(funs) == 2:
...     f1, f2 = funs
...     return lambda x: f1(f2(x))
...   else:
...     f1 = recursive_checkpoint(funs[:len(funs)//2])
...     f2 = recursive_checkpoint(funs[len(funs)//2:])
...     return lambda x: f1(jax.checkpoint(f2)(x))
...

If fun involves Python control flow that depends on argument values, it may be necessary to use the static_argnums parameter. For example, consider a boolean flag argument:

from functools import partial

@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, is_training):
  if is_training:
    ...
  else:
    ...

Here, the use of static_argnums allows the if statement’s condition to depends on the value of is_training. The cost to using static_argnums is that it introduces re-tracing overheads across calls: in the example, foo is re-traced every time it is called with a new value of is_training. In some situations, jax.ensure_compile_time_eval is needed as well:

@partial(jax.checkpoint, static_argnums=(1,))
def foo(x, y):
  with jax.ensure_compile_time_eval():
    y_pos = y > 0
  if y_pos:
    ...
  else:
    ...

As an alternative to using static_argnums (and jax.ensure_compile_time_eval), it may be easier to compute some values outside the jax.checkpoint-decorated function and then close over them.

inox.api.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)#

Vectorizing map. Creates a function which maps fun over argument axes.

Parameters:
  • fun (F) – Function to be mapped over additional axes.

  • in_axes (int | None | Sequence[Any]) –

    An integer, None, or sequence of values specifying which input array axes to map over.

    If each positional argument to fun is an array, then in_axes can be an integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments to fun. An integer or None indicates which array axis to map over for all arguments (with None indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. Axis integers must be in the range [-ndim, ndim) for each array, where ndim is the number of dimensions (axes) of the corresponding input array.

    If the positional arguments to fun are container (pytree) types, in_axes must be a sequence with length equal to the number of positional arguments to fun, and for each argument the corresponding element of in_axes can be a container with a matching pytree structure specifying the mapping of its container elements. In other words, in_axes must be a container tree prefix of the positional argument tuple passed to fun. See this link for more detail: https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees

    Either axis_size must be provided explicitly, or at least one positional argument must have in_axes not None. The sizes of the mapped input axes for all mapped positional arguments must all be equal.

    Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0).

    See below for examples.

  • out_axes (Any) – An integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None out_axes specification. Axis integers must be in the range [-ndim, ndim) for each output array, where ndim is the number of dimensions (axes) of the array returned by the vmap-ed function, which is one more than the number of dimensions (axes) of the corresponding array returned by fun.

  • axis_name (AxisName | None) – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.

  • axis_size (int | None) – Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments.

Returns:

Batched/vectorized version of fun with arguments that correspond to those of fun, but with extra array axes at positions indicated by in_axes, and a return value that corresponds to that of fun, but with extra array axes at positions indicated by out_axes.

Return type:

F

For example, we can implement a matrix-matrix product using a vector dot product:

>>> import jax.numpy as jnp
>>>
>>> vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
>>> mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)

Here we use [a,b] to indicate an array with shape (a,b). Here are some variants:

>>> mv1 = vmap(vv, (0, 0), 0)   #  ([b,a], [b,a]) -> [b]        (b is the mapped axis)
>>> mv2 = vmap(vv, (0, 1), 0)   #  ([b,a], [a,b]) -> [b]        (b is the mapped axis)
>>> mm2 = vmap(mv2, (1, 1), 0)  #  ([b,c,a], [a,c,b]) -> [c,b]  (c is the mapped axis)

Here’s an example of using container types in in_axes to specify which axes of the container elements to map over:

>>> A, B, C, D = 2, 3, 4, 5
>>> x = jnp.ones((A, B))
>>> y = jnp.ones((B, C))
>>> z = jnp.ones((C, D))
>>> def foo(tree_arg):
...   x, (y, z) = tree_arg
...   return jnp.dot(x, jnp.dot(y, z))
>>> tree = (x, (y, z))
>>> print(foo(tree))
[[12. 12. 12. 12. 12.]
 [12. 12. 12. 12. 12.]]
>>> from jax import vmap
>>> K = 6  # batch size
>>> x = jnp.ones((K, A, B))  # batch axis in different locations
>>> y = jnp.ones((B, K, C))
>>> z = jnp.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree).shape)
(6, 2, 5)

Here’s another example using container types in in_axes, this time a dictionary, to specify the elements of the container to map over:

>>> dct = {'a': 0., 'b': jnp.arange(5.)}
>>> x = 1.
>>> def foo(dct, x):
...  return dct['a'] + dct['b'] + x
>>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x)
>>> print(out)
[1. 2. 3. 4. 5.]

The results of a vectorized function can be mapped or unmapped. For example, the function below returns a pair with the first element mapped and the second unmapped. Only for unmapped results we can specify out_axes to be None (to keep it unmapped).

>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), 8.0)

If the out_axes is specified for an unmapped result, the result is broadcast across the mapped axis:

>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))

If the out_axes is specified for a mapped result, the result is transposed accordingly.

Finally, here’s an example using axis_name together with collectives:

>>> xs = jnp.arange(3. * 4.).reshape(3, 4)
>>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))
[[12. 15. 18. 21.]
 [12. 15. 18. 21.]
 [12. 15. 18. 21.]]

See the jax.pmap docstring for more examples involving collectives.

inox.api.pmap(fun, axis_name=None, *, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None)#

Parallel map with support for collective operations.

The purpose of pmap is to express single-program multiple-data (SPMD) programs. Applying pmap to a function will compile the function with XLA (similarly to jit), then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it is comparable to vmap because both transformations map a function over array axes, but where vmap vectorizes functions by pushing the mapped axis down into primitive operations, pmap instead replicates the function and executes each replica on its own XLA device in parallel.

The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by jax.local_device_count (unless devices is specified, see below). For nested pmap calls, the product of the mapped axis sizes must be less than or equal to the number of XLA devices.

Note

pmap compiles fun, so while it can be combined with jit, it’s usually unnecessary.

pmap requires that all of the participating devices are identical. For example, it is not possible to use pmap to parallelize a computation across two different models of GPU. It is currently an error for the same device to participate twice in the same pmap.

Multi-process platforms: On multi-process platforms such as TPU pods, pmap is designed to be used in SPMD Python programs, where every process is running the same Python code such that all processes run the same pmapped function in the same order. Each process should still call the pmapped function with mapped axis size equal to the number of local devices (unless devices is specified, see below), and an array of the same leading axis size will be returned as usual. However, any collective operations in fun will be computed over all participating devices, including those on other processes, via device-to-device communication. Conceptually, this can be thought of as running a pmap over a single array sharded across processes, where each process “sees” only its local shard of the input and output. The SPMD model requires that the same multi-process pmaps must be run in the same order on all devices, but they can be interspersed with arbitrary operations running in a single process.

Parameters:
  • fun (Callable) – Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by static_broadcasted_argnums can be anything at all, provided they are hashable and have an equality operation defined.

  • axis_name (AxisName | None) – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.

  • in_axes – A non-negative integer, None, or nested Python container thereof that specifies which axes of positional arguments to map over. Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0). See vmap for details.

  • out_axes – A non-negative integer, None, or nested Python container thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None out_axes specification (see vmap).

  • static_broadcasted_argnums (int | Iterable[int]) –

    An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded. Calling the pmapped function with different values for these constants will trigger recompilation. If the pmapped function is called with fewer positional arguments than indicated by static_broadcasted_argnums then an error is raised. Each of the static arguments will be broadcasted to all devices. Arguments that are not arrays or containers thereof must be marked as static. Defaults to ().

    Static arguments must be hashable, meaning both __hash__ and __eq__ are implemented, and should be immutable.

  • devices (Sequence[xc.Device] | None) – This is an experimental feature and the API is likely to change. Optional, a sequence of Devices to map over. (Available devices can be retrieved via jax.devices()). Must be given identically for each process in multi-process settings (and will therefore include devices across processes). If specified, the size of the mapped axis must be equal to the number of devices in the sequence local to the given process. Nested pmap s with devices specified in either the inner or outer pmap are not yet supported.

  • backend (str | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend. ‘cpu’, ‘gpu’, or ‘tpu’.

  • axis_size (int | None) – Optional; the size of the mapped axis.

  • donate_argnums (int | Iterable[int]) –

    Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. Note that donate_argnums only work for positional arguments, and keyword arguments will not be donated.

    For more details on buffer donation see the FAQ.

Returns:

A parallelized version of fun with arguments that correspond to those of fun but with extra array axes at positions indicated by in_axes and with output that has an additional leading array axis (with the same size).

Return type:

Any

For example, assuming 8 XLA devices are available, pmap can be used as a map along a leading array axis:

>>> import jax.numpy as jnp
>>>
>>> out = pmap(lambda x: x ** 2)(jnp.arange(8))  
>>> print(out)  
[0, 1, 4, 9, 16, 25, 36, 49]

When the leading dimension is smaller than the number of available devices JAX will simply run on a subset of devices:

>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(jnp.dot)(x, y)  
>>> print(out)  
[[[    4.     9.]
  [   12.    29.]]
 [[  244.   345.]
  [  348.   493.]]
 [[ 1412.  1737.]
  [ 1740.  2141.]]]

If your leading dimension is larger than the number of available devices you will get an error:

>>> pmap(lambda x: x ** 2)(jnp.arange(9))  
ValueError: ... requires 9 replicas, but only 8 XLA devices are available

As with vmap, using None in in_axes indicates that an argument doesn’t have an extra axis and should be broadcasted, rather than mapped, across the replicas:

>>> x, y = jnp.arange(2.), 4.
>>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y)  
>>> print(out)  
([4., 5.], [8., 8.])

Note that pmap always returns values mapped over their leading axis, equivalent to using out_axes=0 in vmap.

In addition to expressing pure maps, pmap can also be used to express parallel single-program multiple-data (SPMD) programs that communicate via collective operations. For example:

>>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
>>> out = pmap(f, axis_name='i')(jnp.arange(4.))  
>>> print(out)  
[ 0.          0.16666667  0.33333334  0.5       ]
>>> print(out.sum())  
1.0

In this example, axis_name is a string, but it can be any Python object with __hash__ and __eq__ defined.

The argument axis_name to pmap names the mapped axis so that collective operations, like jax.lax.psum, can refer to it. Axis names are important particularly in the case of nested pmap functions, where collective operations can operate over distinct axes:

>>> from functools import partial
>>> import jax
>>>
>>> @partial(pmap, axis_name='rows')
... @partial(pmap, axis_name='cols')
... def normalize(x):
...   row_normed = x / jax.lax.psum(x, 'rows')
...   col_normed = x / jax.lax.psum(x, 'cols')
...   doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
...   return row_normed, col_normed, doubly_normed
>>>
>>> x = jnp.arange(8.).reshape((4, 2))
>>> row_normed, col_normed, doubly_normed = normalize(x)  
>>> print(row_normed.sum(0))  
[ 1.  1.]
>>> print(col_normed.sum(1))  
[ 1.  1.  1.  1.]
>>> print(doubly_normed.sum((0, 1)))  
1.0

On multi-process platforms, collective operations operate over all devices, including those on other processes. For example, assuming the following code runs on two processes with 4 XLA devices each:

>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
>>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8)
>>> out = pmap(f, axis_name='i')(data)  
>>> print(out)  
[28 29 30 31] # on process 0
[32 33 34 35] # on process 1

Each process passes in a different length-4 array, corresponding to its 4 local devices, and the psum operates over all 8 values. Conceptually, the two length-4 arrays can be thought of as a sharded length-8 array (in this example equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis given name ‘i’. The pmap call on each process then returns the corresponding length-4 output shard.

The devices argument can be used to specify exactly which devices are used to run the parallel computation. For example, again assuming a single process with 8 devices, the following code defines two parallel computations, one which runs on the first six devices and one on the remaining two:

>>> from functools import partial
>>> @partial(pmap, axis_name='i', devices=jax.devices()[:6])
... def f1(x):
...   return x / jax.lax.psum(x, axis_name='i')
>>>
>>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:])
... def f2(x):
...   return jax.lax.psum(x ** 2, axis_name='i')
>>>
>>> print(f1(jnp.arange(6.)))  
[0.         0.06666667 0.13333333 0.2        0.26666667 0.33333333]
>>> print(f2(jnp.array([2., 3.])))  
[ 13.  13.]