It’s well known that Multi-Head Attention is permutation equivalent (e.g. here). Let’s verify it in PyTorch.

import torch
from torch import nn

batch_size = 16
seq_length = 10
embed_dim = 384
n_heads = 8

attn = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)
X = torch.rand(batch_size, seq_length, embed_dim)
o = torch.randperm(seq_length)
z1, _ = attn(X, X, X)
z2, _ = attn(X[:, o], X[:, o], X[:, o])
print(torch.allclose(z1[:, o], z2))

Almost certainly, it will print a False. What’s going wrong? It turns out that PyTorch uses torch.float32 by default. Let’s increase the precision to torch.float64:

import torch
from torch import nn

batch_size = 16
seq_length = 10
embed_dim = 384
n_heads = 8

attn = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True).to(torch.float64)
X = torch.rand(batch_size, seq_length, embed_dim, dtype=torch.float64)
o = torch.randperm(seq_length)
z1, _ = attn(X, X, X)
z2, _ = attn(X[:, o], X[:, o], X[:, o])
print(torch.allclose(z1[:, o], z2))

It should print True now.