inox.nn.ssm#
State space model (SSM) layers
Classes#
Descriptions#
- class inox.nn.ssm.SISO(**kwargs)#
Abstract single-input single-output (SISO) state space model class.
A SISO state space model defines a system of equations of the form
\[\begin{split}x'(t) & = A x(t) + B u(t) \\ y(t) & = C x(t)\end{split}\]where \(u(t), y(t) \in \mathbb{C}\) are input and output signals and \(x(t) \in \mathbb{C}^{H}\) is a latent/hidden state. In practice, the input and output signals are sampled every \(\Delta\) time units, leading to sequences \((x_1, x_2, \dots)\) and \((y_1, y_2, \dots)\) whose dynamics are governed by the discrete-time form of the system
\[\begin{split}x_i & = \bar{A} x_{i-1} + \bar{B} u_i \\ y_i & = \bar{C} x_i\end{split}\]where \(\bar{A} = \exp(\Delta A)\), \(\bar{B} = A^{-1} (\bar{A} - I) B\) and \(\bar{C} = C\). Assuming \(x_0 = 0\), the dynamics can also be represented as a discrete-time convolution
\[y_{1:L} = \bar{k}_{1:L} * u_{1:L}\]where \(\bar{k}_i = \bar{C} \bar{A}^{i-1} \bar{B} \in \mathbb{C}\).
Wikipedia
https://wikipedia.org/wiki/State-space_representation
- discrete()#
- class inox.nn.ssm.SISOLayer(siso, reverse=False)#
Creates a SISO layer.
- Parameters:
Example
>>> keys = jax.random.split(key, in_features) >>> siso = jax.vmap(S4, in_axes=(0, None))(keys, hid_features) >>> layer = SISOLayer(siso) >>> y = layer(u)
- class inox.nn.ssm.S4(key, hid_features)#
Creates an S4 state space model.
References
Efficiently Modeling Long Sequences with Structured State Spaces (Gu et al., 2021)The Annotated S4 (Rush et al., 2023)- Parameters:
key (PRNGKeyArray) – A PRNG key for initialization.
hid_features (int) – The number of hidden features \(H\).