inox.nn.ssm#

State space model (SSM) layers

Classes#

SISO

Abstract single-input single-output (SISO) state space model class.

S4

Creates an S4 state space model.

Descriptions#

class inox.nn.ssm.SISO(**kwargs)#

Abstract single-input single-output (SISO) state space model class.

A SISO state space model defines a system of equations of the form

\[\begin{split}\dot{x}(t) & = A x(t) + B u(t) \\ y(t) & = C x(t)\end{split}\]

where \(u(t), y(t) \in \mathbb{C}\) are input and output signals and \(x(t) \in \mathbb{C}^{H}\) is a latent/hidden state. In practice, the input and output signals are sampled every \(\Delta\) time units, leading to sequences \((x_1, x_2, \dots)\) and \((y_1, y_2, \dots)\) whose dynamics are governed by the discrete-time form of the system

\[\begin{split}x_i & = \bar{A} x_{i-1} + \bar{B} u_i \\ y_i & = \bar{C} x_i\end{split}\]

where \(\bar{A} = \exp(\Delta A)\), \(\bar{B} = A^{-1} (\bar{A} - I) B\) and \(\bar{C} = C\). Assuming \(x_0 = 0\), the dynamics can also be represented as a discrete-time convolution

\[y_{1:L} = \bar{k}_{1:L} * u_{1:L}\]

where \(\bar{k}_i = \bar{C} \bar{A}^{i-1} \bar{B} \in \mathbb{C}\).

Wikipedia

https://wikipedia.org/wiki/State-space_representation

__call__(u)#
Parameters:

u (Array) – The input signal \(u_{1:L}\), with shape \((*, L)\). Floating point arrays are promoted to complex arrays.

Returns:

The output signal \(y_{1:L}\), with shape \((*, L)\).

Return type:

Array

discrete()#
Returns:

The matrices \(\bar{A}\), \(\bar{B}\) and \(\bar{C}\), respectively with shape \((H, H)\), \((H)\) and \((H)\).

Return type:

Tuple[Array, Array, Array]

kernel(length)#
Parameters:

length (int) – The kernel length \(L\).

Returns:

The kernel \(\bar{k}_{1:L}\), with shape \((L)\).

Return type:

Array

class inox.nn.ssm.S4(hid_features, key=None)#

Creates an S4 state space model.

References

Efficiently Modeling Long Sequences with Structured State Spaces (Gu et al., 2021)
The Annotated S4 (Rush et al., 2023)
Parameters:
  • hid_features (int) – The number of hidden features \(H\).

  • key (Array) – A PRNG key for initialization. If None, inox.random.get_rng is used instead.

Example

>>> ssm = S4(hid_features=64, key=key)
>>> u = jax.numpy.linspace(0.0, 1.0, 1024)
>>> y = ssm(u)
static DPLR_HiPPO(n)#

Returns the diagonal plus low-rank (DPLR) form of the HiPPO matrix.

\[A = \Lambda - PP^*\]
Parameters:

n (int) – The size \(n\) of the HiPPO matrix.