inox.nn.padding#

Padding layers

Classes#

Pad

Creates a spatial padding layer.

Descriptions#

class inox.nn.padding.Pad(padding, mode='constant', value=0.0)#

Creates a spatial padding layer.

This module is a thin wrapper around jax.numpy.pad.

Parameters:
  • padding (Sequence[Tuple[int, int]]) – The padding applied to each end of each spatial axis.

  • mode (str) – The padding mode in {'constant', 'edge', 'reflect', 'wrap'}.

  • value (float | Array) – The padding value if mode='constant.

__call__(x)#
Parameters:

x (Array) – The input tensor \(x\), with shape \((*, H_1, \dots, H_n, C)\).

Returns:

The output tensor \(y\), with shape \((*, H_1', \dots, H_n', C)\), such that

\[H_i' = H_i + p_i\]

where \(p_i\) is the total padding of the \(i\)-th spatial axis.

Return type:

Array