# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.


from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, trunc_normal_
from timm.models.vision_transformer import Mlp
from torchinfo import summary
from torchvision.transforms.functional import resize
import pretrainedmodels

class PatchEmbed(nn.Module):
    """2D Image to Patch Embedding"""

    def __init__(
        self,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        norm_layer=None,
        flatten=True,
    ):
        super().__init__()
        self.flatten = flatten

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )

        # self.scale_proj = nn.Conv2d(
        #     in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        # )
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
        # self.pooling = nn.AdaptiveAvgPool2d((14, 14))

    # def forward_multiscale(self, x):
    #     s = resize(x, (112, 112))
    #     s = self.scale_proj(s)
    #     return s

    def forward(self, x):
        # s = self.forward_multiscale(x)
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
            # s = s.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x


class Attention(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale

        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class CNN_Model(nn.Module):

    def __init__(self):
        super(CNN_Model, self).__init__()

        # Training backbone
        train_backbone = False

        # Backbone modules
        backbone_modules = list(pretrainedmodels.inceptionresnetv2().modules())

        # Stem
        self.mixed5b = nn.Sequential(backbone_modules[1], backbone_modules[5], backbone_modules[9],
                                     backbone_modules[13], backbone_modules[14], backbone_modules[18],
                                     backbone_modules[22], backbone_modules[23])

        # Block 1
        self.block35_2 = nn.Sequential(backbone_modules[57], backbone_modules[86])

        # Block 2
        self.block35_4 = nn.Sequential(backbone_modules[115], backbone_modules[144])

        # Block 3
        self.block35_6 = nn.Sequential(backbone_modules[173], backbone_modules[202])

        # Block 4
        self.block35_8 = nn.Sequential(backbone_modules[231], backbone_modules[260])

        # Block 5
        self.block35_10 = nn.Sequential(backbone_modules[289], backbone_modules[318])

        # Set require grad
        # self.set_require_grad(self.mixed5b, train_backbone)
        # self.set_require_grad(self.block35_2, train_backbone)
        # self.set_require_grad(self.block35_4, train_backbone)
        # self.set_require_grad(self.block35_6, train_backbone)
        # self.set_require_grad(self.block35_8, train_backbone)
        # self.set_require_grad(self.block35_10, train_backbone)

        self.frozen = [self.mixed5b, self.block35_2, self.block35_4, self.block35_6, self.block35_8, self.block35_10]

    def summary(self, show_dict=False):
        super(CNN_Model, self).summary(show_dict=False)

        print("Frozen Parameters: {:,}".format(self.num_params(self.frozen)))
        if show_dict:
            self.show_dict(self.frozen)

        print("Trainable Parameters: {:,}".format(self.num_params(self.trainable)))
        if show_dict:
            self.show_dict(self.trainable)

    def backbone(self, x):

        self.mixed5b.eval()
        self.block35_2.eval()
        self.block35_4.eval()
        self.block35_6.eval()
        self.block35_8.eval()
        self.block35_10.eval()

        x1 = self.mixed5b(x)#[12,3,224,224]->[12,320,25,25]
        x2 = self.block35_2(x1)
        x3 = self.block35_4(x2)
        x4 = self.block35_6(x3)
        x5 = self.block35_8(x4)
        x6 = self.block35_10(x5)

        return torch.cat([x1, x2, x3, x4, x5, x6], dim=1)

    def forward(self, x_dis ,x_ref):

        with torch.no_grad():
            x_ref = self.backbone(x_ref)#torch.Size([12, 3, 224, 224])-》torch.Size([12, 1920, 25, 25])
            x_dis = self.backbone(x_dis)

        x_diff = x_ref - x_dis

        return x_diff

class Experts_MOS(nn.Module):
    def __init__(
        self,
        embed_dim=768,
        juery_nums=6,
    ):
        super().__init__()
        self.juery = juery_nums
        bunch_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim,
            dropout=0.0,
            nhead=3,
            activation=F.gelu,
            batch_first=True,
            dim_feedforward=(embed_dim * 4),
            norm_first=True,
        )
        self.bunch_decoder = nn.TransformerDecoder(bunch_layer, num_layers=1)
        self.bunch_embedding = nn.Parameter(torch.randn(1, juery_nums, embed_dim)) #[1,6,384]
        self.heads = nn.Linear(embed_dim, 1, bias=False)
        trunc_normal_(self.bunch_embedding, std=0.02)

    def forward(self, x, ref): #x:[32,196,384] ref:[32,384]
        B, L, D = x.shape #B:32 D:384 L:196
        bunch_embedding = self.bunch_embedding.expand(B, -1, -1) #[32,6,384] 这句话就是说把[6,384]重复32次
        ref = ref.view(B, 1, -1) #[32,1,384]
        ref = ref.expand(B, self.juery, -1) #[32,6,384] #-1表示当前维度不变
        output_embedding = bunch_embedding + ref
        x = self.bunch_decoder(output_embedding, x)
        x = self.heads(x)  #[32,6,384]->[32,6,1]
        x = x.view(B, -1).mean(dim=1) #x.view(B, -1)->[32,6] x.view(B, -1).mean(dim=1)->[32]
        return x.view(B, 1), self.heads.weight


class Layer_scale_init_Block(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    # with slight modifications
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)

    def forward(self, x):
        x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class Layer_scale_init_Block_paralx2(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    # with slight modifications
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm11 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.attn1 = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.norm21 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.mlp1 = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
        self.gamma_1_1 = nn.Parameter(
            init_values * torch.ones((dim)), requires_grad=True
        )
        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
        self.gamma_2_1 = nn.Parameter(
            init_values * torch.ones((dim)), requires_grad=True
        )

    def forward(self, x):
        x = (
            x
            + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
            + self.drop_path(self.gamma_1_1 * self.attn1(self.norm11(x)))
        )
        x = (
            x
            + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
            + self.drop_path(self.gamma_2_1 * self.mlp1(self.norm21(x)))
        )
        return x


class Block_paralx2(nn.Module):
    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    # with slight modifications
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_values=1e-4,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.norm11 = norm_layer(dim)
        self.attn = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.attn1 = Attention_block(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.norm21 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )
        self.mlp1 = Mlp_block(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x):
        x = (
            x
            + self.drop_path(self.attn(self.norm1(x)))
            + self.drop_path(self.attn1(self.norm11(x)))
        )
        x = (
            x
            + self.drop_path(self.mlp(self.norm2(x)))
            + self.drop_path(self.mlp1(self.norm21(x)))
        )
        return x


class vit_models_NAR(nn.Module):
    """Vision Transformer with LayerScale (https://arxiv.org/abs/2103.17239) support"""

    def __init__(
        self,
        patch_size=16,
        in_chans=1920,
        num_classes=1,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        global_pool=None,
        block_layers=Block,
        Patch_layer=PatchEmbed,
        act_layer=nn.GELU,
        Attention_block=Attention,
        Mlp_block=Mlp,
        init_scale=1e-4,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim

        self.patch_embed = Patch_layer(
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        num_patches = 625
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.conv_proj = nn.Conv2d(in_channels=1920, out_channels=384, kernel_size=(1, 1))
        self.feature_extract = CNN_Model()

        dpr = [drop_path_rate for i in range(depth)]
        self.blocks = nn.ModuleList(
            [
                block_layers(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=0.0,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    act_layer=act_layer,
                    Attention_block=Attention_block,
                    Mlp_block=Mlp_block,
                    init_values=init_scale,
                )
                for i in range(depth)
            ]
        )

        self.norm = norm_layer(embed_dim)

        self.feature_info = [dict(num_chs=embed_dim, reduction=0, module="head")]
        # self.head = (
        #     nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        # )

        self.head = Experts_MOS(embed_dim=384)

        trunc_normal_(self.pos_embed, std=0.02)
        trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token"}

    def get_classifier(self):
        return self.head

    def get_num_layers(self):
        return len(self.blocks)

    def reset_classifier(self, num_classes, global_pool=""):
        self.num_classes = num_classes
        self.head = (
            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        )

    # def forward_pos_scale(self):
    #     pos_embed = self.pos_embed.transpose(1, 2).view(1, -1, 14, 14)
    #     pos_embed = F.interpolate(pos_embed, (7, 7), mode="bilinear").flatten(2)
    #     return pos_embed.transpose(1, 2)

    def forward_features(self, x_dis ,x_ref):

        x = self.feature_extract(x_dis, x_ref)#torch.Size([12, 3, 224, 224])-》[12, 1920, 25, 25]
        # H, W = 25, 25
        B = x.shape[0]
        x = self.conv_proj(x) #[12, 1920, 25, 25]->[12, 384, 25, 25]

        x = x.flatten(2, -1).transpose(1, 2)#[12, 196, 384]

        cls_tokens = self.cls_token.expand(B, -1, -1)   #[12,1,384]

        x = x + self.pos_embed

        # s = s + self.forward_pos_scale()

        # x = torch.cat((x, s), dim=1)

        x = torch.cat((cls_tokens, x), dim=1) #[32,196,384]->[32,197,384]
        mlp_inner_feature = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            mlp_inner_feature.append(x[:, 1:, :])

        x = self.norm(x)
        return x[:, 0], x[:, 1:, :] , mlp_inner_feature
        #x[:, 0]:[32,384];x[:, 1:, :]:[32,196,384]
        # return x

    def forward(self, x_dis , x_ref):
        ref, x ,encode_inner_feature = self.forward_features(x_dis , x_ref) #ref:[32,384] x:[32,196,384]
        x , weight= self.head(x, ref) #用decoder预测出回归值
        # x = self.head(ref)
        return x , encode_inner_feature, weight


# DeiT III: Revenge of the ViT (https://arxiv.org/abs/2204.07118)
def build_vit_NAR(
    patch_size=16,
    embed_dim=384,
    depth=12,
    num_heads=6,
    mlp_ratio=4,
    qkv_bias=True,
    norm_layer=partial(nn.LayerNorm, eps=1e-6),
    block_layers=Layer_scale_init_Block,
    pretrained=False,
    pretrained_model_path="",
):
    model = vit_models_NAR(
        patch_size=patch_size,
        embed_dim=embed_dim,
        depth=depth,
        num_heads=num_heads,
        mlp_ratio=mlp_ratio,
        qkv_bias=qkv_bias,
        norm_layer=norm_layer,
        block_layers=block_layers,
    )
    if pretrained:
        assert pretrained_model_path != ""
        checkpoint = torch.load(pretrained_model_path ,map_location="cpu")
        state_dict = checkpoint["model"]
        # del state_dict["head.heads.weight"]
        # del state_dict["head.bias"]
        del state_dict["patch_embed.proj.weight"]
        del state_dict["pos_embed"]
        model.load_state_dict(state_dict, strict=False)
        # with torch.no_grad():
        #     model.patch_embed.scale_proj.weight.copy_(
        #         state_dict["patch_embed.proj.weight"]
        #     )
        #     model.patch_embed.scale_proj.bias.copy_(state_dict["patch_embed.proj.bias"])
        del checkpoint
        torch.cuda.empty_cache()
    return model

def build_vit_NAR_teacher(
    patch_size=16,
    embed_dim=384,
    depth=12,
    num_heads=6,
    mlp_ratio=4,
    qkv_bias=True,
    norm_layer=partial(nn.LayerNorm, eps=1e-6),
    block_layers=Layer_scale_init_Block,
    pretrained=False,
    pretrained_model_path="",
):
    model = vit_models_NAR(
        patch_size=patch_size,
        embed_dim=embed_dim,
        depth=depth,
        num_heads=num_heads,
        mlp_ratio=mlp_ratio,
        qkv_bias=qkv_bias,
        norm_layer=norm_layer,
        block_layers=block_layers,
    )
    if pretrained:
        assert pretrained_model_path != ""
        checkpoint = torch.load(pretrained_model_path ,map_location="cpu")
        state_dict = checkpoint["model"]
        # del state_dict["head.heads.weight"] # 自己pretrain
        # del state_dict["head.bias"] #deit
        # del state_dict["patch_embed.proj.weight"] # 自己pretrain
        # del state_dict["pos_embed"]
        # 对于模型的每个权重，使其不进行反向传播，即固定参数
        model.load_state_dict(state_dict)
        # for param in model.blocks.parameters():
        #     param.requires_grad = False
        # # 但是参数全部固定了，也没法进行学习，所以我们不固定最后一层，即全连接层fc
        # for param in model.head.parameters():
        #     param.requires_grad = True
        # with torch.no_grad():
        #     model.patch_embed.scale_proj.weight.copy_(
        #         state_dict["patch_embed.proj.weight"]
        #     )
        #     model.patch_embed.scale_proj.bias.copy_(state_dict["patch_embed.proj.bias"])
        del checkpoint
        torch.cuda.empty_cache()
    return model


if __name__ == "__main__":
    vit = build_vit_NAR(
        pretrained=True,
        pretrained_model_path="E:/NR-IQA-IJCAJ/IQA-Data-Eff-IQA/ckpt_epoch_8.pth"
        # "/home/Xudong_Li/NR-IQA-IJCAJ/IQA-Data-Eff-IQA/ckpt_epoch_8.pth",
    ).cuda()
    pre = PatchEmbed(embed_dim=384)
    # summary(vit, (32, 3, 224, 224), device=torch.device("cuda"))
    data1 = torch.randn([12, 3, 224 ,224]).to("cuda")
    data2 = torch.randn([12, 3, 224 ,224]).to("cuda")
    vit.eval()
    _,_,_=vit(data1,data2)