import timm
import torch
import torch.nn as nn
from timm.models.vision_transformer import LayerScale, DropPath, VisionTransformer
from typing import Optional, Type
from torch.jit import Final
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
    trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
    get_act_layer, get_norm_layer, LayerType
import torch.nn.functional as F
from Positional_Embeddings.NoPE import No_PE
from timm.models.layers import trunc_normal_

class CustomPE_ViT(nn.Module):
    def __init__(self, ViT_kwargs, attn_kwargs,  pretrained=False, PE_method=None, stem_only=False, num_heads=12, shared_pe=False, rot_x=False,rot_value=False, flash_att=True,use_cls_token=True, use_default=False):
        super(CustomPE_ViT,self).__init__()
        
        P_x = P_y = int(ViT_kwargs['img_size']/ViT_kwargs['patch_size'])
        self.P_x, self.P_y = P_x, P_y
        
        self.patch_size = ViT_kwargs['patch_size']
        num_heads=ViT_kwargs['num_heads']
        embed_dim = ViT_kwargs['embed_dim']//num_heads
        self.dropout_rate = ViT_kwargs['drop_rate']
        if use_cls_token:
            if PE_method==None or use_default==True:
                ViT_kwargs['global_pool'] = 'token'
                ViT_kwargs['class_token'] = True,
            else:
                self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim*num_heads))
            
            # trunc_normal_(self.cls_token, std=.02)
        
        # Load the ViT model
        self.model = VisionTransformer(**ViT_kwargs)

        self.num_heads = num_heads
        self.pes = []
        # Get the number of tokens (patches)
        if PE_method==None or use_default==True:
            self.PE_method=None
            self.use_default = True
            pass
        else:
            self.use_default = False
            if rot_x:
                pe = PE_method(embed_dim*num_heads, P_x=P_x,P_y=P_y)
            else:
                pe = PE_method(embed_dim, P_x=P_x,P_y=P_y)
            # self.PE_method = pe
            self.model.pos_embed = None
            if shared_pe:
                self.pes = [pe]
            if stem_only:
                for name, block in self.model.blocks.named_children():
                    if hasattr(block, 'attn'):
                        block.attn = Custom_Attention(dim=self.model.embed_dim, num_heads=block.attn.num_heads, PE=pe, rot_x=rot_x, rot_val=rot_value, flash_att=flash_att, **attn_kwargs)
                        # print('Attention Replaced')
                    break
            else:
                # Iterate over each block and replace the attention layer
                for name, block in self.model.blocks.named_children():
                    if hasattr(block, 'attn'):
                        if not shared_pe:
                            if rot_x:
                                pe = PE_method(embed_dim*num_heads, P_x=P_x,P_y=P_y)
                            else:
                                pe = PE_method(embed_dim, P_x=P_x,P_y=P_y)
                            self.pes +=[pe]
                        block.attn = Custom_Attention(dim=self.model.embed_dim, num_heads=block.attn.num_heads, PE=pe, rot_x=rot_x, rot_val=rot_value, flash_att=flash_att, **attn_kwargs)
                        # print('Attention Replaced')
            

    def forward(self, x):
        '''
        Forward pass through the model.
        Args:
            x: Input tensor of shape (B, C, H, W), where B is the batch size, C is the number of channels,
                H is the height, and W is the width.
        Returns:
            out: Output tensor after passing through the model.
        '''
        if self.use_default:
            return self.model(x)
        B, _,_,_ = x.shape
        x = self.model.patch_embed(x)
        cls = self.cls_token.repeat(B, 1, 1)
        x = torch.cat((cls, x), dim=1)
        for i , blk in enumerate(self.model.blocks):
                x = blk(x)
        x = self.model.norm(x)[:,0]

        if self.dropout_rate:
            x = F.dropout(x, p=float(self.dropout_rate), training=self.model.training)
        out = self.model.head(x)
        return out
    
    def to(self, device):
        super().to(device)
        for pe in self.pes:
            pe._set_device_(device)
        return self
    
    def set_PE_patches(self, P_x, P_y):
        for pe in self.pes:
            pe.set_Patches(P_x, P_y)

    def set_Extrapolate_Mode(self):
        self.model.patch_embed.strict_img_size = False
            
    def set_Train_Mode(self):
        for pe in self.pes:
            pe.train_mode()

    def set_Image_Size(self, img_size):
        if isinstance(img_size, int):
            img_size = (img_size, img_size)
        if isinstance(self.patch_size, int):
            self.patch_size = (self.patch_size, self.patch_size)
        else:
            self.patch_size = (self.patch_size[0], self.patch_size[1])
        
        if self.use_default:
            self.model.patch_embed.strict_img_size=False
            P_x = img_size[0]//self.patch_size[0]
            P_y = img_size[1]//self.patch_size[1]

            cls_token = self.model.pos_embed[:,:1]
            pe_interp = torch.nn.functional.interpolate(self.model.pos_embed[:,1:].view(1, self.P_x, self.P_y, -1).permute(0,3,1,2), (P_x, P_y), mode='bilinear')
            self.P_x, self.P_y = P_x, P_y

            pe_interp = pe_interp.permute(0,2,3,1).flatten(1,2)
            pe_interp = torch.cat([cls_token, pe_interp], dim=1)
            self.model.pos_embed = torch.nn.Parameter(pe_interp)
        num_patches_x = (img_size[0] // self.patch_size[0])
        num_patches_y = (img_size[1] // self.patch_size[1])

        for pe in self.pes:
            pe.set_Patches(num_patches_x, num_patches_y)

class Custom_Attention(nn.Module):
    '''
        Based on timm's attention module
    '''
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            proj_bias: bool = True,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: Type[nn.Module] = nn.LayerNorm,
            PE : Type[nn.Module] = None,
            rot_x : bool = False,
            rot_val : bool = False,
            flash_att=True,
            use_class_token=True
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim//num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = flash_att

        self.qkv = nn.Linear(dim, self.head_dim*num_heads * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)
        self.pe = PE
        self.rot_val = rot_val
        self.rot_x = rot_x
        self.use_class_token = use_class_token

    def apply_pe(self, z):
        if self.use_class_token:
            cls_token = z[:, :, 0:1]  # Keep the gradient path
            patch_tokens = self.pe(z[:, :, 1:])
            return torch.cat((cls_token, patch_tokens), dim=2)
        else:
            return self.pe(z)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, P, C = x.shape
        if self.rot_x:
            # print('Rotating x')
            x=x.view(B,1,P,C)
            x=self.pe(x)
            x=x.view(B,P,C)
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, P, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if not self.rot_x:
            q= self.apply_pe(q)
            k = self.apply_pe(k)
            if self.rot_val:
                v = self.pe(v)
        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v
        x = x.transpose(1, 2).reshape(B, P, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    def set_PE_mode(self, mode):
        if mode == 'train':
            self.pe.train_mode()
        elif mode == 'extrapolate':
            self.pe.extrapolate_mode()
        else:
            raise ValueError("Invalid mode. Choose 'train' or 'eval'.")