inox.nn.attention¶
Attention layers
Classes¶
Creates a multihead attention layer. |
Descriptions¶
- class inox.nn.attention.MultiheadAttention(in_features, out_features=None, hid_features=None, heads=1, bias=True, causal=False, dropout=0.0, key=None)¶
Creates a multihead attention layer.
\[Y = \sum_i \mathrm{attention}(X_q W_q^i + b_q^i, X_k W_k^i + b_k^i, X_v W_v^i + b_v^i) W_y^i\]where
\[\mathrm{attention}(Q, K, V) = \mathrm{softmax}\left( \frac{Q K^T}{\sqrt{H}} \right) V\]denotes the scaled dot-product attention.
References
Attention Is All You Need (Vaswani et al., 2023)- Parameters:
in_features (int) – The number of input features \(C\).
out_features (int) – The number of output features \(C'\). If
None, \(C' = C\).hid_features (int) – The number of hidden features \(H\) per head. If
None, \(H = \frac{C}{N}\).heads (int) – The number of attention heads \(N\).
bias (bool) – Whether the layer learns additive biases \((b_q, b_k, b_v)\) or not.
causal (bool) – Whether the attention mask is causal or not. If
True, the \(i\)-th query is only allowed to attend the \(j\)-th key if \(j - i \leq T - S\).dropout (float | Array) – The dropout rate on attention weights.
key (Array) – A PRNG key for initialization. If
None,inox.random.get_rng("init")is used instead.
- __call__(xq, xk=None, xv=None, mask=None, key=None)¶
- Parameters:
xq (Array) – The query tensor \(X_q\), with shape \((*, S, C)\).
xk (Array) – The key tensor \(X_k\), with shape \((*, T, C)\). If
None, \(X_k = X_q\).xv (Array) – The value tensor \(X_v\), with shape \((*, T, C)\). If
None, \(X_v = X_k\).mask (Array) – A boolean attention mask, with shape \((*, N, S, T)\). A
Falsevalue indicates that the corresponding attention weight is set to \(-\infty\).key (Array) – A PRNG key. If
None, dropout is not applied.
- Returns:
The output tensor \(Y\), with shape \((*, S, C')\).
- Return type: