inox.nn.recurrent#

Recurrent layers

Classes#

Cell

Abstract cell class.

Recurrent

Creates a recurrent layer.

GRUCell

Creates a gated recurrent unit (GRU) cell.

LSTMCell

Creates a long short-term memory (LSTM) cell.

Descriptions#

class inox.nn.recurrent.Cell(**kwargs)#

Abstract cell class.

A cell defines a recurrence function \(f\) of the form

\[(h_i, y_i) = f(h_{i-1}, x_i)\]

and an initial hidden state \(h_0\).

Warning

The recurrence function \(f\) should have no side effects.

__call__(h, x)#
Parameters:
  • h (Any) – The previous hidden state \(h_{i-1}\).

  • x (Any) – The input \(x_i\).

Returns:

The hidden state and output \((h_i, y_i)\).

Return type:

Tuple[Any, Any]

init()#
Returns:

The initial hidden state \(h_0\).

Return type:

Any

class inox.nn.recurrent.Recurrent(cell, reverse=False)#

Creates a recurrent layer.

Parameters:
  • cell (Cell) – A recurrent cell.

  • reverse (bool) – Whether to apply the recurrence in reverse or not.

__call__(xs)#
Parameters:

xs (Any) – A sequence of inputs \(x_i\), stacked on the leading axis. When inputs are vectors, xs has shape \((L, C)\).

Returns:

A sequence of outputs \(y_i\), stacked on the leading axis. When outputs are vectors, ys has shape \((L, C')\).

Return type:

Any

class inox.nn.recurrent.GRUCell(in_features, hid_features, bias=True, key=None)#

Creates a gated recurrent unit (GRU) cell.

References

Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation (Cho et al., 2014)
Parameters:
  • in_features (int) – The number of input features \(C\).

  • hid_features (int) – The number of hidden features \(H\).

  • bias (bool) – Whether the cell learns additive biases or not.

  • key (Array) – A PRNG key for initialization. If None, inox.random.get_rng is used instead.

__call__(h, x)#
Parameters:
  • h (Array) – The previous hidden state \(h_{i-1}\), with shape \((*, H)\).

  • x (Array) – The input vector \(x_i\), with shape \((*, C)\).

Returns:

The hidden state \((h_i, h_i)\).

Return type:

Tuple[Array, Array]

init()#
Returns:

The initial hidden state \(h_0 = 0\), with shape \((H)\).

Return type:

Array

class inox.nn.recurrent.LSTMCell(in_features, hid_features, bias=True, key=None)#

Creates a long short-term memory (LSTM) cell.

References

Long Short-Term Memory (Hochreiter et al., 1997)
Parameters:
  • in_features (int) – The number of input features \(C\).

  • hid_features (int) – The number of hidden features \(H\).

  • bias (bool) – Whether the cell learns additive biases or not.

  • key (Array) – A PRNG key for initialization. If None, inox.random.get_rng is used instead.

__call__(hc, x)#
Parameters:
  • hc (Tuple[Array, Array]) – The previous hidden and cell states \((h_{i-1}, c_{i-1})\), each with shape \((*, H)\).

  • x (Array) – The input vector \(x_i\), with shape \((*, C)\).

Returns:

The hidden and cell states \(((h_i, c_i), h_i)\).

Return type:

Tuple[Tuple[Array, Array], Array]

init()#
Returns:

The initial hidden and cell states \(h_0 = c_0 = 0\), each with shape \((H)\).

Return type:

Tuple[Array, Array]