import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from torch.nn.functional import adaptive_avg_pool3d
from functools import partial, reduce
from .swin_backbone import SwinTransformer3D as VideoBackbone
from .swin_backbone import swin_3d_tiny, swin_3d_small
from .conv_backbone import convnext_3d_tiny, convnext_3d_small
from .xclip_backbone import build_x_clip_model
from .swin_backbone import SwinTransformer2D as ImageBackbone
from .head import VQAHead, IQAHead, VARHead

class My_DiViDeAddEvaluator_newtrain(nn.Module):
    def __init__(
        self,
        backbone_size="divided",
        backbone_preserve_keys='fragments,resize',
        multi=False,
        layer=-1,
        backbone=dict(resize={"window_size": (4,4,4)}, fragments={"window_size": (4,4,4)}),
        divide_head=False,
        vqa_head=dict(in_channels=768),
        var=False,
    ):
        self.backbone_preserve_keys = backbone_preserve_keys.split(",")
        self.multi = multi
        self.layer = layer
        super().__init__()
        self.ann = My_ANN(837, 1)
        self.mi_fc1 = nn.Linear(9,1)
        self.mi_fc2 = nn.Linear(400, 1)
        for key, hypers in backbone.items():
            print(backbone_size)
            print('backbone为', backbone)

            if key not in self.backbone_preserve_keys:
                continue
            if backbone_size=="divided":
                # print(type(hypers))
                # print(hypers)
                # print(hypers["window_size"])
                t_backbone_size = hypers["type"]   # todo: type->window_size
            else:
                t_backbone_size = backbone_size

            if t_backbone_size == 'swin_tiny':
                b = swin_3d_tiny(**backbone[key])
            elif t_backbone_size == 'swin_tiny_grpb':
                # to reproduce clif-vqa
                b = VideoBackbone()
            elif t_backbone_size == 'swin_tiny_grpb_m':
                # to reproduce clif-vqa-m
                b = VideoBackbone(window_size=(4,4,4), frag_biases=[0,0,0,0])
            elif t_backbone_size == 'swin_small':
                b = swin_3d_small(**backbone[key])
            elif t_backbone_size == 'conv_tiny':
                b = convnext_3d_tiny(pretrained=True)
            elif t_backbone_size == 'conv_small':
                b = convnext_3d_small(pretrained=True)
            elif t_backbone_size == 'xclip':
                b = build_x_clip_model(**backbone[key])
            else:
                raise NotImplementedError
            print("Setting backbone:", key+"_backbone")
            setattr(self, key+"_backbone", b)     #todo:  setattr(x, 'y', v) is equivalent to ``x.y = v''---> self.fragments_backbone = b
        if divide_head:
            print(divide_head)
            for key in backbone:
                if key not in self.backbone_preserve_keys:
                    continue
                if var:
                    b = VARHead(**vqa_head)
                    print(b)
                else:
                    b = VQAHead(**vqa_head)
                print("Setting head:", key+"_head")
                setattr(self, key+"_head", b) 
        else:
            if var:
                self.vqa_head = VARHead(**vqa_head)
                print(b)
            else:
                self.vqa_head = VQAHead(**vqa_head)

    def forward(self, vclips, clip_feature, inference=True, return_pooled_feats=False, reduce_scores=True, pooled=False, **kwargs):
        if inference:
            self.eval()
            with torch.no_grad():
                
                scores1 = []
                feats = {}
                for key in vclips:
                    feat = getattr(self, key.split("_")[0]+"_backbone")(vclips[key], multi=self.multi, layer=self.layer, **kwargs)
                    if hasattr(self, key.split("_")[0]+"_head"):
                        scores1 += [getattr(self, key.split("_")[0]+"_head")(feat)]
                    else:
                        scores1 += [getattr(self, "vqa_head")(feat)]
                    if return_pooled_feats:
                        feats[key] = feat.mean((-3,-2,-1))
            # if reduce_scores:
            #     if len(scores1) > 1:
            #         scores1 = reduce(lambda x, y: x + y, scores1)
            #         a=1
            #     else:
            #         scores1 = scores1[0]
            #         b=2
            #     if pooled:
            #         scores1 = torch.mean(scores1, (1, 2, 3, 4))
            #         c=3
            # todo: 获得的score1的维度为（4,1,16,7,7）
            #scores1 = scores1.mean((-3, -2, -1))   #scores1= (12,1)
            scores1 = scores1[0].view(4, -1) #todo: score1 = (4,784)

            clip_feature = clip_feature.to(torch.float32)
            scores2 = self.mi_fc1(clip_feature)  # todo: clip_feature=(12,53,240)->(12,53,1)
            scores2 = scores2.squeeze(3)  # todo:压缩一个维度(12,240,53,1)->(12,240,53)
            scores2 = scores2.transpose(1, 2)
            scores2 = self.mi_fc2(scores2)
            scores2 = scores2.squeeze(2)
            scores2 = torch.cat([scores2, scores2,scores2, scores2], 0)
            scores = torch.cat([scores1, scores2], 1)
            scores = self.ann(scores)
            self.train()
            if return_pooled_feats:
                return scores, feats
            return scores
        else:
            self.train()
            scores1 = []
            feats = {}
            for key in vclips:  #todo: Tensor=(12, 3, 32, 224, 224)
                feat = getattr(self, key.split("_")[0]+"_backbone")(vclips[key], multi=self.multi, layer=self.layer, **kwargs)
                if hasattr(self, key.split("_")[0]+"_head"):
                    scores1 += [getattr(self, key.split("_")[0]+"_head")(feat)]
                else:
                    scores1 += [getattr(self, "vqa_head")(feat)]
                if return_pooled_feats:
                    feats[key] = feat.mean((-3,-2,-1))

            # todo: 获得的score1的维度为（12,1,16,7,7）
            #scores1 = scores1.mean((-3, -2, -1))   #scores1= (12,1)
            #a= scores1[0].shape[0]
            scores1 = scores1[0].view(scores1[0].shape[0], -1) #todo: score1 = (12,784)
            clip_feature = clip_feature.to(torch.float32)
            scores2 = self.mi_fc1(clip_feature)  #todo: clip_feature=(12,53,240)->(12,53,1)
            scores2 = scores2.squeeze(3)  #todo:压缩一个维度(12,240,53,1)->(12,240,53)
            scores2 = scores2.transpose(1, 2)
            scores2 = self.mi_fc2(scores2)
            scores2 = scores2.squeeze(2)
            scores = torch.cat([scores1, scores2], 1)
            scores = self.ann(scores)
           #转化语义的特征图想办法，然后均值学习。
            if return_pooled_feats:
                return scores, feats
            return scores

class My_ANN(nn.Module):
    def __init__(self, input_size=2385, reduced_size=1, n_ANNlayers=2, dropout_p=0.5):
        super().__init__()
        self.n_ANNlayers = n_ANNlayers
        self.fc_m1 = nn.Linear(input_size, 256)  #
        self.dropout = nn.Dropout(p=dropout_p)  #
        self.fc_m2 = nn.Linear(256, reduced_size)  #
        self.My_gelu = nn.GELU()  #todo：将relu激活函数改为了gelu激活函数
    def forward(self, input):
        input = self.fc_m1(input)  # linear
        for i in range(self.n_ANNlayers-1):  # nonlinear
            input = self.fc_m2(self.dropout(self.My_gelu(input)))
            #input = self.fc_m2(self.My_gelu(input))
        return input

def TP(q, tau=24, beta=0.5):
    """subjectively-inspired temporal pooling"""
    q = torch.unsqueeze(torch.t(q), 0)
    qm = -float('inf')*torch.ones((1, 1, tau-1)).to(q.device)
    qp = 10000.0 * torch.ones((1, 1, tau - 1)).to(q.device)  #
    l = -F.max_pool1d(torch.cat((qm, -q), 2), tau, stride=1)
    m = F.avg_pool1d(torch.cat((q * torch.exp(-q), qp * torch.exp(-qp)), 2), tau, stride=1)
    n = F.avg_pool1d(torch.cat((torch.exp(-q), torch.exp(-qp)), 2), tau, stride=1)
    m = m / n
    return beta * m + (1 - beta) * l