import torch
import torch.nn as nn
from pathlib import Path
import copy
# from timm.models.vision_transformer import Attention, Block
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
class AttentionHead(torch.nn.Module):
    def __init__(self, dim,head_dim, qkv_bias=False, attn_drop=0., proj_drop=0.):
        # dim : size attention input vector
        # head_dim : size of this attention head
        super().__init__()
        self.scale = head_dim ** -0.5
        self.qk_dim = head_dim
        self.v_dim = head_dim

        self.qkv = torch.nn.Linear(dim, 2*self.qk_dim+self.v_dim, bias=qkv_bias)
        self.attn_drop = torch.nn.Dropout(attn_drop)
    def forward(self, x):
        q, k, v = torch.split(self.qkv(x), [self.qk_dim, self.qk_dim, self.v_dim], dim=-1)
        # qkv = self.qkv(x).reshape(x.shape[0], x.shape[1], 3, -1)
        # q = qkv[:, :, 0, :] * self.scale
        # k = qkv[:, :, 1, :]
        # v = qkv[:, :, 2, :]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = (attn @ v)
        # attn = torch.einsum('...qhd,...khd->...hqk', q, k)
        # attn = attn.softmax(dim=-1)
        # attn = self.attn_drop(attn)
        # out = torch.einsum('...hqk,...vhd->...qhd', attn, v)
        return out
class PrunedAttention_subhead(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.heads = torch.nn.ModuleList([
            AttentionHead(dim, head_dim, qkv_bias, attn_drop, proj_drop) for _ in range(num_heads)
        ])
        self.proj = torch.nn.Linear(dim, dim)
        self.proj_drop = torch.nn.Dropout(proj_drop)
    def forward(self, x):
        attn_heads = [head(x) for head in self.heads]
        x = torch.cat(attn_heads, dim=-1)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    @classmethod
    def convert_from(cls, attention: Attention, preserved_indices=None):
        num_heads = attention.num_heads
        dim = attention.qkv.weight.shape[1]
        head_dim = dim // num_heads
        qkv_bias = attention.qkv.bias is not None
        ret = cls(dim, num_heads, qkv_bias, attention.attn_drop.p, attention.proj_drop.p)

        if preserved_indices is None:
            split_indices = [head_dim for _ in range(num_heads)]

        head_qkv_ws = attention.qkv.weight.reshape(3,dim,dim).split(split_indices, dim=1)
        head_qkv_bs = attention.qkv.bias.reshape(3,dim).split(split_indices, dim=1) if attention.qkv.bias is not None else [None for _ in range(num_heads)]
        for head,qkv_w, qkb_b in zip(ret.heads, head_qkv_ws, head_qkv_bs):
            head.qkv.weight.data = qkv_w.clone().flatten(end_dim=1)
            if qkb_b is not None:
                head.qkv.bias.data = qkb_b.clone().flatten(end_dim=1)
        ret.proj.load_state_dict(attention.proj.state_dict())
        
        # if split_indices is None:
        #     split_indices = [head_dim * i for i in range(1,num_heads+1)]
        

        # qkv_weights = attention.qkv.weight.split(split_indices, dim=0)
        # for i, head in enumerate(cls.heads):
        #     head.qkv.weight.data = qkv_weights[i]
        
        return ret

class PrunedAttention_ind(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.q_split = [head_dim for _ in range(num_heads)]
        self.k_split = [head_dim for _ in range(num_heads)]
        self.v_split = [head_dim for _ in range(num_heads)]

        self.qkv_split = [sum(self.q_split),sum(self.k_split),sum(self.v_split)]

        # super().__init__()
        # self.num_heads = num_heads
        # head_dim = dim // num_heads
        # self.heads = torch.nn.ModuleList([
        #     AttentionHead(dim, head_dim, qkv_bias, attn_drop, proj_drop) for _ in range(num_heads)
        # ])
        # self.proj = torch.nn.Linear(dim, dim)
        # self.proj_drop = torch.nn.Dropout(proj_drop)
    def forward(self,x):
        B, N, C = x.shape
        q_all, k_all, v_all = self.qkv(x).split(self.qkv_split,dim=-1)
        # qkv_all = self.qkv(x).reshape(B,N,3,-1).permute(2,0,1,3)  # .reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q_heads, k_heads, v_heads = q_all.split(self.q_split,dim=-1), k_all.split(self.k_split,dim=-1), v_all.split(self.v_split,dim=-1)   # make torchscript happy (cannot use tensor as tuple)
        attn_heads=[]
        for q,k,v in zip(q_heads,k_heads,v_heads):
            q = q * self.scale
            attn = (q @ k.transpose(-2, -1))
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            attn_heads.append(attn @ v)
        x = torch.cat(attn_heads, dim=-1)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    @classmethod
    def convert_from(cls, attention: Attention):
        num_heads = attention.num_heads
        dim = attention.qkv.weight.shape[1]
        head_dim = dim // num_heads
        qkv_bias = attention.qkv.bias is not None
        ret = cls(dim, num_heads, qkv_bias, attention.attn_drop.p, attention.proj_drop.p)
        ret.register_module('qkv',copy.deepcopy(getattr(attention,'activation',attention).qkv))
        ret.register_module('proj',copy.deepcopy(getattr(attention,'activation',attention).proj))

        # ret.load_state_dict(attention.state_dict())
        
        # if split_indices is None:
        #     split_indices = [head_dim * i for i in range(1,num_heads+1)]
        

        # qkv_weights = attention.qkv.weight.split(split_indices, dim=0)
        # for i, head in enumerate(cls.heads):
        #     head.qkv.weight.data = qkv_weights[i]
        
        return ret

class PrunedAttention_spatial(nn.Module):
    def __init__(self, dim,qk_size,v_size, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.qk_size= qk_size
        self.v_size =v_size
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim,num_heads*(qk_size*2+v_size),bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(v_size, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self,x):
        B,N,C = x.shape
        num_heads = self.num_heads
        q,k,v = self.qkv(x).split([self.qk_size*num_heads,self.qk_size*num_heads,self.v_size*self.num_heads],dim=2)

        q = q.reshape(B,N,num_heads,-1).transpose(1,2)
        k = k.reshape(B,N,num_heads,-1).transpose(1,2)
        v = v.reshape(B,N,num_heads,-1).transpose(1,2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # 정확히는 -1 대신 self.v_size*self.num_heads가 들어가야함
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    @classmethod
    def convert_from(cls,attention:Attention):
        num_heads = attention.num_heads
        dim = attention.qkv.weight.shape[1]
        head_dim = dim // num_heads
        qkv_bias = attention.qkv.bias is not None
        ret = cls(dim,head_dim,head_dim,num_heads,qkv_bias,attention.attn_drop.p, attention.proj_drop.p)

        ret.register_module('qkv',copy.deepcopy(getattr(attention,'activation',attention).qkv))
        ret.register_module('proj',copy.deepcopy(getattr(attention,'activation',attention).proj))
        return ret

if __name__ == "__main__":
    import tqdm
    # torch.use_deterministic_algorithms(True)
    pretrained_attn = Attention(768,num_heads=12,qkv_bias=True)
    pretrained_attn.load_state_dict(torch.load(Path(__file__).parent / "attn_dev_state_dict.pth"))
    pretrained_attn.eval()
    example_input = torch.randn(1,197,768)

    

    pruned_attn = PrunedAttention_subhead.convert_from(pretrained_attn, preserved_indices=None)
    pruned_attn.eval()

    pruned_attn2 = PrunedAttention_ind.convert_from(pretrained_attn, preserved_indices=None)
    pruned_attn2.eval()


    
    for _ in tqdm.tqdm(range(1)):
        example_input = torch.randn(1,197,768)
        pretrained_output = pretrained_attn(example_input)
    # for _ in tqdm.tqdm(range(10000)):
        # example_input = torch.randn(1,197,768)
        pruned_output = pruned_attn(example_input)
    # for _ in tqdm.tqdm(range(10000)):
        # example_input = torch.randn(1,197,768)
        pruned_output2 = pruned_attn2(example_input)

        # if not torch.all(pretrained_output == pruned_output):
        #     print(1)
    # pruned_output = pruned_attn(example_input)


    print(1)