import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional

import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision.ops.misc import Conv2dNormActivation, MLP
from torchvision.transforms._presets import ImageClassification, InterpolationMode
from torchvision.utils import _log_api_usage_once
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._meta import _IMAGENET_CATEGORIES
from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface
from collections import OrderedDict
from copy import deepcopy
from .model_utils import reshape_conv_input_activation



__all__ = [
    "VisionTransformer",
    "ViT_B_16_Weights",
    "ViT_B_32_Weights",
    "ViT_L_16_Weights",
    "ViT_L_32_Weights",
    "ViT_H_14_Weights",
    "vit_b_16",
    "vit_b_32",
    "vit_l_16",
    "vit_l_32",
    "vit_h_14",
]


class ConvStemConfig(NamedTuple):
    out_channels: int
    kernel_size: int
    stride: int
    norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
    activation_layer: Callable[..., nn.Module] = nn.ReLU


class MLPBlock(MLP):
    """Transformer MLP block."""

    _version = 2

    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
            for i in range(2):
                for type in ["weight", "bias"]:
                    old_key = f"{prefix}linear_{i+1}.{type}"
                    new_key = f"{prefix}{3*i}.{type}"
                    if old_key in state_dict:
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )



class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        
        x, _ = self.self_attention(x, x, x, need_weights=False)
        # x = self.test_unrolled(x) # This verifies the correct unrolling of the above function. 
        x = self.dropout(x)
        x = x + input
        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y
    
    def test_unrolled(self, x):
        # Ingore dropout as representation collected only in eval mode.
        bsz, seq_len, embed_dim  = x.shape
        q, k, v = F._in_projection_packed(x, x, x, self.self_attention.in_proj_weight, self.self_attention.in_proj_bias)
        q = q.view(bsz, seq_len, self.num_heads, embed_dim//self.num_heads).permute(0,2,1,3)
        k = k.view(bsz, seq_len, self.num_heads, embed_dim//self.num_heads).permute(0,2,1,3)
        v = v.view(bsz, seq_len, self.num_heads, embed_dim//self.num_heads).permute(0,2,1,3)
        attn_output = F.scaled_dot_product_attention(q,k,v)
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(seq_len*bsz, embed_dim)
        attn_output = F.linear(attn_output, self.self_attention.out_proj.weight, self.self_attention.out_proj.bias)
        attn_output = attn_output.view(bsz, seq_len, embed_dim)
        return attn_output
        

    def get_activations(self, input):      
        act={"pre":OrderedDict(), "post":OrderedDict()}
        mlp_ind = 0
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)        
        
        ### Unroll attention here! 
        bsz, seq_len, embed_dim  = x.shape
        act["pre"][f"self_attn.qkv"] = deepcopy(x.clone().detach().view(-1,  x.shape[-1]).cpu().numpy())
        q, k, v = F._in_projection_packed(x, x, x, self.self_attention.in_proj_weight, self.self_attention.in_proj_bias)
        act["post"][f"self_attn.query"] = deepcopy(q.clone().detach().view(-1,  q.shape[-1]).cpu().numpy())
        act["post"][f"self_attn.key"] = deepcopy(k.clone().detach().view(-1,  k.shape[-1]).cpu().numpy())
        act["post"][f"self_attn.value"] = deepcopy(v.clone().detach().view(-1,  v.shape[-1]).cpu().numpy())
        q = q.view(bsz, seq_len, self.num_heads, embed_dim//self.num_heads).permute(0,2,1,3)
        k = k.view(bsz, seq_len, self.num_heads, embed_dim//self.num_heads).permute(0,2,1,3)
        v = v.view(bsz, seq_len, self.num_heads, embed_dim//self.num_heads).permute(0,2,1,3)
        attn_output = F.scaled_dot_product_attention(q,k,v)
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(seq_len*bsz, embed_dim)
        act["pre"][f"self_attn.out_proj"] = deepcopy(attn_output.clone().detach().view(-1,  attn_output.shape[-1]).cpu().numpy())
        attn_output = F.linear(attn_output, self.self_attention.out_proj.weight, self.self_attention.out_proj.bias)
        act["post"][f"self_attn.out_proj"] = deepcopy(attn_output.clone().detach().view(-1,  attn_output.shape[-1]).cpu().numpy())
        x = attn_output.view(bsz, seq_len, embed_dim)
        x = self.dropout(x)
        x = x + input
        y = self.ln_2(x)        
        for layer in self.mlp:
            if isinstance(layer, nn.Linear):
                act["pre"][f"mlp.linear{mlp_ind}"] = deepcopy(y.clone().detach().view(-1,  y.shape[-1]).cpu().numpy())
                y= layer(y)
                act["post"][f"mlp.linear{mlp_ind}"] = deepcopy(y.clone().detach().view(-1,  y.shape[-1]).cpu().numpy())
                mlp_ind+=1
            else:
                y= layer(y)
        return act, x+y 
    
    def project_weights(self, projection_mat_dict):
        mlp_ind = 0
        
        w_q, w_k, w_v = self.self_attention.in_proj_weight.chunk(3)
        w_q.data = torch.mm(projection_mat_dict["post"][f"self_attn.query"].transpose(0,1) , torch.mm(w_q.data, projection_mat_dict["pre"][f"self_attn.qkv"].transpose(0,1) )).view_as(w_q.data)
        w_k.data = torch.mm(projection_mat_dict["post"][f"self_attn.key"].transpose(0,1) , torch.mm(w_k.data, projection_mat_dict["pre"][f"self_attn.qkv"].transpose(0,1) )).view_as(w_k.data)
        w_v.data = torch.mm(projection_mat_dict["post"][f"self_attn.value"].transpose(0,1) , torch.mm(w_v.data, projection_mat_dict["pre"][f"self_attn.qkv"].transpose(0,1) )).view_as(w_v.data)
        self.self_attention.in_proj_weight.data = torch.cat((w_q.data, w_k.data, w_v.data ), 0)
        
        if self.self_attention.in_proj_bias is not None:
            b_q, b_k, b_v = self.self_attention.in_proj_bias.chunk(3)
            b_q.data  = torch.mm(b_q.data.unsqueeze(0),projection_mat_dict["post"][f"self_attn.query"]).squeeze(0)
            b_k.data  = torch.mm(b_k.data.unsqueeze(0),projection_mat_dict["post"][f"self_attn.key"]).squeeze(0)  
            b_v.data  = torch.mm(b_v.data.unsqueeze(0),projection_mat_dict["post"][f"self_attn.value"]).squeeze(0) 
            self.self_attention.in_proj_bias.data = torch.cat((b_q.data, b_k.data, b_v.data ), 0)
        
        self.self_attention.out_proj.weight.data = torch.mm(projection_mat_dict["post"][f"self_attn.out_proj"].transpose(0,1) , torch.mm(self.self_attention.out_proj.weight.data, projection_mat_dict["pre"][f"self_attn.out_proj"].transpose(0,1) )).view_as(w_v.data)
        if self.self_attention.out_proj.bias is not None:
            self.self_attention.out_proj.bias.data  = torch.mm(self.self_attention.out_proj.bias.data.unsqueeze(0),projection_mat_dict["post"][f"self_attn.out_proj"]).squeeze(0)
        
        for layer in self.mlp:
            if isinstance(layer, nn.Linear):
                layer.weight.data = torch.mm(projection_mat_dict["post"][f"mlp.linear{mlp_ind}"].transpose(0,1) ,torch.mm(layer.weight.data.flatten(1), projection_mat_dict["pre"][f"mlp.linear{mlp_ind}"].transpose(0,1))).view_as(layer.weight.data)
                if layer.bias is not None:
                    layer.bias.data  = torch.mm(layer.bias.data.unsqueeze(0),projection_mat_dict["post"][f"mlp.linear{mlp_ind}"]).squeeze(0)
                mlp_ind+=1
        return


class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        input = input + self.pos_embedding
        return self.ln(self.layers(self.dropout(input)))    
    
    def get_activations(self, input):      
        act={"pre":OrderedDict(), "post":OrderedDict()}
        encoderblock_ind=0
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        input = input + self.pos_embedding
        input = self.dropout(input)
        for layer in self.layers:
            layer_acts, input = layer.get_activations(input)
            for loc in layer_acts.keys():
                for key in layer_acts[loc].keys():
                    act[loc][f"encoderblock{encoderblock_ind}.{key}"] = layer_acts[loc][key]
            encoderblock_ind+=1        
        return act, self.ln(input)
    
    def project_weights(self, projection_mat_dict):
        encoderblock_ind=0
        for layer in self.layers:
            layer_proj_mat = {"pre":OrderedDict(), "post":OrderedDict()}
            for loc in projection_mat_dict.keys():
                for key in projection_mat_dict[loc].keys():
                    if f"encoderblock{encoderblock_ind}." in key:
                        layer_proj_mat[loc][".".join(key.split(".")[1:])] = projection_mat_dict[loc][key]
            layer.project_weights(layer_proj_mat)
            encoderblock_ind+=1
        return


class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        conv_stem_configs: Optional[List[ConvStemConfig]] = None,
    ):
        super().__init__()
        _log_api_usage_once(self)
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer

        if conv_stem_configs is not None:
            # As per https://arxiv.org/abs/2106.14881
            seq_proj = nn.Sequential()
            prev_channels = 3
            for i, conv_stem_layer_config in enumerate(conv_stem_configs):
                seq_proj.add_module(
                    f"conv_bn_relu_{i}",
                    Conv2dNormActivation(
                        in_channels=prev_channels,
                        out_channels=conv_stem_layer_config.out_channels,
                        kernel_size=conv_stem_layer_config.kernel_size,
                        stride=conv_stem_layer_config.stride,
                        norm_layer=conv_stem_layer_config.norm_layer,
                        activation_layer=conv_stem_layer_config.activation_layer,
                    ),
                )
                prev_channels = conv_stem_layer_config.out_channels
            seq_proj.add_module(
                "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
            )
            self.conv_proj: nn.Module = seq_proj
        else:
            self.conv_proj = nn.Conv2d(
                in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
            )

        seq_length = (image_size // patch_size) ** 2

        # Add a class token
        self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        seq_length += 1

        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            norm_layer,
        )
        self.seq_length = seq_length

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        if isinstance(self.conv_proj, nn.Conv2d):
            # Init the patchify stem
            fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
            nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
            if self.conv_proj.bias is not None:
                nn.init.zeros_(self.conv_proj.bias)
        elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
            # Init the last 1x1 conv of the conv stem
            nn.init.normal_(
                self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
            )
            if self.conv_proj.conv_last.bias is not None:
                nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            nn.init.zeros_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]
        
        x = self.heads(x)
        x = F.log_softmax(x, dim=1)
        return x
    
    def get_activations(self, x):      
        act={"pre":OrderedDict(), "post":OrderedDict()}
        conv_proj_ind=0
        heads_ind=0
        ### x = self._process_input(x)
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # x = self.conv_proj(x)
        if isinstance(self.conv_proj, nn.Conv2d):
            act["pre"][f"conv_proj.conv{conv_proj_ind}"]=reshape_conv_input_activation(deepcopy(x.clone().detach()), self.conv_proj).cpu().numpy()
            x = self.conv_proj(x)
            act["post"][f"conv_proj.conv{conv_proj_ind}"]=deepcopy(x.permute(0,2,3,1).clone().detach().cpu().numpy().reshape(-1, x.shape[1]))
            conv_proj_ind+=1
        else:
            for layer in self.conv_proj:
                if isinstance(layer, nn.Conv2d):
                    act["pre"][f"conv_proj.conv{conv_proj_ind}"]=reshape_conv_input_activation(deepcopy(x.clone().detach()), layer).cpu().numpy()
                    x = layer(x)
                    act["post"][f"conv_proj.conv{conv_proj_ind}"]=deepcopy(x.permute(0,2,3,1).clone().detach().cpu().numpy().reshape(-1, x.shape[1]))
                    conv_proj_ind+=1
                else:
                    x = layer(x)

        x = x.reshape(n, self.hidden_dim, n_h * n_w)        

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)
        
        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        encoder_activation, x = self.encoder.get_activations(x)

        #### Update act with encoder activations.
        for loc in encoder_activation.keys():
            for key in encoder_activation[loc].keys():
                act[loc][f"encoder.{key}"] = encoder_activation[loc][key]

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]

        # x = self.heads(x)
        if isinstance(self.heads, nn.Linear):        
            act["pre"][f"heads.linear{heads_ind}"]=deepcopy(x.clone().detach().view(-1, x.shape[-1]).cpu().numpy())
            x = self.conv_proj(x)
            act["post"][f"heads.linear{heads_ind}"]=deepcopy(x.clone().detach().view(-1, x.shape[-1]).cpu().numpy())
            heads_ind+=1
        else:
            for layer in self.heads:
                if isinstance(layer, nn.Linear):
                    act["pre"][f"heads.linear{heads_ind}"]=deepcopy(x.clone().detach().view(-1, x.shape[-1]).cpu().numpy())
                    x = layer(x)
                    act["post"][f"heads.linear{heads_ind}"]=deepcopy(x.clone().detach().view(-1, x.shape[-1]).cpu().numpy())
                    heads_ind+=1
                else:
                    x = layer(x)
        self.max_head_linear_layers = heads_ind
        return act
    
    def project_weights(self, projection_mat_dict):
        conv_proj_ind=0
        heads_ind=0
        if isinstance(self.conv_proj, nn.Conv2d):
            self.conv_proj.weight.data = torch.mm(projection_mat_dict["post"][f"conv_proj.conv{conv_proj_ind}"].transpose(0,1) ,torch.mm(self.conv_proj.weight.data.flatten(1), projection_mat_dict["pre"][f"conv_proj.conv{conv_proj_ind}"].transpose(0,1))).view_as(self.conv_proj.weight.data)
            if self.conv_proj.bias is not None:
                self.conv_proj.bias.data  = torch.mm(self.conv_proj.bias.data.unsqueeze(0),projection_mat_dict["post"][f"conv_proj.conv{conv_proj_ind}"]).squeeze(0)
            conv_proj_ind+=1
        else:
            for layer in self.conv_proj:
                if isinstance(layer, nn.Conv2d):
                    layer.weight.data = torch.mm(projection_mat_dict["post"][f"conv_proj.conv{conv_proj_ind}"].transpose(0,1) ,torch.mm(layer.weight.data.flatten(1), projection_mat_dict["pre"][f"conv_proj.conv{conv_proj_ind}"].transpose(0,1))).view_as(layer.weight.data)
                    if layer.bias is not None:
                        layer.bias.data  = torch.mm(layer.bias.data.unsqueeze(0),projection_mat_dict["post"][f"conv_proj.conv{conv_proj_ind}"]).squeeze(0)
                    conv_proj_ind+=1
        
        encoder_mat_dict = {"pre":OrderedDict(), "post":OrderedDict()}
        for loc in projection_mat_dict.keys():
            for key in projection_mat_dict[loc].keys():
                if "encoder." in key:
                    encoder_mat_dict[loc][".".join(key.split(".")[1:])] = projection_mat_dict[loc][key]
                    
        self.encoder.project_weights(encoder_mat_dict)

        if isinstance(self.heads, nn.Linear):
            self.heads.weight.data  = torch.mm(self.heads.weight.data , projection_mat_dict["pre"][f"heads.linear{heads_ind}"].transpose(0,1) )
            # self.heads.weight.data = torch.mm(projection_mat_dict["post"][f"heads.linear{heads_ind}"].transpose(0,1) ,torch.mm(self.heads.weight.data.flatten(1), projection_mat_dict["pre"][f"heads.linear{heads_ind}"].transpose(0,1))).view_as(self.conv_proj.weight.data)
            # if self.heads.bias is not None:
            #     self.heads.bias.data  = torch.mm(self.heads.bias.data.unsqueeze(0),projection_mat_dict["post"][f"heads.linear{heads_ind}"]).squeeze(0)
            heads_ind+=1
        else:
            for layer in self.heads:
                if isinstance(layer, nn.Linear):
                    if heads_ind == self.max_head_linear_layers-1:
                        # LAST layer avoid output projections
                        layer.weight.data  = torch.mm(layer.weight.data , projection_mat_dict["pre"][f"heads.linear{heads_ind}"].transpose(0,1) )
                        heads_ind+=1
                    else:
                        layer.weight.data = torch.mm(projection_mat_dict["post"][f"heads.linear{heads_ind}"].transpose(0,1) ,torch.mm(layer.weight.data.flatten(1), projection_mat_dict["pre"][f"heads.linear{heads_ind}"].transpose(0,1))).view_as(layer.weight.data)
                        if layer.bias is not None:
                            layer.bias.data  = torch.mm(layer.bias.data.unsqueeze(0),projection_mat_dict["post"][f"heads.linear{heads_ind}"]).squeeze(0)
                        heads_ind+=1
        return 
    

def _vision_transformer(
    patch_size: int,
    num_layers: int,
    num_heads: int,
    hidden_dim: int,
    mlp_dim: int,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> VisionTransformer:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
        _ovewrite_named_param(kwargs, "image_size", weights.meta["min_size"][0])
    image_size = kwargs.pop("image_size", 224)

    model = VisionTransformer(
        image_size=image_size,
        patch_size=patch_size,
        num_layers=num_layers,
        num_heads=num_heads,
        hidden_dim=hidden_dim,
        mlp_dim=mlp_dim,
        **kwargs,
    )

    if weights:
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

    return model


_COMMON_META: Dict[str, Any] = {
    "categories": _IMAGENET_CATEGORIES,
}

_COMMON_SWAG_META = {
    **_COMMON_META,
    "recipe": "https://github.com/facebookresearch/SWAG",
    "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
}


class ViT_B_16_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vit_b_16-c867db91.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 86567656,
            "min_size": (224, 224),
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.072,
                    "acc@5": 95.318,
                }
            },
            "_ops": 17.564,
            "_file_size": 330.285,
            "_docs": """
                These weights were trained from scratch by using a modified version of `DeIT
                <https://arxiv.org/abs/2012.12877>`_'s training recipe.
            """,
        },
    )
    IMAGENET1K_SWAG_E2E_V1 = Weights(
        url="https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth",
        transforms=partial(
            ImageClassification,
            crop_size=384,
            resize_size=384,
            interpolation=InterpolationMode.BICUBIC,
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 86859496,
            "min_size": (384, 384),
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 85.304,
                    "acc@5": 97.650,
                }
            },
            "_ops": 55.484,
            "_file_size": 331.398,
            "_docs": """
                These weights are learnt via transfer learning by end-to-end fine-tuning the original
                `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
            """,
        },
    )
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/vit_b_16_lc_swag-4e70ced5.pth",
        transforms=partial(
            ImageClassification,
            crop_size=224,
            resize_size=224,
            interpolation=InterpolationMode.BICUBIC,
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 86567656,
            "min_size": (224, 224),
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.886,
                    "acc@5": 96.180,
                }
            },
            "_ops": 17.564,
            "_file_size": 330.285,
            "_docs": """
                These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
                weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V1


class ViT_B_32_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 88224232,
            "min_size": (224, 224),
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 75.912,
                    "acc@5": 92.466,
                }
            },
            "_ops": 4.409,
            "_file_size": 336.604,
            "_docs": """
                These weights were trained from scratch by using a modified version of `DeIT
                <https://arxiv.org/abs/2012.12877>`_'s training recipe.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V1


class ViT_L_16_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=242),
        meta={
            **_COMMON_META,
            "num_params": 304326632,
            "min_size": (224, 224),
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 79.662,
                    "acc@5": 94.638,
                }
            },
            "_ops": 61.555,
            "_file_size": 1161.023,
            "_docs": """
                These weights were trained from scratch by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    IMAGENET1K_SWAG_E2E_V1 = Weights(
        url="https://download.pytorch.org/models/vit_l_16_swag-4f3808c9.pth",
        transforms=partial(
            ImageClassification,
            crop_size=512,
            resize_size=512,
            interpolation=InterpolationMode.BICUBIC,
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 305174504,
            "min_size": (512, 512),
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 88.064,
                    "acc@5": 98.512,
                }
            },
            "_ops": 361.986,
            "_file_size": 1164.258,
            "_docs": """
                These weights are learnt via transfer learning by end-to-end fine-tuning the original
                `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
            """,
        },
    )
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/vit_l_16_lc_swag-4d563306.pth",
        transforms=partial(
            ImageClassification,
            crop_size=224,
            resize_size=224,
            interpolation=InterpolationMode.BICUBIC,
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 304326632,
            "min_size": (224, 224),
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 85.146,
                    "acc@5": 97.422,
                }
            },
            "_ops": 61.555,
            "_file_size": 1161.023,
            "_docs": """
                These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
                weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V1


class ViT_L_32_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/vit_l_32-c7638314.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 306535400,
            "min_size": (224, 224),
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 76.972,
                    "acc@5": 93.07,
                }
            },
            "_ops": 15.378,
            "_file_size": 1169.449,
            "_docs": """
                These weights were trained from scratch by using a modified version of `DeIT
                <https://arxiv.org/abs/2012.12877>`_'s training recipe.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V1


class ViT_H_14_Weights(WeightsEnum):
    IMAGENET1K_SWAG_E2E_V1 = Weights(
        url="https://download.pytorch.org/models/vit_h_14_swag-80465313.pth",
        transforms=partial(
            ImageClassification,
            crop_size=518,
            resize_size=518,
            interpolation=InterpolationMode.BICUBIC,
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 633470440,
            "min_size": (518, 518),
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 88.552,
                    "acc@5": 98.694,
                }
            },
            "_ops": 1016.717,
            "_file_size": 2416.643,
            "_docs": """
                These weights are learnt via transfer learning by end-to-end fine-tuning the original
                `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
            """,
        },
    )
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/vit_h_14_lc_swag-c1eb923e.pth",
        transforms=partial(
            ImageClassification,
            crop_size=224,
            resize_size=224,
            interpolation=InterpolationMode.BICUBIC,
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 632045800,
            "min_size": (224, 224),
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 85.708,
                    "acc@5": 97.730,
                }
            },
            "_ops": 167.295,
            "_file_size": 2411.209,
            "_docs": """
                These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
                weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
            """,
        },
    )
    DEFAULT = IMAGENET1K_SWAG_E2E_V1


# @register_model()
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
    """
    Constructs a vit_b_16 architecture from
    `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.

    Args:
        weights (:class:`~torchvision.models.ViT_B_16_Weights`, optional): The pretrained
            weights to use. See :class:`~torchvision.models.ViT_B_16_Weights`
            below for more details and possible values. By default, no pre-trained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ViT_B_16_Weights
        :members:
    """
    weights = ViT_B_16_Weights.verify(weights)

    return _vision_transformer(
        patch_size=16,
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        mlp_dim=3072,
        weights=weights,
        progress=progress,
        **kwargs,
    )


# @register_model()
@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1))
def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
    """
    Constructs a vit_b_32 architecture from
    `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.

    Args:
        weights (:class:`~torchvision.models.ViT_B_32_Weights`, optional): The pretrained
            weights to use. See :class:`~torchvision.models.ViT_B_32_Weights`
            below for more details and possible values. By default, no pre-trained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ViT_B_32_Weights
        :members:
    """
    weights = ViT_B_32_Weights.verify(weights)

    return _vision_transformer(
        patch_size=32,
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        mlp_dim=3072,
        weights=weights,
        progress=progress,
        **kwargs,
    )


# @register_model()
@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1))
def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
    """
    Constructs a vit_l_16 architecture from
    `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.

    Args:
        weights (:class:`~torchvision.models.ViT_L_16_Weights`, optional): The pretrained
            weights to use. See :class:`~torchvision.models.ViT_L_16_Weights`
            below for more details and possible values. By default, no pre-trained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ViT_L_16_Weights
        :members:
    """
    weights = ViT_L_16_Weights.verify(weights)

    return _vision_transformer(
        patch_size=16,
        num_layers=24,
        num_heads=16,
        hidden_dim=1024,
        mlp_dim=4096,
        weights=weights,
        progress=progress,
        **kwargs,
    )


# @register_model()
@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1))
def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
    """
    Constructs a vit_l_32 architecture from
    `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.

    Args:
        weights (:class:`~torchvision.models.ViT_L_32_Weights`, optional): The pretrained
            weights to use. See :class:`~torchvision.models.ViT_L_32_Weights`
            below for more details and possible values. By default, no pre-trained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ViT_L_32_Weights
        :members:
    """
    weights = ViT_L_32_Weights.verify(weights)

    return _vision_transformer(
        patch_size=32,
        num_layers=24,
        num_heads=16,
        hidden_dim=1024,
        mlp_dim=4096,
        weights=weights,
        progress=progress,
        **kwargs,
    )


# @register_model()
@handle_legacy_interface(weights=("pretrained", None))
def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
    """
    Constructs a vit_h_14 architecture from
    `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.

    Args:
        weights (:class:`~torchvision.models.ViT_H_14_Weights`, optional): The pretrained
            weights to use. See :class:`~torchvision.models.ViT_H_14_Weights`
            below for more details and possible values. By default, no pre-trained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.vision_transformer.VisionTransformer``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ViT_H_14_Weights
        :members:
    """
    weights = ViT_H_14_Weights.verify(weights)

    return _vision_transformer(
        patch_size=14,
        num_layers=32,
        num_heads=16,
        hidden_dim=1280,
        mlp_dim=5120,
        weights=weights,
        progress=progress,
        **kwargs,
    )


def interpolate_embeddings(
    image_size: int,
    patch_size: int,
    model_state: "OrderedDict[str, torch.Tensor]",
    interpolation_mode: str = "bicubic",
    reset_heads: bool = False,
) -> "OrderedDict[str, torch.Tensor]":
    """This function helps interpolate positional embeddings during checkpoint loading,
    especially when you want to apply a pre-trained model on images with different resolution.

    Args:
        image_size (int): Image size of the new model.
        patch_size (int): Patch size of the new model.
        model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
        interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
        reset_heads (bool): If true, not copying the state of heads. Default: False.

    Returns:
        OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
    """
    # Shape of pos_embedding is (1, seq_length, hidden_dim)
    pos_embedding = model_state["encoder.pos_embedding"]
    n, seq_length, hidden_dim = pos_embedding.shape
    if n != 1:
        raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")

    new_seq_length = (image_size // patch_size) ** 2 + 1

    # Need to interpolate the weights for the position embedding.
    # We do this by reshaping the positions embeddings to a 2d grid, performing
    # an interpolation in the (h, w) space and then reshaping back to a 1d grid.
    if new_seq_length != seq_length:
        # The class token embedding shouldn't be interpolated, so we split it up.
        seq_length -= 1
        new_seq_length -= 1
        pos_embedding_token = pos_embedding[:, :1, :]
        pos_embedding_img = pos_embedding[:, 1:, :]

        # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
        pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
        seq_length_1d = int(math.sqrt(seq_length))
        if seq_length_1d * seq_length_1d != seq_length:
            raise ValueError(
                f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}"
            )

        # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
        pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
        new_seq_length_1d = image_size // patch_size

        # Perform interpolation.
        # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
        new_pos_embedding_img = nn.functional.interpolate(
            pos_embedding_img,
            size=new_seq_length_1d,
            mode=interpolation_mode,
            align_corners=True,
        )

        # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
        new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)

        # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
        new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
        new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)

        model_state["encoder.pos_embedding"] = new_pos_embedding

        if reset_heads:
            model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
            for k, v in model_state.items():
                if not k.startswith("heads"):
                    model_state_copy[k] = v
            model_state = model_state_copy

    return model_state
