import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn

MIN_NUM_PATCHES = 16

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs): 
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs): 
        return self.fn(self.norm(x), **kwargs) 

class FeedForward(nn.Module): 
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(), 
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

class CAAttention(nn.Module):
    def __init__(self, channels, heads =4):
        super(CAAttention, self).__init__()
        self.heads = heads
        self.temperature = nn.Parameter(torch.ones(heads, 1, 1))  # h11

        self.query_filter =  nn.Conv2d(channels, channels, kernel_size=(1, 1))
            
        self.key_filter =  nn.Conv2d(channels, channels, kernel_size=(1, 1))
        self.value_filter =  nn.Conv2d(channels, channels, kernel_size=(1, 1))
        self.project_out = nn.Conv2d(channels, channels, kernel_size=(1, 1))

    def forward(self, x): 
        x1 = x.transpose(1,2).unsqueeze(-1) 
        B, C, N, _ = x1.shape
        q = self.query_filter(x1) 
        k = self.key_filter(x1) 
        v = self.value_filter(x1)  


        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.heads) 
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.heads) 
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.heads) 

        q = torch.nn.functional.normalize(q, dim=-1) 
        k = torch.nn.functional.normalize(k, dim=-1) 

        attn = (q @ k.transpose(-2, -1)) * self.temperature 
        attn = attn.softmax(dim=-1) 
        out = (attn @ v)  

        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.heads, h=N, w=1) 

        out = self.project_out(out) 
        out = out.squeeze(-1) 
        out = out.transpose(1,2)
        return out + x

class Transformerca(nn.Module):
    def __init__(self, dim, depth, heads,  mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth): 
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, CAAttention(dim, heads = heads))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
            ]))
    def forward(self, x, mask = None): 

        for attn, ff in self.layers:
            x = attn(x) #
            x = ff(x) #
        return x
