from copy import deepcopy
import numpy as np
import torch
from torch import nn

# from pytorch_pretrained_bert.modeling import BertModel
from transformers import BertConfig, RobertaConfig, RobertaModel
from .modeling_bert_new import QVBertModel

from pathlib import Path
import os
import loralib as lora
from transformers.models.bert.modeling_bert import BertSelfAttention, BertAttention
import math


class VisionQuerySelfAttention(BertSelfAttention):
    def __init__(self, config):
        super().__init__(config)

        self.register_parameter("vision_query_gate", nn.Parameter(torch.tensor([0.]*12).view(1, 12, 1, 1)))

        self.vision_query_linear = nn.Linear(256, 768)
        self.vision_query_linear_norm = nn.LayerNorm(768) 
    
    def transform_vision_mask(self, attention_mask, dtype):
        extended_attention_mask = attention_mask[:, None, :, :].transpose(-1, -2).to(dtype=dtype)
        extended_attention_mask = (1 - extended_attention_mask) * torch.finfo(dtype).min
        
        return extended_attention_mask
    
    def forward(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        past_key_value,
        output_attentions):

        # concate vision and text tokens
        vision = hidden_states[0]
        text = hidden_states[1]
        vision = self.vision_query_linear(vision)
        vision = self.vision_query_linear_norm(vision)
        hidden_states = torch.cat((vision, text), dim=1)

        text_attention_mask = attention_mask[1]
        vision_attention_mask = self.transform_vision_mask(attention_mask[0], text_attention_mask.dtype) # [bs, 1, num_text, num_vision]
        num_vision = vision_attention_mask.shape[-1]

        # QKV
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        vision_value_layer = value_layer[:, :, :num_vision, :]
        text_value_layer = value_layer[:, :, num_vision:, :]

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        vision_attention_scores = attention_scores[:, :, num_vision:, :num_vision]
        text_attention_scores = attention_scores[:, :, num_vision:, num_vision:]

        # original text attention
        text_attention_scores = text_attention_scores + text_attention_mask
        text_attention_probs = nn.functional.softmax(text_attention_scores, dim=-1)
        text_attention_probs = self.dropout(text_attention_probs)
        text_context_layer = torch.matmul(text_attention_probs, text_value_layer)

        # vision query attention
        vision_attention_scores = vision_attention_scores + vision_attention_mask
        vision_attention_probs = nn.functional.softmax(vision_attention_scores, dim=-1)
        vision_context_layer = torch.matmul(vision_attention_probs, vision_value_layer)

        # modal merge
        context_layer = torch.tanh(self.vision_query_gate) * vision_context_layer + text_context_layer

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        # breakpoint()
        outputs = (context_layer, text_attention_scores) if output_attentions else (context_layer,)

        return outputs


class VisionQueryAttention(BertAttention):
    def __init__(self, config):
        super().__init__(config)
        self.self = VisionQuerySelfAttention(config)

    def forward(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        past_key_value,
        output_attentions,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states[1])
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


class BertEncoder(nn.Module):
    def __init__(self, cfg):
        super(BertEncoder, self).__init__()
        self.cfg = cfg
        self.bert_name = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
        print("LANGUAGE BACKBONE USE GRADIENT CHECKPOINTING: ", self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT)

        if os.path.basename(self.bert_name) == "bert-base-uncased":
            config = BertConfig.from_pretrained(self.bert_name)
            # config.save_pretrained(Path('MODEL/THIRD_PARTIES/', self.bert_name))
            config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT
            self.model = QVBertModel.from_pretrained(self.bert_name, dim_t=config.hidden_size, dim_v=cfg.MODEL.BACKBONE.OUT_CHANNELS, share_kv=cfg.VISION_QUERY.SHARE_KV, cfg=cfg, add_pooling_layer=False, config=config)
            # model = BertModel.from_pretrained(self.bert_name)
            # model.save_pretrained(Path('MODEL/THIRD_PARTIES/', self.bert_name))
            self.language_dim = 768

            self.config = config

        elif os.path.basename(self.bert_name) == "roberta-base":
            raise NotImplementedError
            config = RobertaConfig.from_pretrained(self.bert_name)
            config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT
            self.model = RobertaModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config)
            self.language_dim = 768
        else:
            raise NotImplementedError

        self.num_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS
    
    def add_lora(self):
        for layer_idx in range(6, len(self.model.encoder.layer)):
            attn_state_dict = self.model.encoder.layer[layer_idx].attention.state_dict()
            self.model.encoder.layer[layer_idx].attention = VisionQueryAttention(self.config)
            self.model.encoder.layer[layer_idx].attention.load_state_dict(attn_state_dict, strict=False)
            print("Enabled VisionQueryAttention in BertLayer {}".format(layer_idx))
        
        for layer_idx in range(6, len(self.model.encoder.layer)):
            query_weight = self.model.encoder.layer[layer_idx].attention.self.query.weight.data.clone()
            query_bias = self.model.encoder.layer[layer_idx].attention.self.query.bias.data.clone()

            value_weight = self.model.encoder.layer[layer_idx].attention.self.value.weight.data.clone()
            value_bias = self.model.encoder.layer[layer_idx].attention.self.value.bias.data.clone()
        
            self.model.encoder.layer[layer_idx].attention.self.query = lora.Linear(768, 768, r=8)
            self.model.encoder.layer[layer_idx].attention.self.value = lora.Linear(768, 768, r=8)

            self.model.encoder.layer[layer_idx].attention.self.query.weight.data = query_weight
            self.model.encoder.layer[layer_idx].attention.self.query.bias.data = query_bias

            self.model.encoder.layer[layer_idx].attention.self.value.weight.data = value_weight
            self.model.encoder.layer[layer_idx].attention.self.value.bias.data = value_bias

            print("Enabled LoRA in BertLayer {}".format(layer_idx))

    def forward(self, x):
        input = x["input_ids"]
        mask = x["attention_mask"]
        vision_inputs = x["vision_inputs"]

        vision=vision_inputs['vision']
        images=vision_inputs['images']
        vision_attention_mask=vision_inputs['vision_attention_mask']
        batched_pos_category_map=vision_inputs["batched_pos_category_map"]

        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
            # with padding, always 256
            outputs, bert_attn = self.model(
                input_ids=input,
                attention_mask=mask,
                output_hidden_states=True,
                vision=vision,
                images=images,
                vision_attention_mask=vision_attention_mask,
                batched_pos_category_map=batched_pos_category_map
            )
            # outputs has 13 layers, 1 input layer and 12 hidden layers
            encoded_layers = outputs.hidden_states[1:]
            features = None
            features = torch.stack(encoded_layers[-self.num_layers:], 1).mean(1)

            # language embedding has shape [len(phrase), seq_len, language_dim]
            features = features / self.num_layers

            embedded = features * mask.unsqueeze(-1).float()
            aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())

        else:
            # without padding, only consider positive_tokens
            max_len = (input != 0).sum(1).max().item()
            outputs = self.model(
                input_ids=input[:, :max_len],
                attention_mask=mask[:, :max_len],
                output_hidden_states=True,
                vision=vision,
                images=images,
                vision_attention_mask=vision_attention_mask,
            )
            # outputs has 13 layers, 1 input layer and 12 hidden layers
            encoded_layers = outputs.hidden_states[1:]

            features = None
            features = torch.stack(encoded_layers[-self.num_layers:], 1).mean(1)
            # language embedding has shape [len(phrase), seq_len, language_dim]
            features = features / self.num_layers

            embedded = features * mask[:, :max_len].unsqueeze(-1).float()
            aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())

        ret = {
            "aggregate": aggregate,
            "embedded": embedded,
            "masks": mask,
            "hidden": encoded_layers[-1], 
            "bert_attn": bert_attn
        }
        # if self.cfg.VISION_QUERY.GATE_REGULARIZATION:
        # ret['vision_query_gates'] = outputs.vision_query_gates
        if self.cfg.VISION_QUERY.QUERY_FUSION:
            ret['augmented_vision'] = outputs.augmented_vision
            ret['vision_attention_mask'] = outputs.vision_attention_mask
        return ret
