from torch.jit import Final
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List
from timm.layers import use_fused_attn#, Mlp
import torch.nn.functional as F
import torch.nn as nn
import torch
from timm.models import create_model
from itertools import repeat
import collections.abc

def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse
to_2tuple = _ntuple(2)

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks

    NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
    """
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            norm_layer=None,
            bias=True,
            drop=0.,
            use_conv=False,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = nn.Linear

        self.fc1 = LinearLayerWithL2norm(linear_layer(in_features, hidden_features, bias=bias[0]))
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
        self.fc2 = LinearLayerWithL2norm(linear_layer(hidden_features, out_features, bias=bias[1]))
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.norm(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x        

class LinearLayerWithL2norm(nn.Module):
    def __init__(self, linear_layer, n_iters=1):
        super().__init__()
        self.linear = linear_layer
        self.n_iters = n_iters
        self.sigma = 1.0
        W = self.linear.weight
        self._u = nn.Parameter(F.normalize(torch.randn(W.size(0)), dim=0), requires_grad=False)
        self._v = nn.Parameter(F.normalize(torch.randn(W.size(1)), dim=0), requires_grad=False)

    def forward(self, x):
        W = self.linear.weight
        if self.training:
            u = self._u
            v = self._v
            for _ in range(self.n_iters):
                v = F.normalize(torch.mv(W.t(), u), dim=0)
                u = F.normalize(torch.mv(W, v), dim=0)
            self._u.data = u
            self._v.data = v
            self.sigma = torch.dot(u, torch.mv(W, v))
        return self.linear(x)

class Attention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            proj_bias: bool = True,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: Type[nn.Module] = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

    def get_attn_norm(self, x: torch.Tensor):
        W_split = self.proj.weight.view(self.dim, self.num_heads, -1)
        W_square = torch.einsum('dhs, dhn -> hsn', W_split, W_split)
        if self.proj.bias is not None:
            bias_split = self.proj.bias/self.num_heads
            bias_W = torch.einsum('D,Dhd->hd',bias_split, W_split)/(self.num_heads)
            term1 = torch.einsum('bhtd,hde,bhte->bt', x/self.dim, W_square, x/self.num_heads)
            term2 = 2*torch.einsum('hd,bhtd->bt', bias_W, x)/self.dim
            term3 = bias_split.square().mean().view(1, 1)
            projected_norm = term1 + term2 + term3
        else:
            term1 = torch.einsum('bhtd,hde,bhte->bt', x/self.dim, W_square, x/self.num_heads)
            projected_norm = term1
        return projected_norm
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v
        attn_norm = self.get_attn_norm(x)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn_norm

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep
        self.dropped_attn_norm = None
        
    def drop_path(self, x, attn_norm, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
        if drop_prob == 0. or not training:
            return x, attn_norm
        keep_prob = 1 - drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
        if keep_prob > 0.0 and scale_by_keep:
            random_tensor.div_(keep_prob)
        if attn_norm is None:
            x * random_tensor, None
        return x * random_tensor, attn_norm * random_tensor

    def forward(self, x, attn_norm=None):
        x, dropped_attn_norm = self.drop_path(x, attn_norm, self.drop_prob, self.training, self.scale_by_keep)
        self.dropped_attn_norm = dropped_attn_norm
        return x

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'        

class VisionTransformer_Reg(nn.Module):
    def __init__(self, model_name, num_classes):
        super().__init__()
        self.model = create_model(model_name, num_classes=num_classes, pretrained=False)
        for i in range(len(self.model.blocks)):
            block = self.model.blocks[i]
            dim = block.attn.qkv.in_features
            num_heads = block.attn.num_heads
            mlp_ratio = block.mlp.fc1.out_features/block.mlp.fc1.in_features
            qkv_bias = block.attn.qkv.bias is not None
            qk_norm = not isinstance(block.attn.q_norm, nn.Identity)
            proj_bias = block.attn.proj.bias is not None
            proj_drop = block.attn.proj_drop.p
            attn_drop = block.attn.attn_drop.p
            init_values = None
            if not isinstance(block.drop_path1, nn.Identity):
                drop_path = block.drop_path1.drop_prob
            else:
                drop_path = 0.0
            self.model.blocks[i] = Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                                   qkv_bias=qkv_bias, qk_norm=qk_norm, proj_bias=proj_bias,
                                   proj_drop=proj_drop, attn_drop=attn_drop, init_values=init_values,
                                   drop_path=drop_path)

    def forward(self, x: torch.Tensor):
        x = self.model(x)
        Regloss=0.0
        for block in self.model.blocks:
            Regloss += block.Reg_loss
        return x, Regloss/len(self.model.blocks)

class Block(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int,
            mlp_ratio: float = 4.,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            proj_bias: bool = True,
            proj_drop: float = 0.,
            attn_drop: float = 0.,
            init_values: Optional[float] = None,
            drop_path: float = 0.,
            act_layer: Type[nn.Module] = nn.GELU,
            norm_layer: Type[nn.Module] = nn.LayerNorm,
            mlp_layer: Type[nn.Module] = Mlp,
    ) -> None:
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            proj_bias=proj_bias,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) #if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            bias=proj_bias,
            drop=proj_drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.num_heads = num_heads
        self.Reg_loss = None
        
    # def get_spectral_norm(self, linear_layer, n_iters=1):
    #     W = linear_layer.weight
    #     u = getattr(linear_layer, '_u', None)
    #     v = getattr(linear_layer, '_v', None)
    #     if u is None:
    #         u = F.normalize(torch.randn(W.size(0), device=W.device), dim=0)
    #         v = F.normalize(torch.randn(W.size(1), device=W.device), dim=0)
    #         linear_layer.register_buffer('_u', u)
    #         linear_layer.register_buffer('_v', v)
    #     if self.training:
    #         with torch.no_grad():
    #             for _ in range(n_iters):
    #                 v = F.normalize(torch.mv(W.t(), u), dim=0, out=v)
    #                 u = F.normalize(torch.mv(W, v), dim=0, out=u)
    #     sigma = torch.dot(u, torch.mv(W, v))
    #     return sigma

    def get_ffn_norm(self, x):
        with torch.cuda.amp.autocast(enabled=False):
            x = x.float()
            gamma_max = torch.max(self.norm2.weight.detach())
            w1norm = self.mlp.fc1.sigma
            w2norm = self.mlp.fc2.sigma
            ffnnorm = (gamma_max*w1norm*w2norm/x.std(dim=-1) + 1).pow(2)
        return ffnnorm     
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_norm = x.pow(2).mean(dim=-1)/(self.num_heads)
        y, attn_norm = self.attn(self.norm1(x))
        x = x + self.drop_path1(y, attn_norm)
        fdual_norm = self.drop_path1.dropped_attn_norm + x_norm
        ffn_norm = self.get_ffn_norm(x)
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        self.Reg_loss = (fdual_norm*ffn_norm).mean()
        return x



