複素値Attention
一言でいうと: 複素値Attentionは、複素値信号の位相と振幅の構造を保ってTransformerの類似度計算を行うために、共役内積や複素softmaxの扱いを明示的に設計するAttention機構である。
複素値Attentionは、query、key、valueを複素数として扱うAttention機構である。 複素数では と共役内積 が異なるため、どの量を類似度に使うかが設計上の論点になる。
複素値Transformer構成要素は、複素内積の実部を使うCAttを主提案とする。 この類似度は、同じ角度回転をqueryとkeyに加えても変わらないため、複素信号の回転構造と整合する。
比較候補として、内積の絶対値、絶対値と位相、実部と虚部の別softmaxを使う方式も評価される。 MusicNetの分類と系列生成では、複素内積ベースのCAttが安定して高い結果を示した。
ComplexOrlicz撤回プレプリントではCauchy-Riemann条件や複素平面上の直交性が言及されるが、Attention機構そのものではなく異分散回帰の損失設計に関する主張である。 撤回済みプレプリント由来のため、検証済み知見とは分けて参照する。
実装メモ
複素値Transformer構成要素のPDFには、公式実装として ag-pria/cv-transformer が記載されている。
複素値Multi-head Attentionの実装は attention/mha.py にある。
実装上の分かりづらい点は、積の種類とsoftmaxの種類が別々に切り替えられることである。
sm_variante の末尾2文字が cp のときは通常の に相当する積を使い、ip のときは torch.conj_physical(k) によりkeyを共役にしてから積を取る。
論文で主提案とされるCAttに対応するのは、共役内積側のスコアから実部だけを取り出し、その実数値にsoftmaxをかける処理である。
公式実装では、積の切り替えは次のように書かれている。
if self.product == 'cp':
attn_weights = torch.bmm(q, k.transpose(1, 2)) * self.scaling
elif self.product == 'ip':
attn_weights = torch.bmm(q, torch.conj_physical(k).transpose(1, 2)) * self.scaling出典: attention/mha.py
softmax variantも数式と対応している。
real は にsoftmaxをかける。
abs は にsoftmaxをかけた後、torch.sgn(input) を掛けて位相を戻すため、APAttに近い。
absonly は絶対値softmaxの重みだけを複素数型へ変換するため、位相を戻さないAAttに近い。
naiv は実部と虚部に別々のsoftmaxをかけて複素重みを作るため、RIAttに対応する。
対応する実装断片は次である。
def softmax_abs(self, input, attn_mask=None):
abso = torch.abs(input)
if attn_mask is not None:
abso += attn_mask.unsqueeze(0).real.to(self.device)
return softmax(abso, dim=-1).type(torch.complex64) * torch.sgn(input)
def softmax_real(self, input, attn_mask=None):
real = torch.real(input)
if attn_mask is not None:
real += attn_mask.unsqueeze(0).real.to(self.device)
return softmax(real, dim=-1).type(torch.complex64)出典: attention/mha.py
maskの扱いも実装では注意が必要である。
real と absonly と abs では、maskの実部だけを実数スコアに足す。
naiv では実部maskと虚部maskをそれぞれ足してから、実部softmaxと虚部softmaxを別々に計算する。
したがって、複素値Attentionを再実装するときは、単にPyTorchの実数MultiheadAttentionを複素型に置き換えるだけではなく、スコアを実数化するタイミングとmaskを足す対象を明示する必要がある。