inox.nn.attention#
Attention layers
Classes#
Creates a multihead attention layer. |
Descriptions#
- class inox.nn.attention.MultiheadAttention(heads, in_features, out_features=None, hid_features=None, bias=True, causal=False, dropout=0.0, key=None)#
Creates a multihead attention layer.
\[Y = \sum_i \mathrm{attention}(X_q W_q^i + b_q^i, X_k W_k^i + b_k^i, X_v W_v^i + b_y^i) W_y^i\]where
\[\mathrm{attention}(Q, K, V) = \mathrm{softmax}\left( \frac{Q K^T}{\sqrt{H}} \right) V\]denotes the scaled dot-product attention.
References
Attention Is All You Need (Vaswani et al., 2023)- Parameters:
heads (int) – The number of attention heads \(N\).
in_features (int) – The number of input features \(C\).
out_features (int) – The number of output features \(C'\). If
None
, \(C' = C\).hid_features (int) – The number of hidden features \(H\) per head. If
None
, \(H = \frac{C}{N}\).bias (bool) – Whether the layer learns additive biases \((b_q, b_k, b_v)\) or not.
causal (bool) – Whether the attention mask is causal or not. If
True
, the \(i\)-th query is only allowed to attend the \(j\)-th key if \(j - i \leq T - S\).dropout (float | Array) – The dropout rate on attention weights.
key (Array) – A PRNG key for initialization. If
None
,inox.random.get_rng
is used instead.
- __call__(xq, xk=None, xv=None, mask=None, key=None)#
- Parameters:
xq (Array) – The query tensor \(X_q\), with shape \((*, S, C)\).
xk (Array | None) – The key tensor \(X_k\), with shape \((*, T, C)\). If
None
, \(X_k = X_q\).xv (Array | None) – The value tensor \(X_v\), with shape \((*, T, C)\). If
None
, \(X_v = X_k\).mask (Array | None) – A boolean attention mask, with shape \((*, S, T)\). A
False
value indicates that the corresponding attention weight is set to \(-\infty\).key (Array | None) – A PRNG key. If
None
, dropout is not applied.
- Returns:
The output tensor \(Y\), with shape \((*, S, C')\).
- Return type: