import torch
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F
from einops import rearrange
from matplotlib import pyplot as plt

class WindowSpectralAttn(nn.Module):
    def __init__(
            self,
            dim,
            dim_head,
            heads,
            window_size=8
    ):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_qkv = nn.Conv2d(dim, dim_head * heads * 3, 1, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1, 1))
        self.proj = nn.Conv2d(dim_head * heads, dim, 1, bias=True)
        self.pos_emb = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
            nn.GELU(),
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
        )
        self.dim = dim
        self.window_size = window_size

    def forward(self, x_in):
        """
        x_in: [b,c,h,w]
        return out: [b,c,h,w]
        """
        b, c, h, w = x_in.shape
        N = self.window_size
        nH = h // N
        nW = w // N
        qkv = self.to_qkv(x_in)
        q_inp, k_inp, v_inp = torch.chunk(qkv, chunks=3, dim=1)
        # q, k, v = q_inp.reshape(b, c, h*w), k_inp.reshape(b, c, h*w), v_inp.reshape(b, c, h*w)
        q, k, v = map(lambda t: rearrange(t, 'b (n d) (nh h) (nw w) -> b n (nh nw) (h w) d',n=self.num_heads, h=N, w=N),
                                (q_inp, k_inp, v_inp))
        # q: b,heads,hw,c
        # q = q.transpose(-2, -1)
        # k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)
        q = F.normalize(q, dim=-2, p=2)
        k = F.normalize(k, dim=-2, p=2)
        attn = (k.transpose(-2, -1) @ q)   # A = K^T*Q
        # print(attn.shape)
        # print(self.rescale.shape)
        attn = attn * self.rescale
        attn = attn.softmax(dim=-1)
        x = attn @ v   # b*heads, nW, d, hw
        x = rearrange(x, 'b n (nh nw) d (h w) -> b (n d) (nh h) (nw w)', nh=nH, nw=nW, h=N, w=N)
        out_c = self.proj(x)
        out_p = self.pos_emb(v_inp)
        out = out_c + out_p

        return out


class TopkOperation(nn.Module):
    def __init__(self, dim, ratio_list) -> None:
        super().__init__()

        self.dim = dim
        self.ratio_list = ratio_list
        self.attn_list = nn.ParameterList([
            nn.Parameter(torch.tensor([0.2]), requires_grad=True) for i in range(len(ratio_list))
        ])

    def forward(self, attn, v):
        out = 0
        # attn_final = 0
        for i, ratio in enumerate(self.ratio_list):
            index = torch.topk(attn, k=int(self.dim * ratio), dim=-1, largest=True)[1]
            mask = torch.zeros_like(attn, requires_grad=False)
            mask = mask.scatter_(-1, index, 1.)
            attn_i = torch.where(mask > 0, attn, torch.full_like(attn, float('-inf')))
            attn_i = attn_i.softmax(dim=-1)
            out = out + (attn_i @ v) * self.attn_list[i]
            # print(attn.shape)

        #     for k in range(10):
        #         win_num = attn_i.shape[2]
        #         attn_vis = attn_i[k,0,win_num//2,:,:].cpu().numpy()
        #         plt.imsave(f"vis/attention/scene_{k+1:02d}_attn_{ratio:.2f}.png", attn_vis)

        #     attn_final += (self.attn_list[i] * attn_i[:,0,win_num//2,:,:]).cpu().numpy()

        # for k in range(10):
        #     win_num = attn.shape[2]
        #     attn_vis = attn[k,0,win_num//2,:,:].cpu().numpy()
        #     attn_final_vis = attn_final[k,:,:]
        #     plt.imsave(f"vis/attention/scene_{k+1:02d}_attn_ori.png", attn_vis)
        #     plt.imsave(f"vis/attention/scene_{k+1:02d}_attn_final.png", attn_final_vis)

        return out


class Topk_SpectralAttn(nn.Module):
    def __init__(
            self,
            dim,
            dim_head,
            heads
    ):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_qkv = nn.Conv2d(dim, dim_head * heads * 3, 1, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
        self.proj = nn.Conv2d(dim_head * heads, dim, 1, bias=True)
        self.pos_emb = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
            nn.GELU(),
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
        )
        self.dim = dim
        self.topk = TopkOperation(self.dim_head, ratio_list=[1/2, 2/3, 3/4, 4/5])

    def forward(self, x_in):
        """
        x_in: [b,c,h,w]
        return out: [b,c,h,w]
        """
        b, c, h, w = x_in.shape
        qkv = self.to_qkv(x_in)
        q_inp, k_inp, v_inp = torch.chunk(qkv, chunks=3, dim=1)
        q, k, v = q_inp.reshape(b, c, h*w), k_inp.reshape(b, c, h*w), v_inp.reshape(b, c, h*w)
        q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> b h n d', h=self.num_heads),
                                (q, k, v))
        # q: b,heads,hw,c
        q = q.transpose(-2, -1)
        k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)
        q = F.normalize(q, dim=-1, p=2)
        k = F.normalize(k, dim=-1, p=2)
        attn = (k @ q.transpose(-2, -1))   # A = K^T*Q
        attn = attn * self.rescale
        # attn = attn.softmax(dim=-1)
        # x = attn @ v   # b,heads,d,hw
        x = self.topk(attn, v)
        x = x.reshape(b, self.num_heads * self.dim_head, h, w)
        out_c = self.proj(x)
        out_p = self.pos_emb(v_inp)
        out = out_c + out_p

        return out
    

class Topk_WindowSpectralAttn(nn.Module):
    def __init__(
            self,
            dim,
            dim_head,
            heads,
            window_size=8
    ):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_qkv = nn.Conv2d(dim, dim_head * heads * 3, 1, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1, 1))
        self.proj = nn.Conv2d(dim_head * heads, dim, 1, bias=True)
        self.pos_emb = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
            nn.GELU(),
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
        )
        self.dim = dim
        self.window_size = window_size
        self.topk = TopkOperation(self.dim_head, ratio_list=[1/2, 2/3, 3/4, 4/5])

    def forward(self, x_in):
        """
        x_in: [b,c,h,w]
        return out: [b,c,h,w]
        """
        b, c, h, w = x_in.shape
        N = self.window_size
        nH = h // N
        nW = w // N
        qkv = self.to_qkv(x_in)
        q_inp, k_inp, v_inp = torch.chunk(qkv, chunks=3, dim=1)
        # q, k, v = q_inp.reshape(b, c, h*w), k_inp.reshape(b, c, h*w), v_inp.reshape(b, c, h*w)
        q, k, v = map(lambda t: rearrange(t, 'b (n d) (nh h) (nw w) -> b n (nh nw) (h w) d',n=self.num_heads, h=N, w=N),
                                (q_inp, k_inp, v_inp))
        # q: b,heads,hw,c
        # q = q.transpose(-2, -1)
        # k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)
        q = F.normalize(q, dim=-2, p=2)
        k = F.normalize(k, dim=-2, p=2)
        attn = (k.transpose(-2, -1) @ q)   # A = K^T*Q
        # print(self.rescale.shape)
        # print(attn.shape)
        attn = attn * self.rescale
        # attn = attn.softmax(dim=-1)
        # x = attn @ v   # b*heads, nW, d, hw
        x = self.topk(attn, v)
        x = rearrange(x, 'b n (nh nw) d (h w) -> b (n d) (nh h) (nw w)', nh=nH, nw=nW, h=N, w=N)
        out_c = self.proj(x)
        out_p = self.pos_emb(v_inp)
        out = out_c + out_p

        return out
    



if __name__ == "__main__":
    from fvcore.nn import parameter_count_table, flop_count_table, FlopCountAnalysis

    model = Topk_WindowSpectralAttn(28, 28, 1)
    x = torch.randn(2, 28, 256, 256)
    
    flops = FlopCountAnalysis(model, x)
    print(parameter_count_table(model))
    print(flop_count_table(flops))