import torch
from torch import nn, einsum


def linear_attn(q, k, v):
    dim = q.shape[-1]

    q = q.softmax(dim=-1)
    k = k.softmax(dim=-2)

    q = q * dim ** -0.5

    context = einsum('bhnd,bhne->bhde', k, v)
    attn = einsum('bhnd,bhde->bhne', q, context)
    return attn.reshape(*q.shape)


class SelfAttention(nn.Module):
    def __init__(self, dim, heads, dropout=0.0):
        super().__init__()
        d_heads = dim // heads

        self.heads = heads
        self.d_heads = d_heads

        self.global_attn_fn = linear_attn

        self.to_q = nn.Linear(dim, d_heads * heads, bias=False)
        self.to_k = nn.Linear(dim, d_heads * heads, bias=False)
        self.to_v = nn.Linear(dim, d_heads * heads, bias=False)

        self.to_out = nn.Linear(d_heads * heads, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))

        b, t, _, dh = *q.shape, self.d_heads

        merge_heads = lambda x: x.reshape(*x.shape[:2], -1, dh).transpose(1, 2)

        q, k, v = map(merge_heads, (q, k, v))

        out = [self.global_attn_fn(q, k, v)]

        attn = torch.cat(out, dim=1)
        attn = attn.transpose(1, 2).reshape(b, t, -1)
        return self.dropout(self.to_out(attn))
