import torch
import torch.nn as nn
from einops import rearrange


class CrossAttn(nn.Module):
    """
    Cross Attention Module between image and text
    """

    def __init__(self, dim, dim_text, scale_factor=4, heads=16, dim_head=64):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_q = nn.Conv2d(
            dim, hidden_dim, scale_factor, bias=False, stride=scale_factor
        )
        self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False)
        self.to_out = nn.Linear(hidden_dim, dim)

    def forward(self, x, text):
        kv = self.to_kv(text).chunk(2, dim=-1)
        k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), kv)
        q = self.to_q(x)
        q = rearrange(q, "b (h c) x y -> b h (x y) c", h=self.heads)

        attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = attn.softmax(dim=-1)
        out = torch.matmul(attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")

        return self.to_out(out)


class CrossAttnImage(nn.Module):
    def __init__(self, dim, dim_text, heads=16, dim_head=64):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_q = nn.Conv2d(dim, hidden_dim, 1, bias=False, stride=1)
        self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False)
        self.to_out = nn.Linear(hidden_dim, dim)

    def forward(self, x, text):
        b, c, h, w = x.shape
        kv = self.to_kv(text).chunk(2, dim=-1)
        k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), kv)
        q = self.to_q(x)
        q = rearrange(q, "b (h c) x y -> b h (x y) c", h=self.heads)

        attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = attn.softmax(dim=-1)
        out = torch.matmul(attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")

        return self.to_out(out)
