import torch
import math
from torch import nn


class PerceiverAttentionCA(nn.Module):
    def __init__(self, dim=3072, dim_head=1024, heads=33):
        super().__init__()
        self.scale = dim_head**-0.5
        self.dim_head = dim_head
        self.heads = heads
        inner_dim = dim_head  # * heads

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

        import torch.nn.init as init

        init.zeros_(self.to_out.weight)
        if self.to_out.bias is not None:
            init.zeros_(self.to_out.bias)

    def forward(self, x, latents):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, t, aa, D)
            latent (torch.Tensor): latent features
                shape (b, t, hw, D)
        """
        x = self.norm1(x)
        latents = self.norm2(latents)
        # print("latents shape: ", latents.shape)
        # print("x shape: ", x.shape)
        q = self.to_q(latents)
        k, v = self.to_kv(x).chunk(2, dim=-1)

        # attention
        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
        weight = (q * scale) @ (k * scale).transpose(
            -2, -1
        )  # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        out = weight @ v

        # out = out.permute(0, 2, 1, 3)
        return self.to_out(out)
