import torch
import torch.nn as nn
from transformers import AutoImageProcessor, ViTModel, BertModel, BertTokenizer, CLIPModel, ASTModel
from peft import LoraConfig, TaskType, get_peft_model
import torch.nn.functional as F


# Define ViT model
class ViTClassifier(nn.Module):
    def __init__(
        self, 
        num_classes, 
        max_text_length=128, 
        r=8,
        lora_alpha=8,
        target_modules=["k_proj", "v_proj", "q_proj", "out_proj"],
        lora_dropout=0.1,
        enable_lora=False,
    ):
        super(ViTClassifier, self).__init__()

        self.classifier_type = classifier_type
        self.enable_lora = enable_lora
        
        self.vit = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").vision_model
            
        if self.enable_lora:
            print("LoRA enabled. Inserting lora layers in ViT...")
            self.vit_lora_config = LoraConfig(
                r=r,
                lora_alpha=lora_alpha,
                target_modules=target_modules,
                lora_dropout=lora_dropout,
                bias="none",
                modules_to_save=["classifier"],
            )
            self.vit = get_peft_model(self.vit, self.vit_lora_config)

        self.hidden_dim = 768
        self.classifier = nn.Linear(self.hidden_dim, num_classes)
        
    def forward(self, input, normalize=False):
        image_features = self.vit(**input[0])[0][:, 0, :]
        output = self.classifier(image_features)
        return output, image_features


# Define BERT model
class BertClassifier(nn.Module):
    def __init__(
        self, 
        num_classes, 
        max_text_length=128, 
        r=8,
        lora_alpha=8,
        target_modules=["query", "value"],
        lora_dropout=0.1,
        enable_lora=False,
    ):
        super(BertClassifier, self).__init__()

        self.classifier_type = classifier_type
        self.enable_lora = enable_lora
        
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        if self.enable_lora:
            print("LoRA enabled. Inserting lora layers in BERT...")
            self.bert_lora_config = LoraConfig(
                r=r,
                lora_alpha=lora_alpha,
                target_modules=target_modules,
                lora_dropout=lora_dropout,
                bias="none",
                modules_to_save=["classifier"],
            )
            self.bert = get_peft_model(self.bert, self.bert_lora_config)
        self.max_text_length = max_text_length

        self.hidden_dim = self.bert.config.hidden_size
        self.classifier = nn.Linear(self.hidden_dim, num_classes)
    
    def forward(self, input, normalize=False):
        text_features = self.bert(**input[1]).last_hidden_state.mean(dim=1)
        output = self.classifier(text_features)
        return output, text_features


# Define AST model
class ASTClassifier(nn.Module):
    def __init__(
        self, 
        num_classes, 
        max_text_length=128, 
        r=8,
        lora_alpha=8,
        target_modules=["query", "value"],
        lora_dropout=0.1,
        enable_lora=False,
    ):
        super(ASTClassifier, self).__init__()

        self.classifier_type = classifier_type
        self.enable_lora = enable_lora
        
        self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        if self.enable_lora:
            print("LoRA enabled. Inserting lora layers in BERT...")
            self.ast_lora_config = LoraConfig(
                r=r,
                lora_alpha=lora_alpha,
                target_modules=target_modules,
                lora_dropout=lora_dropout,
                bias="none",
                modules_to_save=["classifier"],
            )
            self.ast = get_peft_model(self.ast, self.ast_lora_config)

        self.hidden_dim = 768
        self.classifier = nn.Linear(self.hidden_dim, num_classes)
    
    def forward(self, input, normalize=False):
        audio_features = self.ast(input[1])[1]
        output = self.classifier(audio_features)
        return output, audio_features


# Define Video model
class VideoClassifier(nn.Module):
    def __init__(
        self, 
        num_classes, 
        max_text_length=128, 
        r=8,
        lora_alpha=8,
        target_modules=["k_proj", "v_proj", "q_proj", "out_proj"],
        lora_dropout=0.1,
        enable_lora=False,
    ):
        super(VideoClassifier, self).__init__()

        self.classifier_type = classifier_type
        self.enable_lora = enable_lora
        
        self.vit = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").vision_model
            
        if self.enable_lora:
            print("LoRA enabled. Inserting lora layers in ViT...")
            self.vit_lora_config = LoraConfig(
                r=r,
                lora_alpha=lora_alpha,
                target_modules=target_modules,
                lora_dropout=lora_dropout,
                bias="none",
                modules_to_save=["classifier"],
            )
            self.vit = get_peft_model(self.vit, self.vit_lora_config)

        self.hidden_dim = 768
        self.classifier = nn.Linear(self.hidden_dim, num_classes)
        
    def forward(self, input, normalize=False):
        B, T, C, H, W = input[0].shape
        input[0] = input[0].view(B * T, C, H, W)
        video_features = self.vit(input[0])
        video_features = video_features[1]
        video_features = video_features.view(B, T, -1).mean(dim=1)
        
        output = self.classifier(video_features)
        return output, video_features
