import torch
import torch.nn as nn
import torchaudio
import torchvision
from models.adapter import w2v2_adapter_nlp, w2v2_adapter_conv, vit_adapter_nlp, vit_adapter_conv
from models.visual_model import cnn_face
from torch.nn import functional as F

class CognitiveLoadFeatureExtractor(nn.Module):
    def __init__(self, pretrained_model_path):
        super(CognitiveLoadFeatureExtractor, self).__init__()
        checkpoint = torch.load(pretrained_model_path)
        

        self.model = Fusion(
            fusion_type='cross2', 
            num_encoders=4, 
            adapter=True, 
            adapter_type='efficient_conv',
            multi=False
        )
        
        state_dict = checkpoint['model_state_dict']
        

        excluded_layers = [
            'classifier.0.weight',
            'classifier.0.bias',
            'classifier.2.weight',
            'classifier.2.bias'
        ]
        

        for layer_name in excluded_layers:
            if layer_name in state_dict:
                del state_dict[layer_name]
        

        self.model.load_state_dict(state_dict, strict=False)
        

        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        x, _ = self.model.FEATURE_EXTRACTOR(x, None)
        audios = self.model.FEATURE_PROJECTOR(x)
        
        # Reshape the visual data and extract features
        b_s, no_of_frames, C, H, W = y.shape
        y = torch.reshape(y, (b_s * no_of_frames, C, H, W))
        faces = self.model.cnn_feature_extractor(y)
        faces = torch.reshape(faces, (b_s, no_of_frames, 256))
        faces = self.model.projection(faces) + self.model.pos_embedding


        cognitive_crossmodal_outputs = []
        for i in range(self.model.num_encoders):
            audios = self.model.TRANSFORMER[i](audios)
            faces = self.model.ViT_Encoder[i](faces)
            fused = self.model.cross_conv_layer[i](audios, faces)
            cognitive_crossmodal_outputs.append(fused)

        return cognitive_crossmodal_outputs

class CrossFusionModule(nn.Module):
    def __init__(self, dim=256):
        super(CrossFusionModule, self).__init__()

        self.project_audio = nn.Linear(768, dim)  
        self.project_vision = nn.Linear(768, dim)
        self.corr_weights = torch.nn.Parameter(torch.empty(
            dim, dim, requires_grad=True).type(torch.cuda.FloatTensor))
        nn.init.xavier_normal_(self.corr_weights)
        self.project_bottleneck = nn.Sequential(nn.Linear(dim * 2, 64),
                                                nn.LayerNorm((64,), eps=1e-05, elementwise_affine=True),
                                                nn.ReLU())

    def forward(self, audio_feat, visual_feat):
        audio_feat = self.project_audio(audio_feat)
        visual_feat = self.project_vision(visual_feat)

        visual_feat = visual_feat.transpose(1, 2)

        a1 = torch.matmul(audio_feat, self.corr_weights)
        cc_mat = torch.bmm(a1, visual_feat)

        audio_att = F.softmax(cc_mat, dim=1)
        visual_att = F.softmax(cc_mat.transpose(1, 2), dim=1)
        atten_audiofeatures = torch.bmm(audio_feat.transpose(1, 2), audio_att)
        atten_visualfeatures = torch.bmm(visual_feat, visual_att)
        atten_audiofeatures = atten_audiofeatures + audio_feat.transpose(1, 2)
        atten_visualfeatures = atten_visualfeatures + visual_feat

        fused_features = self.project_bottleneck(torch.cat((atten_audiofeatures,
                                                            atten_visualfeatures), dim=1).transpose(1, 2))

        return fused_features

class Fusion(nn.Module):
    def __init__(self, fusion_type, num_encoders, adapter, adapter_type, multi=False):
        super(Fusion, self).__init__()
        self.fusion_type = fusion_type
        self.num_encoders = num_encoders
        self.adapter = adapter
        self.adapter_type = adapter_type
        self.multi = multi

        model = torchaudio.pipelines.WAV2VEC2_BASE.get_model()
        for p in model.parameters():
            p.requires_grad = False

        self.FEATURE_EXTRACTOR = model.feature_extractor

        self.FEATURE_PROJECTOR = nn.Sequential(
            model.encoder.feature_projection,
            model.encoder.transformer.pos_conv_embed,
            model.encoder.transformer.layer_norm,
            model.encoder.transformer.dropout,
        )

        audio_layer_list = []

        for i in range(self.num_encoders):
            if self.adapter:
                if self.adapter_type == 'nlp':
                    audio_layer_list.append(w2v2_adapter_nlp(transformer_encoder=model.encoder.transformer.layers[i]))
                else:
                    audio_layer_list.append(w2v2_adapter_conv(transformer_encoder=model.encoder.transformer.layers[i]))
            else:
                for p in model.encoder.transformer.layers[i].parameters():
                    p.requires_grad = True
                audio_layer_list.append(model.encoder.transformer.layers[i])

        self.TRANSFORMER = nn.Sequential(*audio_layer_list)

        self.projection = nn.Sequential(
            nn.Linear(256, 768),
            nn.ReLU(),
        )

        vit_b_16 = torchvision.models.vit_b_16(pretrained=True)
        for p in vit_b_16.parameters():
            p.requires_grad = False

        vit = vit_b_16.encoder

        self.pos_embedding = nn.Parameter(torch.empty(1, 64, 768).normal_(std=0.02))

        face_layer_list = []
        for i in range(self.num_encoders):
            if self.adapter:
                if self.adapter_type == 'nlp':
                    face_layer_list.append(vit_adapter_nlp(transformer_encoder=vit.layers[i]))
                else:
                    face_layer_list.append(vit_adapter_conv(transformer_encoder=vit.layers[i]))
            else:
                for p in vit.layers[i].parameters():
                    p.requires_grad = True
                face_layer_list.append(vit.layers[i])

        self.cnn_feature_extractor = cnn_face()
        self.ViT_Encoder = nn.Sequential(*face_layer_list)

        self.projection_layer = nn.Linear(64 * 12, 64 * 4)

        if self.fusion_type == "concat":
            self.classifier = nn.Sequential(nn.Linear(768 * 2, 2))
        elif self.fusion_type == "cross2":
            cross_conv_layer = []
            for i in range(self.num_encoders):
                cross_conv_layer.append(CrossFusionModule(dim=256))
            self.cross_conv_layer = nn.Sequential(*cross_conv_layer)
            self.classifier = nn.Sequential(nn.Linear(512, 64),
                                            nn.Dropout(p=0.5),
                                            nn.Linear(64, 2))
            if self.multi:
                self.audio_classifier = nn.Linear(768, 2)
                self.vision_classifier = nn.Linear(768, 2)
        else:
            self.classifier = nn.Sequential(nn.Linear(768 * 2, 2))

    def forward(self, x, y):
        a_features = self.cognitive_feature_extractor_a(x, y)

        b_features = self.cognitive_feature_extractor_b(x, y)

        c_features = self.cognitive_feature_extractor_c(x, y)

        a_features_concat = torch.cat(a_features, dim=-1)
        b_features_concat = torch.cat(b_features, dim=-1)
        c_features_concat = torch.cat(c_features,dim=-1)

        concatenated_features = torch.cat((a_features_concat, b_features_concat, c_features_concat), dim=-1)

        projected_features = self.projection_layer(concatenated_features)

        x, _ = self.FEATURE_EXTRACTOR(x, None)
        audios = self.FEATURE_PROJECTOR(x)

        b_s, no_of_frames, C, H, W = y.shape
        y = torch.reshape(y, (b_s * no_of_frames, C, H, W))
        
        faces = self.cnn_feature_extractor(y)
        faces = torch.reshape(faces, (b_s, no_of_frames, 256))
        faces = self.projection(faces) + self.pos_embedding

        feat_ls = []
        if self.fusion_type == "concat":
            audios = self.TRANSFORMER(audios)
            faces = self.ViT_Encoder(faces)
            fused_output = torch.cat((audios, faces), dim=-1)
        elif self.fusion_type in ["cross2", ]:
            assert len(self.TRANSFORMER) == len(self.ViT_Encoder), "unmatched encoders between audio and face"
            for audio_net, visual_net, cross_conv in zip(self.TRANSFORMER, self.ViT_Encoder, self.cross_conv_layer):
                audios = audio_net(audios)
                faces = visual_net(faces)
                fused_features = cross_conv(audios, faces)
                feat_ls.append(fused_features)

            fused_features = torch.cat(feat_ls, dim=-1)

            final_fused_output = torch.cat((fused_features, projected_features), dim=-1)
        else:
            raise Exception("undefined fusion type")

        logits = self.classifier(final_fused_output)
        return torch.mean(logits, 1), None, None

    def forward_without_classifier(self, x, y):
        cognitive_crossmodal_outputs = self.cognitive_feature_extractor_a(x, y)
        cognitive_crossmodal_outputs_b = self.cognitive_feature_extractor_b(x, y)
        cognitive_crossmodal_outputs_c = self.cognitive_feature_extractor_c(x, y)

        cognitive_features = torch.cat(cognitive_crossmodal_outputs, dim=-1)
        cognitive_features_b = torch.cat(cognitive_crossmodal_outputs_b, dim=-1)
        cognitive_features_c = torch.cat(cognitive_crossmodal_outputs_c, dim=-1)

        final_cognitive_features = torch.cat((cognitive_features, cognitive_features_b, cognitive_features_c), dim=-1)

        x, _ = self.FEATURE_EXTRACTOR(x, None)
        audios = self.FEATURE_PROJECTOR(x)

        b_s, no_of_frames, C, H, W = y.shape
        y = torch.reshape(y, (b_s * no_of_frames, C, H, W))
        faces = self.cnn_feature_extractor(y)
        faces = torch.reshape(faces, (b_s, no_of_frames, 256))
        faces = self.projection(faces) + self.pos_embedding

        feat_ls = []
        if self.fusion_type == "concat":
            audios = self.TRANSFORMER(audios)
            faces = self.ViT_Encoder(faces)
            fused_output = torch.cat((audios, faces), dim=-1)
        elif self.fusion_type in ["cross2", ]:
            assert len(self.TRANSFORMER) == len(self.ViT_Encoder), "unmatched encoders between audio and face"
            for audio_net, visual_net, cross_conv in zip(self.TRANSFORMER, self.ViT_Encoder, self.cross_conv_layer):
                audios = audio_net(audios)
                faces = visual_net(faces)
                fused_features = cross_conv(audios, faces)
                feat_ls.append(fused_features)
            
            fused_features = torch.cat(feat_ls, dim=-1)

            final_fused_output = torch.cat((fused_features, final_cognitive_features), dim=-1)

        else:
            raise Exception("undefined fusion type")

        return final_fused_output, None
