inox.nn.attention#

Attention layers

Classes#

MultiheadAttention

Creates a multihead attention layer.

Descriptions#

class inox.nn.attention.MultiheadAttention(heads, in_features, out_features=None, hid_features=None, bias=True, causal=False, dropout=0.0, key=None)#

Creates a multihead attention layer.

\[Y = \sum_i \mathrm{attention}(X_q W_q^i + b_q^i, X_k W_k^i + b_k^i, X_v W_v^i + b_y^i) W_y^i\]

where

\[\mathrm{attention}(Q, K, V) = \mathrm{softmax}\left( \frac{Q K^T}{\sqrt{H}} \right) V\]

denotes the scaled dot-product attention.

References

Attention Is All You Need (Vaswani et al., 2023)
Parameters:
  • heads (int) – The number of attention heads \(N\).

  • in_features (int) – The number of input features \(C\).

  • out_features (int) – The number of output features \(C'\). If None, \(C' = C\).

  • hid_features (int) – The number of hidden features \(H\) per head. If None, \(H = \frac{C}{N}\).

  • bias (bool) – Whether the layer learns additive biases \((b_q, b_k, b_v)\) or not.

  • causal (bool) – Whether the attention mask is causal or not. If True, the \(i\)-th query is only allowed to attend the \(j\)-th key if \(j - i \leq T - S\).

  • dropout (float | Array) – The dropout rate on attention weights.

  • key (Array) – A PRNG key for initialization. If None, inox.random.get_rng is used instead.

__call__(xq, xk=None, xv=None, mask=None, key=None)#
Parameters:
  • xq (Array) – The query tensor \(X_q\), with shape \((*, S, C)\).

  • xk (Array | None) – The key tensor \(X_k\), with shape \((*, T, C)\). If None, \(X_k = X_q\).

  • xv (Array | None) – The value tensor \(X_v\), with shape \((*, T, C)\). If None, \(X_v = X_k\).

  • mask (Array | None) – A boolean attention mask, with shape \((*, S, T)\). A False value indicates that the corresponding attention weight is set to \(-\infty\).

  • key (Array | None) – A PRNG key. If None, dropout is not applied.

Returns:

The output tensor \(Y\), with shape \((*, S, C')\).

Return type:

Array