# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.


from functools import partial

import numpy as np
import torch.nn.functional as F
from timm.models.layers import DropPath, trunc_normal_
from timm.models.vision_transformer import Mlp
from torchinfo import summary
from torchvision.transforms.functional import resize
import torch
import torch.nn as nn
from functools import partial

from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
# from mmcls.models.backbones.rednet import RedNet
from mmcls.models.backbones.rednet import RedNet
class PatchEmbed(nn.Module):
    """2D Image to Patch Embedding"""

    def __init__(
        self,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        norm_layer=None,
        flatten=True,
    ):
        super().__init__()
        self.flatten = flatten

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )

        # self.scale_proj = nn.Conv2d(
        #     in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        # )
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
        # self.pooling = nn.AdaptiveAvgPool2d((14, 14))

    # def forward_multiscale(self, x):
    #     s = resize(x, (112, 112))
    #     s = self.scale_proj(s)
    #     return s

    def forward(self, x):
        # s = self.forward_multiscale(x)
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
            # s = s.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x


class Attention(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class Experts_MOS(nn.Module):
    def __init__(
        self,
        embed_dim=768,
        juery_nums=3,
    ):
        super().__init__()
        self.juery = juery_nums
        bunch_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            dropout=0.0,
            nhead=6,
            activation=F.gelu,
            batch_first=True,
            dim_feedforward=(embed_dim * 4),
            norm_first=True,
        )
        self.bunch_decoder = nn.TransformerDecoder(bunch_layer, num_layers=1)
        self.bunch_embedding = nn.Parameter(torch.randn(1, juery_nums, embed_dim)) #[1,6,384]
        self.heads = nn.Linear(embed_dim, 1, bias=False)
        # #修改4、
        # self.head_conv = nn.Linear(self.embed_dim, 1,bias=False)
        # self.head_inv = nn.Linear(self.embed_dim, 1,bias=False)

        trunc_normal_(self.bunch_embedding, std=0.02)

    def forward(self, x, x_conv, x_inv, ref): #x:[32,196,384] ref:[32,384]
        B, L, D = x.shape #B:32 D:384 L:196
        bunch_embedding = self.bunch_embedding.expand(B, -1, -1) #[32,3,384] 这句话就是说把[6,384]重复32次
        ref = ref.view(B, 1, -1) #[32,1,384]
        x_conv = x_conv.view(B, 1, -1)
        x_inv = x_inv.view(B, 1, -1)
        # ref = ref.expand(B, self.juery, -1) #[32,6,384] #-1表示当前维度不变
        token_embedding = torch.cat((ref, x_conv, x_inv), dim=1)#[32,3,384]
        output_embedding = bunch_embedding + token_embedding#[32,3,384]
        x = self.bunch_decoder(output_embedding, x)
        x = self.heads(x)  #[32,3,384]->[32,3,1]
        # x = x.view(B, -1).mean(dim=1) #x.view(B, -1)->[32,3] x.view(B, -1).mean(dim=1)->[32]
        return x[:, 0], x[:, 1], x[:, 2], self.heads.weight


class Layer_scale_init_Block(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    # with slight modifications
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)

    def forward(self, x):
        x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class Layer_scale_init_Block_paralx2(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    # with slight modifications
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm11 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.attn1 = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.norm21 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.mlp1 = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
        self.gamma_1_1 = nn.Parameter(
            init_values * torch.ones((dim)), requires_grad=True
        )
        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
        self.gamma_2_1 = nn.Parameter(
            init_values * torch.ones((dim)), requires_grad=True
        )

    def forward(self, x):
        x = (
            x
            + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
            + self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
        )
        x = (
            x
            + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
            + self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
        )
        return x


class Block_paralx2(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    # with slight modifications
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm11 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.attn1 = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.norm21 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.mlp1 = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x):
        x = (
            x
            + self.drop_path(self.attn(self.norm1(x)))
            + self.drop_path(self.attn1(self.norm11(x)))
        )
        x = (
            x
            + self.drop_path(self.mlp(self.norm2(x)))
            + self.drop_path(self.mlp1(self.norm21(x)))
        )
        return x


#Conv token 的 embedding
class embed_conv(nn.Module):
    def __init__(self, embed_dim):
        super(embed_conv, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
                                  nn.BatchNorm2d(64), nn.ReLU(True),
                                  nn.Conv2d(64,64,3,1,1, bias=False),
                                  nn.BatchNorm2d(64), nn.ReLU(True),
                                  nn.Conv2d(64, 64, 3, 1, 1, bias=False),
                                  nn.BatchNorm2d(64), nn.ReLU(True))
        self.proj = nn.Conv2d(64, embed_dim, kernel_size=8, stride=8)
    def forward(self, x):
        x = self.proj(self.conv(x))
        x = x.flatten(2).transpose(1,2)
        return x

#Inv token 的 embedding
class embed_inv(nn.Module):
    def __init__(self, embed_dim):
        super(embed_inv, self).__init__()
        red26 = RedNet(26)
        self.stem = red26.stem
        # self.conv = nn.Sequential(self.stem,
        #                           nn.Conv2d(32, 64, 3, 1, 1, bias=False),
        #                           nn.BatchNorm2d(32), nn.ReLU(True),
        #                           nn.Conv2d(32, 64, 3, 1, 1, bias=False),
        #                           nn.BatchNorm2d(64), nn.ReLU(True))
        self.proj = nn.Conv2d(64, embed_dim, kernel_size=8, stride=8)
    def forward(self, x):
        x = self.proj(self.stem(x))
        x = x.flatten(2).transpose(1,2)
        return x

class FeatureLoss(nn.Module):
    """PyTorch version of `Masked Generative Distillation`

    Args:
        student_channels(int): Number of channels in the student's feature map.
        teacher_channels(int): Number of channels in the teacher's feature map.
        name (str): the loss name of the layer
        alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002
        lambda_mgd (float, optional): masked ratio. Defaults to 0.65
    """

    def __init__(self,student_channels,teacher_channels, alpha_mgd=0.00002,lambda_mgd=0.65,):
        super(FeatureLoss, self).__init__()
        self.alpha_mgd = alpha_mgd
        self.lambda_mgd = lambda_mgd

        if student_channels != teacher_channels:
            self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0).cuda()
        else:
            self.align = None

        self.generation = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1),
            #修改11.28
            nn.BatchNorm2d(teacher_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1)).cuda()
        # self.pool = nn.AdaptiveMaxPool2d((14,14))
        self.fc = nn.Linear( 196, 625 ,bias=False).cuda()
        self.relu = nn.ReLU(inplace=True)
    def forward(self,preds_S, preds_T):
        """Forward function.
        Args:
            preds_S(Tensor): Bs*C*H*W, student's feature map [B,196,384]
            preds_T(Tensor): Bs*C*H*W, teacher's feature map [B,625,384]
        """
        # assert preds_S.shape[-2:] == preds_T.shape[-2:]

        loss_mse = torch.nn.MSELoss(reduction='sum')
        # torch.Size([12, 196, 384])
        preds_S = self.fc(preds_S.transpose(1, 2)).cuda()
        preds_S = self.relu(preds_S)
        B, C, new_HW = preds_S.shape #[12 ,384, 625]
        preds_S = preds_S.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))#[B,384,25,25]
        B_t, new_HW_t, C_t = preds_T.shape
        preds_T = preds_T.transpose(1, 2).reshape(B_t, C_t, int(np.sqrt(new_HW_t)), int(np.sqrt(new_HW_t)))#[B,384,25,25]
        # N, s_channels, H_s, W_s = preds_S.shape
        N, t_channels, H_t, W_t = preds_T.shape
        # if H_s != H_t:
        #     preds_T = self.pool(preds_T)
        if self.align is not None:
            preds_S = self.align(preds_S)
        device = preds_S.device
        mat = torch.rand((N, 1, H_t, W_t)).to(device) #[12 ,384, 625]

        mat = torch.where(mat > 1 - self.lambda_mgd, 0, 1).to(device)  # 得到mask图
        masked_fea = torch.mul(preds_S, mat)  # 得到mask后的特征图

        new_fea =self.generation(masked_fea)  # 生成特征图
        loss=loss_mse(new_fea, preds_T) / N
        dis_loss = loss * self.alpha_mgd

        return dis_loss

class vit_models_distill(nn.Module):
    """Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support"""

    def __init__(
        self,
        patch_size=16,
        in_chans=3,
        num_classes=1,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        global_pool=None,
        block_layers=Block,
        Patch_layer=PatchEmbed,
        act_layer=nn.GELU,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_scale=1e-4,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim

        self.patch_embed = Patch_layer(
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )

        num_patches = 196

        # 修改1、
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # 修改2、
        self.conv_embed = embed_conv(self.embed_dim)
        self.inv_embed = embed_inv(self.embed_dim)
        # 修改3、
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.pos_embed_c_i=nn.Parameter(torch.zeros(1, 3, embed_dim))
        #修改5、
        # self.head_inv.apply(self._init_weights)
        # self.head_conv.apply(self._init_weights)
        self.conv_embed.apply(self.init_weight)
        self.inv_embed.apply(self.init_weight)

        self.MGDLOSS = FeatureLoss(student_channels=384,teacher_channels=384)
        dpr = [drop_path_rate for i in range(depth)]
        self.blocks = nn.ModuleList(
            [
                block_layers(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=0.0,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    act_layer=act_layer,
                    Attention_block=Attention_block,
                    Mlp_block=Mlp_block,
                    init_values=init_scale,
                )
                for i in range(depth)
            ]
        )

        self.norm = norm_layer(embed_dim)

        self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")]
        # self.head = (
        #     nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        # )

        self.head = Experts_MOS(embed_dim=384)

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    #修改6
    def init_weight(self, module):
        for m in module.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)


    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token", "conv_embed", "inv_embed"}

    def get_classifier(self):
        return self.head

    def get_num_layers(self):
        return len(self.blocks)

    def reset_classifier(self, num_classes, global_pool=""):
        self.num_classes = num_classes
        self.head = (
            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        )

    # def forward_pos_scale(self):
    #     pos_embed = self.pos_embed.transpose(1, 2).view(1, -1, 14, 14)
    #     pos_embed = F.interpolate(pos_embed, (7, 7), mode="bilinear").flatten(2)
    #     return pos_embed.transpose(1, 2)

    def forward_features(self, x):#encoder
        B = x.shape[0]  #[32,3,224,224] #B=12 train #torch.Size([12, 3, 224, 224])
        # 修改8、
        # x = self.patch_embed(x) #[32,3,224,224]->[32,196,384]
        cls_tokens = self.cls_token.expand(B, -1, -1)   #[32,1,384]
        #修改7
        x_conv = self.conv_embed(x)  # B N C
        with torch.cuda.amp.autocast(False):
            x_inv = self.inv_embed(x)
        conv_token = x_conv.mean(1, keepdim=True)
        inv_token = x_inv.mean(1, keepdim=True)

        # 修改8、
        x_3 = torch.cat((cls_tokens, conv_token, inv_token), dim=1)
        x_3 = x_3+self.pos_embed_c_i
        x_196 = self.patch_embed(x) + self.pos_embed
        x =torch.cat((x_3, x_196), dim=1)
        # x = torch.cat((cls_tokens, conv_token, inv_token, x_conv + x_inv), dim=1)
        # x = torch.cat((cls_tokens, x), dim=1) #[32,196,384]->[32,197,384]->[32,199,384]
        # x = x + self.pos_embed
        # s = s + self.forward_pos_scale()

        # x = torch.cat((x, s), dim=1)

        # x = torch.cat((cls_tokens, x), dim=1) #[32,196,384]->[32,197,384]
        # 修改10、
        mlp_inner_feature = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            mlp_inner_feature.append(x[:, 3:, :])

        x = self.norm(x)
        #修改9、
        return x[:, 0], x[:, 1], x[:, 2], x[:, 3:, :], mlp_inner_feature
        #x[:, 0]:[32,384];x[:, 1:, :]:[32,196,384]
        # return x

    def forward(self, x , t_encode_inner_feature=None):
        # ref, x = self.forward_features(x) #ref:[32,384] x:[32,196,384]
        feature_loss=0.0
        ref, x_conv, x_inv, x ,s_encode_inner_feature= self.forward_features(x)#ref:[32,384] x_conv:[32,384] x_inv:[32,384],x:[32,196,384]
        x1, x2, x3, weight = self.head(x, x_conv, x_inv, ref) #用decoder预测出回归值
        # 特征蒸馏
        if t_encode_inner_feature != None:
            for t_encode_feature, s_encode_feature in zip(t_encode_inner_feature, s_encode_inner_feature):
                #     # mse_loss
                feature_loss_single_layer = self.MGDLOSS(s_encode_feature, t_encode_feature)
                feature_loss += feature_loss_single_layer
        # x = self.head(ref)
            return x1, x2, x3, s_encode_inner_feature, weight ,feature_loss
        else:
            return x1, x2, x3

# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
def build_vit_distill(
    patch_size=16,
    embed_dim=384,
    depth=12,
    num_heads=6,
    mlp_ratio=4,
    qkv_bias=True,
    norm_layer=partial(nn.LayerNorm, eps=1e-6),
    block_layers=Layer_scale_init_Block,
    pretrained=False,
    pretrained_model_path="",
):
    model = vit_models_distill(
        patch_size=patch_size,
        embed_dim=embed_dim,
        depth=depth,
        num_heads=num_heads,
        mlp_ratio=mlp_ratio,
        qkv_bias=qkv_bias,
        norm_layer=norm_layer,
        block_layers=block_layers,
    )
    if pretrained:
        assert pretrained_model_path != ""
        checkpoint = torch.load(pretrained_model_path, map_location="cpu")
        state_dict = checkpoint["model"]
        # del state_dict["head.weight"]
        # del state_dict["head.bias"]
        # del state_dict["pos_embed"]
        model.load_state_dict(state_dict, strict=False)
        # with torch.no_grad():
        #     model.patch_embed.scale_proj.weight.copy_(
        #         state_dict["patch_embed.proj.weight"]
        #     )
        #     model.patch_embed.scale_proj.bias.copy_(state_dict["patch_embed.proj.bias"])
        del checkpoint
        torch.cuda.empty_cache()
    return model

if __name__ == "__main__":
    vit = build_vit_distill(
        pretrained=False,
        pretrained_model_path="C:/Users/76105/Desktop/mac/Data-Eff-IQA/deit_3_small_224_1k.pth",
    )
    pre = PatchEmbed(embed_dim=384)
    summary(vit, (32, 3, 224, 224), device=torch.device("cuda"))
