# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.


from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
# from timm.models.layers import DropPath, trunc_normal_
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.vision_transformer import _cfg, Mlp
from torchinfo import summary
from torchvision.transforms.functional import resize
import pdb
from .llama import LLaMATransformer
import math

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        if multi_conv:
            if patch_size[0] == 12:
                self.proj = nn.Sequential(
                    nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
                )
            elif patch_size[0] == 16:
                self.proj = nn.Sequential(
                    nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
                )
        else:
            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class Attention(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Experts_MOS(nn.Module):
    def __init__(
        self,
        embed_dim=768,
        juery_nums=6,
    ):
        super().__init__()
        self.juery = juery_nums
        bunch_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            dropout=0.0,
            nhead=6,
            activation=F.gelu,
            batch_first=True,
            dim_feedforward=(embed_dim * 4),
            norm_first=True,
        )
        self.bunch_decoder = nn.TransformerDecoder(bunch_layer, num_layers=1)
        self.bunch_embedding = nn.Parameter(torch.randn(1, juery_nums, embed_dim))
        self.heads = nn.Linear(embed_dim, 1, bias=False)
        trunc_normal_(self.bunch_embedding, std=0.02)

    def forward(self, x, ref):
        B, L, D = x.shape  #196 196 384
        bunch_embedding = self.bunch_embedding.expand(B, -1, -1) 
        ref = ref.view(B, 1, -1)
        ref = ref.expand(B, self.juery, -1)
        output_embedding = bunch_embedding + ref
        
        x = self.bunch_decoder(output_embedding, x)
        x = self.heads(x)
        x = x.view(B, -1).mean(dim=1)
        return x.view(B, 1)


class Layer_scale_init_Block(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    # with slight modifications
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)

    def forward(self, x):
        x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x

class Block(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class Block_paralx2(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    # with slight modifications
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm11 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.attn1 = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.norm21 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.mlp1 = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x):
        x = (
            x
            + self.drop_path(self.attn(self.norm1(x)))
            + self.drop_path(self.attn1(self.norm11(x)))
        )
        x = (
            x
            + self.drop_path(self.mlp(self.norm2(x)))
            + self.drop_path(self.mlp1(self.norm21(x)))
        )
        return x
    
    
class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.wq = nn.Linear(dim, dim, bias=qkv_bias)
        self.wk = nn.Linear(dim, dim, bias=qkv_bias)
        self.wv = nn.Linear(dim, dim, bias=qkv_bias)
        # self.attn_drop = nn.Dropout(attn_drop)
        # self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
        self.attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
        self.attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
        self.attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
        self.attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)

    def forward(self, x):

        B, N, C = x.shape
        
        q = self.wq(x).reshape(B, self.num_heads, C // self.num_heads, N)  # B1C -> B1H(C/H) -> BH1(C/H)
        k = self.wk(x).reshape(B, self.num_heads, C // self.num_heads, N)  # BNC -> BNH(C/H) -> BHN(C/H)
        v = self.wv(x).reshape(B, self.num_heads, C // self.num_heads, N)  # BNC -> BNH(C/H) -> BHN(C/H)
        
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        
        _, _, C, _ = q.shape
        
        mask1 = torch.zeros(B, self.num_heads, C, C, device=x.device, requires_grad=False)
        mask2 = torch.zeros(B, self.num_heads, C, C, device=x.device, requires_grad=False)
        mask3 = torch.zeros(B, self.num_heads, C, C, device=x.device, requires_grad=False)
        mask4 = torch.zeros(B, self.num_heads, C, C, device=x.device, requires_grad=False)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale  # BH1(C/H) @ BH(C/H)N -> BH1N
        
        index = torch.topk(attn, k=int(C/2), dim=-1, largest=True)[1]
        mask1.scatter_(-1, index, 1.)
        attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('0')))

        index = torch.topk(attn, k=int(C*2/3), dim=-1, largest=True)[1]
        mask2.scatter_(-1, index, 1.)
        attn2 = torch.where(mask2 > 0, attn, torch.full_like(attn, float('0')))

        index = torch.topk(attn, k=int(C*3/4), dim=-1, largest=True)[1]
        mask3.scatter_(-1, index, 1.)
        attn3 = torch.where(mask3 > 0, attn, torch.full_like(attn, float('0')))

        index = torch.topk(attn, k=int(C*4/5), dim=-1, largest=True)[1]
        mask4.scatter_(-1, index, 1.)
        attn4 = torch.where(mask4 > 0, attn, torch.full_like(attn, float('0')))
        
        attn1 = attn1.softmax(dim=-1)
        attn2 = attn2.softmax(dim=-1)
        attn3 = attn3.softmax(dim=-1)
        attn4 = attn4.softmax(dim=-1)

        out1 = (attn1 @ v)
        out2 = (attn2 @ v)
        out3 = (attn3 @ v)
        out4 = (attn4 @ v)
        
        out = out1 * self.attn1 + out2 * self.attn2 + out3 * self.attn3 + out4 * self.attn4
        out = out[:,:,:,0].reshape(B, 1, C * self.num_heads)
        out = self.proj_drop(out)
        # out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        
        
        # pdb.set_trace()
        # attn = attn.softmax(dim=-1)
        # attn = self.attn_drop(attn)

        # x = (attn @ v).transpose(1, 2).reshape(B, 1, C)   # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
        # x = self.proj(x)
        # x = self.proj_drop(x)
        return out
    
class CrossAttentionBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = CrossAttention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.has_mlp = has_mlp
        if has_mlp:
            self.norm2 = norm_layer(dim)
            mlp_hidden_dim = int(dim * mlp_ratio)
            self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        # print("x.shape", x.shape)
        x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
        if self.has_mlp:
            x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class MultiScaleBlock(nn.Module):

    def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()

        num_branches = len(dim)
        self.num_branches = num_branches
        # different branch could have different embedding size, the first one is the base
        self.blocks = nn.ModuleList()
        for d in range(num_branches):
            tmp = []
            for i in range(depth[d]):
                tmp.append(
                    Block(dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, 
                          drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer))
            if len(tmp) != 0:
                self.blocks.append(nn.Sequential(*tmp))

        if len(self.blocks) == 0:
            self.blocks = None

        self.projs = nn.ModuleList()
        for d in range(num_branches):
            if dim[d] == dim[(d+1) % num_branches] and False:
                tmp = [nn.Identity()]
            else:
                tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d+1) % num_branches])]
            self.projs.append(nn.Sequential(*tmp))

        self.fusion = nn.ModuleList()
        for d in range(num_branches):
            d_ = (d+1) % num_branches
            nh = num_heads[d_]
            if depth[-1] == 0:  # backward capability:
                self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,
                                                       drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer,
                                                       has_mlp=False))
            else:
                tmp = []
                for _ in range(depth[-1]):
                    tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,
                                                   drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer,
                                                   has_mlp=False))
                self.fusion.append(nn.Sequential(*tmp))

        self.revert_projs = nn.ModuleList()
        for d in range(num_branches):
            if dim[(d+1) % num_branches] == dim[d] and False:
                tmp = [nn.Identity()]
            else:
                tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])]
            self.revert_projs.append(nn.Sequential(*tmp))

    def forward(self, x):
        outs_b = [block(x_) for x_, block in zip(x, self.blocks)]
        # only take the cls token out
        proj_cls_token = [proj(x[:, 0:1]) for x, proj in zip(outs_b, self.projs)]
        # cross attention
        outs = []
        for i in range(self.num_branches):
            
            tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
            tmp = self.fusion[i](tmp)
            reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...])
            tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
            outs.append(tmp)
        return outs


def _compute_num_patches(img_size, patches):
    return [i // p * i // p for i, p in zip(img_size,patches)]


class vit_models(nn.Module):
    """Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support"""

    def __init__(
        self,
        img_size = [224, 384],
        patch_size=(12, 16),
        in_chans=3,
        num_classes=1,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        global_pool=None,
        Patch_layer=PatchEmbed,
        act_layer=nn.GELU,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_scale=1e-4,
        lda=None,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim
        self.img_size = img_size
        self.lda = lda

        # self.patch_embed = Patch_layer(
        #     patch_size=patch_size,
        #     in_chans=in_chans,
        #     embed_dim=embed_dim,
        # )
        num_patches = _compute_num_patches(img_size, patch_size)
        self.num_branches = len(patch_size)
        
        self.patch_embed = nn.ModuleList()
        self.pos_drop = nn.Dropout(p=0.0)
        
        self.cls_token = nn.ParameterList([nn.Parameter(torch.zeros(1, 1, embed_dim[i])) for i in range(self.num_branches)])
        
        ### pos_embed
        self.pos_embed = nn.ParameterList([
            nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)])
        for im_s, p, d in zip(img_size, patch_size, embed_dim):
            self.patch_embed.append(PatchEmbed(img_size=im_s, 
                                               patch_size=p, 
                                               in_chans=in_chans, 
                                               embed_dim=d, 
                                               multi_conv=False))
        ################################################################################
        
        total_depth = sum([sum(x[-2:]) for x in depth])
        

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)]  # stochastic depth decay rule
        dpr_ptr = 0
        self.blocks = nn.ModuleList()
        for idx, block_cfg in enumerate(depth):
            curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
            dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
            blk = MultiScaleBlock(embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio,
                                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=0, attn_drop=attn_drop_rate, drop_path=dpr_,
                                  norm_layer=norm_layer)
            dpr_ptr += curr_depth
            self.blocks.append(blk)


        self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
        self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])

        for i in range(self.num_branches):
            if self.pos_embed[i].requires_grad:
                trunc_normal_(self.pos_embed[i], std=.02)
            trunc_normal_(self.cls_token[i], std=.02)
            
            
            
        

        self.apply(self._init_weights)
        
        
        self.linear_else1 = nn.Linear(192, 384)
        self.linear_else2 = nn.Linear(400, 196)
        
        
        # num_patches = 196
        
        # dpr = [drop_path_rate for i in range(depth)]
        # self.blocks = nn.ModuleList(
        #     [
        #         block_layers(
        #             dim=embed_dim,
        #             num_heads=num_heads,
        #             mlp_ratio=mlp_ratio,
        #             qkv_bias=qkv_bias,
        #             qk_scale=qk_scale,
        #             drop=0.0,
        #             attn_drop=attn_drop_rate,
        #             drop_path=dpr[i],
        #             norm_layer=norm_layer,
        #             act_layer=act_layer,
        #             Attention_block=Attention_block,
        #             Mlp_block=Mlp_block,
        #             init_values=init_scale,
        #         )
        #         for i in range(depth)
        #     ]
        # )

        # self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])

        # self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")]
        # # self.head = (
        # #     nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        # # )
        self.norm_llm = nn.LayerNorm(384)
        self.head = Experts_MOS(embed_dim=384)
        
        llama_config = {"dim": 4096, "multiple_of": 256,
        "n_heads": 32, "n_layers": 32, "norm_eps": 1.0e-6,
        "vocab_size": -1, "first_layer": 31}
        
        llama = LLaMATransformer(llama_config)
        
        checkpoint = torch.load("./checkpoint.pth", map_location="cpu")
        # print()
        checkpoint = checkpoint['model']
        
        selected_keys = {k: v for k, v in checkpoint.items() if 'llama' in k}

        
        for key, value in list(selected_keys.items()):
            if 'llama' in key:
                new_key = key.replace('llama.', '')
                selected_keys[new_key] = selected_keys.pop(key)
        
        llama.load_state_dict(selected_keys,strict = False)
        
        for name, param in llama.named_parameters():
            if 'dim_mapper' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False  
        
        self.llama = llama
        
        # for name, param in self.llama.named_parameters():
        #     print(name)
        #     print(param.requires_grad)
        # dasdasdasd
        
        self.llama_dim_mapper1 = nn.Linear(384, 4096, bias=False)
        self.llama_dim_mapper2 = nn.Linear(4096, 384, bias=False)

        # for i in range(self.num_branches):
        #     if self.pos_embed[i].requires_grad:
        #         trunc_normal_(self.pos_embed[i], std=.02)
        #     trunc_normal_(self.cls_token[i], std=.02)
        # self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        out = {'cls_token'}
        if self.pos_embed[0].requires_grad:
            out.add('pos_embed')
        return out

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        

    # def forward_pos_scale(self):
    #     pos_embed = self.pos_embed.transpose(1, 2).view(1, -1, 14, 14)
    #     pos_embed = F.interpolate(pos_embed, (7, 7), mode="bilinear").flatten(2)
    #     return pos_embed.transpose(1, 2)

    def forward_features(self, x, score, training=None):
        # pdb.set_trace()
        B, C, H, W = x.shape
        xs = []
        for i in range(self.num_branches):
            x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x
            tmp = self.patch_embed[i](x_)
            cls_tokens = self.cls_token[i].expand(B, -1, -1) 
            tmp = torch.cat((cls_tokens, tmp), dim=1)
            tmp = tmp + self.pos_embed[i]
            tmp = self.pos_drop(tmp) 
            xs.append(tmp) 
                        
        self.training = training
        
        if self.training:
            image_score1 = score.expand(B, B)  # (B,1)->(B,B)
            image_score2 = image_score1.t()  # (B,B)->(B,B)
            logits_per_score = abs(image_score1 - image_score2)  # (B,B)
            max_count = (torch.sum(logits_per_score == 0, dim=1)).max()
            threshold = logits_per_score  # (B,B)

            idx = torch.topk(threshold, k=round(0.1 * B) + 1, dim=1, largest=False)[1]  # GT接近

            idx_negative = torch.topk(threshold, k=round(0.4 * B), dim=1)[1]  # GT较远

            candidates_t = []  # output接近
            candidates_negative_t = []  # output较远
            
            for i in range(B):
                candidates_t.append(idx[i, 1:].unsqueeze(dim=0))  # GT近output远 candidates_t是列表
                candidates_negative_t.append(idx_negative[i, :].unsqueeze(dim=0))  # GT远output近

            weight_pos = torch.zeros_like(logits_per_score)
            weight_neg = torch.zeros_like(logits_per_score)
            for i, (cols1, cols2) in enumerate(zip(candidates_t, candidates_negative_t)):
                weight_pos[i, cols1.long()] = 1
                weight_neg[i, cols2.long()] = 1
        
        Loss_sum = 0
        
        # print("lda:", self.lda)
        for i, blk in enumerate(self.blocks):
            xs = blk(xs)
            if self.training:
                Loss_con = Loss_contrastive().cuda()
                Loss = Loss_con(xs, weight_pos, weight_neg, 0.3)
                Loss_sum += self.lda * Loss
                
        pdb.set_trace()
        xs = [self.norm[i](x) for i, x in enumerate(xs)]
        ref = [x[:, 0] for x in xs]
        ref_else = [x[:, 1:] for x in xs]
        
        ref_else_example_2 = self.linear_else2(self.linear_else1(ref_else[0]).permute(0,2,1))
        ref_else[0] = ref_else_example_2.permute(0,2,1)
        
        ref[0] = self.linear_else1(ref[0])
        
        
        ref = torch.mean(torch.stack(ref, dim=0), dim=0)
        ref_else = torch.mean(torch.stack(ref_else, dim=0), dim=0)
        
        # ref_ = [self.head[i](x) for i, x in enumerate(ref)]
        # ref_else_ = 
        ref = ref.unsqueeze(dim=1)
        x = torch.concat((ref, ref_else), dim=1)
        
        x = self.llama_dim_mapper1(x)
        x = self.llama(x)
        x = self.llama_dim_mapper2(x)
        x = self.norm_llm(x)
        

        
        return x[:, 0], x[:, 1:], Loss_sum
        # return x

    def forward(self, x, score, training=None):
        ref, x, Loss_sum = self.forward_features(x, score, training)
        # print("ref.shape", ref.shape)
        # print("x.shape", x.shape)
        x = self.head(x, ref)
        # x = self.head(ref)
        return x, Loss_sum


# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
def build_vit(
    img_size=[224,384],
    patch_size=[12, 16],
    embed_dim=[192, 384],
    depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
    num_heads=[6, 6],
    mlp_ratio=[4, 4, 1],
    qkv_bias=True,
    norm_layer=partial(nn.LayerNorm, eps=1e-6),
    pretrained=True,
    pretrained_model_path="./crossvit_small_224.pth",
    lda=0.1,
):
    model = vit_models(
        img_size=[240,224],
        patch_size=[12, 16],
        embed_dim=[192, 384],
        depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
        num_heads=[6, 6],
        mlp_ratio=[4, 4, 1],
        qkv_bias=qkv_bias,
        norm_layer=norm_layer,
        lda=lda,
    )
    if pretrained:
        print("pretrained_model_path", pretrained_model_path)
        assert pretrained_model_path != ""
        checkpoint = torch.load(pretrained_model_path, map_location="cpu")
        state_dict = checkpoint
        # del state_dict["head.weight"]
        # del state_dict["head.bias"]
        model.load_state_dict(state_dict, strict=False)
        # with torch.no_grad():
        #     model.patch_embed.scale_proj.weight.copy_(
        #         state_dict["patch_embed.proj.weight"]
        #     )
        #     model.patch_embed.scale_proj.bias.copy_(state_dict["patch_embed.proj.bias"])
        del checkpoint
        torch.cuda.empty_cache()
    return model

class Loss_contrastive(torch.nn.Module):
    def __init__(self):
        super(Loss_contrastive,self).__init__()
        self.linear = nn.Linear(192, 384)
        
    def get_image_cluster(self, tensor, length):
        B, N, D = tensor.shape
        image_shape = int(math.sqrt(N))
        image = tensor.view(B, image_shape, image_shape, D)
        split_blocks = []
        for i in range(0, image_shape, length):
            for j in range(0, image_shape, length):
                block = image[:, i:i+length, j:j+length, :]  # 每个块为5x5
                split_blocks.append(block) 
        split_tensor = torch.stack(split_blocks)
        return split_tensor

    def forward(self,xs,weight_pos ,weight_neg ,temperature):
        B = xs[0].shape[0] #[B,0.1*B]
        xs1 = xs[0]
        xs2 = xs[1]
        
        xs1_inter = xs1[:, 0]
        xs2_inter = xs2[:, 0]
        
        xs1_intra = xs1[:, 1:]
        xs2_intra = xs2[:, 1:]
        
        xs1_inter = self.linear(xs1_inter)
        
        xs1_cluster = self.get_image_cluster(xs1_intra, 5).permute(1,0,2,3,4) 
        xs2_cluster = self.get_image_cluster(xs2_intra, 7).permute(1,0,2,3,4)

        
        # xs1_cluster_reshaped = xs1_cluster.reshape(-1, 192)
        N1 = xs1_cluster.shape[1]
        N2 = xs2_cluster.shape[1]
        
        
        ### inter1 and inter2
        image_features1 = xs1_inter.view(B, -1)
        image_features1 = (image_features1 / image_features1.norm(dim=1, keepdim=True))
        image_features2 = xs2_inter.view(B, -1)
        image_features2 = (image_features2 / image_features2.norm(dim=1, keepdim=True))
        
        # image_features3 = xs1_cluster.view(B, N1, -1)
        # image_features3 = (image_features3 / image_features3.norm(dim=1, keepdim=True))
        # image_features4 = xs2_cluster.view(B, N2, -1)
        # image_features4 = (image_features4 / image_features4.norm(dim=1, keepdim=True))
        
        
        ### sim inter and cross
        inter_positive_A = torch.exp(torch.mm(image_features1, image_features1.t().contiguous()) / temperature)
        inter_positive1 = ((inter_positive_A * weight_pos).sum(1)) / ((weight_pos).sum(1))  # 按行就和
        inter_negative1 = (inter_positive_A * weight_neg).sum(1)
        loss1 = (-torch.log(inter_positive1 / (inter_positive1 + inter_negative1))).mean()
        
        inter_positive_B = torch.exp(torch.mm(image_features2, image_features2.t().contiguous()) / temperature)
        inter_positive2 = ((inter_positive_B * weight_pos).sum(1)) / ((weight_pos).sum(1))  # 按行就和
        inter_negative2 = (inter_positive_B * weight_neg).sum(1)
        loss2 = (-torch.log(inter_positive2 / (inter_positive2 + inter_negative2))).mean()
        
        cross_positive_AB = torch.exp(torch.mm(image_features1, image_features2.t().contiguous()) / temperature)
        inter_positive3 = ((cross_positive_AB * weight_pos).sum(1)) / ((weight_pos).sum(1))  # 按行就和
        inter_negative3 = (cross_positive_AB * weight_neg).sum(1)
        loss4 = (-torch.log(inter_positive3 / (inter_positive3 + inter_negative3))).mean()
        
        cross_positive_BA = torch.exp(torch.mm(image_features2, image_features1.t().contiguous()) / temperature)
        inter_positive4 = ((cross_positive_BA * weight_pos).sum(1)) / ((weight_pos).sum(1))  # 按行就和
        inter_negative4 = (cross_positive_BA * weight_neg).sum(1)
        loss5 = (-torch.log(inter_positive4 / (inter_positive4 + inter_negative4))).mean()
        
        # print(image_features1.shape)
        # print(image_features2.shape)
        loss3_sum = []
        loss4_sum = 0
        
        pool_layer = nn.AdaptiveAvgPool2d((1, 1))

# 遍历每个批次
        for k in range(B):
            for i in range(N2):
                min_similarity = float('inf')  # 初始化最小相似度为无穷大
                min_index = -1  

                # 获取并处理第一组特征
                image_features4 = xs2_cluster[k, i, :].reshape(1, 384, 7, 7)
                image_features4 = pool_layer(image_features4).reshape(384, 1)
    

                # 获取第二组特征并批量池化
                image_features3_batch = xs1_cluster[k, :, :].reshape(N1, 192, 5, 5)
                image_features3_pooled = pool_layer(image_features3_batch).reshape(N1, 192, 1)
                
                image_features3_pooled = self.linear(image_features3_pooled.permute(0,2,1)).view(16, 384)
                
                image_features3_pooled = (image_features3_pooled / image_features3_pooled.norm(dim=1, keepdim=True))
                image_features4 = (image_features4 / image_features4.norm(dim=1, keepdim=True))

                # 计算所有相似度
                sim = torch.nn.functional.cosine_similarity(image_features3_pooled, image_features4.t(), dim=1)
                sim = abs(sim)

                # 获取当前最小相似度
                current_min_similarity, min_index = torch.min(sim, dim=0)

                # 更新最低相似度
                if current_min_similarity < min_similarity and current_min_similarity < 0.4:
                    min_similarity = current_min_similarity

                    cross_positive_AB = torch.exp(min_similarity / temperature)
                    loss3 = 1 / cross_positive_AB
                    loss3_sum.append(loss3)
                

        loss3_sum = torch.tensor(loss3_sum)
        # pdb.set_trace()
        Loss = loss1 + loss2 + loss4 + loss5 + 0.0001 * loss3_sum.mean()
        
        return Loss

if __name__ == "__main__":
    print(123)
    vit = build_vit(
        pretrained=True,
        pretrained_model_path="./crossvit_small_224.pth",
        args = args
    )
    # pre = PatchEmbed(embed_dim=384)
    # summary(vit, (32, 3, 224, 224), device=torch.device("cpu"))
