from peft import LoraConfig, TaskType, get_peft_model
import torch
import torch.nn as nn
from transformers import ASTModel, ViTModel
from .clip import CLIPModel
from .clip import ASTModel as MyAST


class AudioVideoModelWithMT(nn.Module):
    def __init__(
        self, 
        num_classes, 
        r=1,
        lora_alpha=1,
        vit_target_modules=["k_proj", "v_proj", "q_proj", "out_proj"],
        ast_target_modules=["query", "value", "key", "attention.output.dense"],
        lora_dropout=0.1,
        enable_lora=False,
        enable_mt=False,
    ):
        super(AudioVideoModelWithMT, self).__init__()

        self.enable_lora = enable_lora
        self.enable_mt = enable_mt

        self.vit = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").vision_model
        self.ast = MyAST.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

        if self.enable_lora:
            print("LoRA enabled. Inserting lora layers in ViT...")
            self.vit_lora_config = LoraConfig(
                r=r,
                lora_alpha=lora_alpha,
                target_modules=vit_target_modules,
                lora_dropout=lora_dropout,
                bias="none",
                modules_to_save=["classifier"],
            )
            self.vit = get_peft_model(self.vit, self.vit_lora_config)

            print("LoRA enabled. Inserting lora layers in AST...")
            self.ast_lora_config = LoraConfig(
                r=r,
                lora_alpha=lora_alpha,
                target_modules=ast_target_modules,
                lora_dropout=lora_dropout,
                bias="none",
                modules_to_save=["classifier"],
            )
            self.ast = get_peft_model(self.ast, self.ast_lora_config)
        
        self.hidden_dim = self.ast.config.hidden_size
        self.video_norm = nn.LayerNorm(self.hidden_dim)
        self.audio_norm = nn.LayerNorm(self.hidden_dim)
        self.classifier = nn.Linear(self.hidden_dim, num_classes)
    
    def forward(self, input, missing_type):
        B, T, C, H, W = input[0].shape
        input[0] = input[0].view(B * T, C, H, W)
        video_features = self.vit(input[0], enable_mt=self.enable_mt)[0]
        audio_features = self.ast(input[1], enable_mt=self.enable_mt)[0]
        
        video_cls_tokens = video_features[:, 0, :]
        video_cls_tokens = video_cls_tokens.view(B, T, -1).mean(dim=1)
        audio_cls_tokens = audio_features[:, 0, :]

        video_cls_tokens = self.video_norm(video_cls_tokens)
        audio_cls_tokens = self.audio_norm(audio_cls_tokens)

        if self.enable_mt:
            estimated_audio_tokens = video_features[:, -1, :]
            estimated_audio_tokens = self.audio_norm(estimated_audio_tokens.view(B, T, -1).mean(dim=1))
            estimated_video_tokens = self.video_norm(audio_features[:, -1, :])

        real_tokens = []
        estimated_tokens = []
        if self.enable_mt:
            mask_0 = missing_type == 0
            mask_1 = missing_type == 1
            mask_2 = missing_type == 2

            # Calculate fused_features using element-wise operations
            # 0 -> video, 1 -> audio, 2 -> none
            fused_features = torch.zeros_like(video_cls_tokens)  # Initialize with zeros
            fused_features += mask_2.unsqueeze(1) * ((video_cls_tokens + audio_cls_tokens) / 2.0) 
            fused_features += mask_0.unsqueeze(1) * ((estimated_video_tokens + audio_cls_tokens) / 2.0)
            fused_features += mask_1.unsqueeze(1) * ((video_cls_tokens + estimated_audio_tokens) / 2.0)

            for i in range(len(missing_type)):
                if missing_type[i] == 2:
                    real_tokens.append(video_cls_tokens[i])
                    real_tokens.append(audio_cls_tokens[i])
                    estimated_tokens.append(estimated_video_tokens[i])
                    estimated_tokens.append(estimated_audio_tokens[i])
        else:
            fused_features = (video_cls_tokens + audio_cls_tokens) / 2.0
        
        # Classifier
        output = self.classifier(fused_features)
        return output, real_tokens, estimated_tokens, fused_features
