from peft import LoraConfig, TaskType, get_peft_model
import torch
import torch.nn as nn
from transformers.models.vit.modeling_vit import *
from transformers.models.bert.modeling_bert import *
from transformers.models.clip.modeling_clip import *
from transformers import RobertaModel


class CLIPVisionTransformer(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.encoder = CLIPEncoder(config)
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

        ################### CUSTOMIZATION ###################
        # Adding Missing Token
        self.missing_tokens = nn.Parameter(torch.zeros(1, 1, self.config.hidden_size))
        ################### CUSTOMIZATION ###################

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        enable_mt = False,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)

        ################### CUSTOMIZATION ###################
        # Get the batch size and generate missing tokens
        if enable_mt:
            # print("Before MT:", hidden_states.shape)
            batch_size, num_channels, height, width = pixel_values.shape
            missing_tokens = self.missing_tokens.expand(batch_size, -1, -1)
            hidden_states = torch.cat((hidden_states, missing_tokens), dim=1)
            # print("After MT:", hidden_states.shape)
        ################### CUSTOMIZATION ###################

        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]
        pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.post_layernorm(pooled_output)

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )


class CLIPModel(CLIPPreTrainedModel):
    config_class = CLIPConfig
    _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]

    def __init__(self, config: CLIPConfig):
        super().__init__(config)

        if not isinstance(config.text_config, CLIPTextConfig):
            raise ValueError(
                "config.text_config is expected to be of type CLIPTextConfig but is of type"
                f" {type(config.text_config)}."
            )

        if not isinstance(config.vision_config, CLIPVisionConfig):
            raise ValueError(
                "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
                f" {type(config.vision_config)}."
            )

        text_config = config.text_config
        vision_config = config.vision_config

        self.projection_dim = config.projection_dim
        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size

        self.text_model = CLIPTextTransformer(text_config)
        self.vision_model = CLIPVisionTransformer(vision_config)

        self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
        self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
        self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))

        # Initialize weights and apply final processing
        self.post_init()

    def get_text_features(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor:
        pass

    def get_image_features(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor:
        pass

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        return_loss: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CLIPOutput]:
        pass


class BertModel(BertPreTrainedModel):
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)

        self.pooler = BertPooler(config) if add_pooling_layer else None

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        batch_size, seq_length = input_shape
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

        if token_type_ids is None:
            if hasattr(self.embeddings, "token_type_ids"):
                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )

        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )


# Define multimodal model
class ViTBertMMT(nn.Module):
    def __init__(
        self, 
        num_classes, 
        max_text_length=128, 
        r=1,
        lora_alpha=1,
        vit_target_modules=["k_proj", "v_proj", "q_proj", "out_proj"],
        bert_target_modules=["query", "value"],
        lora_dropout=0.1,
        enable_lora=False,
        enable_mt=False,
        text_model='bert-base-uncased',
        no_placeholder=False, # DO NOT use features from missing modality for fusion
    ):
        super(ViTBertMMT, self).__init__()

        self.enable_lora = enable_lora
        self.enable_mt = enable_mt
        self.text_model = text_model
        self.no_placeholder = no_placeholder

        self.vit = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").vision_model

        if self.enable_lora:
            print("LoRA enabled. Inserting lora layers in ViT...")
            self.vit_lora_config = LoraConfig(
                r=r,
                lora_alpha=lora_alpha,
                target_modules=vit_target_modules,
                lora_dropout=lora_dropout,
                bias="none",
                modules_to_save=["classifier"],
            )
            self.vit = get_peft_model(self.vit, self.vit_lora_config)
    
        # Text features extraction using BERT/RoBERTa
        if self.text_model == 'bert-base-uncased':
            self.bert = BertModel.from_pretrained(self.text_model)
        else:
            self.bert = RobertaModel.from_pretrained(self.text_model)

        if self.enable_lora:
            print("LoRA enabled. Inserting lora layers in BERT...")
            self.bert_lora_config = LoraConfig(
                r=r,
                lora_alpha=lora_alpha,
                target_modules=bert_target_modules,
                lora_dropout=lora_dropout,
                bias="none",
                modules_to_save=["classifier"],
            )
            self.bert = get_peft_model(self.bert, self.bert_lora_config)
        self.max_text_length = max_text_length

        # Hidden Dim
        self.hidden_dim = self.bert.config.hidden_size
        
        # Norm and classifier head
        self.num_classes = num_classes

        if self.num_classes == 101:
            self.visual_norm = nn.LayerNorm(self.hidden_dim)
            self.text_norm = nn.LayerNorm(self.hidden_dim)

        self.classifier = nn.Linear(self.hidden_dim, num_classes)
    
    def forward(self, input, missing_type):
        # Base model
        image_features = self.vit(**input[0], enable_mt=self.enable_mt)[0]
        img_cls_tokens = image_features[:, 0, :]

        if self.enable_mt:
            # Base Model
            estimated_text_tokens = image_features[:, -1, :]
        
        # Text feature extraction
        text_features = self.bert(**input[1]).last_hidden_state
        if self.enable_mt:
            estimated_image_tokens = text_features[:, 1, :]
        text_cls_tokens = torch.cat((text_features[:, :1, :], text_features[:, 2:, :]), dim=1).mean(dim=1)
        
        if self.num_classes == 101:
            img_cls_tokens = self.visual_norm(img_cls_tokens)
            text_cls_tokens = self.text_norm(text_cls_tokens)
            if self.enable_mt:
                estimated_image_tokens = self.visual_norm(estimated_image_tokens)
                estimated_text_tokens = self.text_norm(estimated_text_tokens)
            
        real_tokens = []
        estimated_tokens = []
        if self.enable_mt:
            mask_0 = missing_type == 0
            mask_1 = missing_type == 1
            mask_2 = missing_type == 2

            fused_features = torch.zeros_like(img_cls_tokens)  # Initialize with zeros
            fused_features += mask_0.unsqueeze(1) * ((img_cls_tokens + text_cls_tokens) / 2.0) 
            fused_features += mask_1.unsqueeze(1) * ((img_cls_tokens + estimated_text_tokens) / 2.0)
            fused_features += mask_2.unsqueeze(1) * ((estimated_image_tokens + text_cls_tokens) / 2.0)

            for i in range(len(missing_type)):
                if missing_type[i] == 0:
                    real_tokens.append(img_cls_tokens[i])
                    real_tokens.append(text_cls_tokens[i])
                    estimated_tokens.append(estimated_image_tokens[i])
                    estimated_tokens.append(estimated_text_tokens[i])
        else:
            fused_features = (img_cls_tokens + text_cls_tokens) / 2.0
        
        # Classifier
        output = self.classifier(fused_features)
        return output, real_tokens, estimated_tokens, fused_features
