import torch
import torch.nn as nn
from pathlib import Path
import copy
from timm.models.vision_transformer import Attention, Block

class PrunedAttention_spatial(nn.Module):
    def __init__(self, dim,qk_size,v_size, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.qk_size= qk_size
        self.v_size =v_size
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim,num_heads*(qk_size*2+v_size),bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(v_size, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self,x):
        B,N,C = x.shape
        num_heads = self.num_heads
        q,k,v = self.qkv(x).split([self.qk_size*num_heads,self.qk_size*num_heads,self.v_size*self.num_heads],dim=2)

        q = q.reshape(B,N,num_heads,-1).transpose(1,2)
        k = k.reshape(B,N,num_heads,-1).transpose(1,2)
        v = v.reshape(B,N,num_heads,-1).transpose(1,2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # 정확히는 -1 대신 self.v_size*self.num_heads가 들어가야함
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    @classmethod
    def convert_from(cls,attention:Attention):
        num_heads = attention.num_heads
        dim = attention.qkv.weight.shape[1]
        head_dim = dim // num_heads
        qkv_bias = attention.qkv.bias is not None
        ret = cls(dim,head_dim,head_dim,num_heads,qkv_bias,attention.attn_drop.p, attention.proj_drop.p)

        ret.register_module('qkv',copy.deepcopy(getattr(attention,'activation',attention).qkv))
        ret.register_module('proj',copy.deepcopy(getattr(attention,'activation',attention).proj))
        return ret
