inox.debug#

Extended utilities for debugging

Functions#

same_trace

Checks whether two arrays have the same trace source.

Descriptions#

inox.debug.same_trace(x, y, ignore_primal=False)#

Checks whether two arrays have the same trace source.

Parameters:
  • x (Array) – The first array.

  • y (Array) – The second array.

  • ignore_primal (bool) – Whether to ignore primal traces (jax.grad).

Example

>>> x, y = jax.numpy.zeros(2)
>>> same_trace(x, y)
True
>>> jax.jit(lambda x, y: same_trace(x, y))(x, y)
Array(True, dtype=bool)
>>> jax.jit(lambda x: same_trace(x, y))(x)
Array(False, dtype=bool)