# Identifiers will be add once the code is made public.


from __future__ import absolute_import, division, print_function, unicode_literals

import copy
import json
import logging
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from typing import Optional
from torch.nn import CrossEntropyLoss
from torch.autograd import Variable
from torch.nn.parameter import Parameter

def normalize_scores(scores):
    """
    Normalize the saliency scores to the range [0, 1] using min-max normalization.

    :param scores: Tensor of saliency scores (Batch, L_v)
    :return: Normalized scores (Batch, L_v)
    """
    # Ensure the scores are of type float for precision
    scores = scores.float()

    # Compute min and max values along the sequence length dimension
    min_scores = scores.min(dim=1, keepdim=True)[0]  # (Batch, 1)
    max_scores = scores.max(dim=1, keepdim=True)[0]  # (Batch, 1)

    # Apply min-max normalization
    normalized_scores = (scores - min_scores) / (
                max_scores - min_scores + 1e-8)  # Adding small epsilon to avoid division by zero

    return normalized_scores#

logger = logging.getLogger(__name__)
v_th_Val = 1.
class BertConfig(object):
    """Configuration class to store the configuration of a `BertModel`.
    """

    def __init__(self,
                 vocab_size_or_config_json_file,
                 hidden_size=768,
                 num_hidden_layers=12,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 hidden_act="gelu",
                 hidden_dropout_prob=0.1,
                 attention_probs_dropout_prob=0.1,
                 max_position_embeddings=512,
                 type_vocab_size=2,
                 initializer_range=0.02,
                 pre_trained='',
                 training=''):
        """Constructs BertConfig.

        Args:
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
            hidden_size: Size of the encoder layers and the pooler layer.
            num_hidden_layers: Number of hidden layers in the Transformer encoder.
            num_attention_heads: Number of attention heads for each attention layer in
                the Transformer encoder.
            intermediate_size: The size of the "intermediate" (i.e., feed-forward)
                layer in the Transformer encoder.
            hidden_act: The non-linear activation function (function or string) in the
                encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
            hidden_dropout_prob: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attention_probs_dropout_prob: The dropout ratio for the attention
                probabilities.
            max_position_embeddings: The maximum sequence length that this model might
                ever be used with. Typically set this to something large just in case
                (e.g., 512 or 1024 or 2048).
            type_vocab_size: The vocabulary size of the `token_type_ids` passed into
                `BertModel`.
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
        """
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                                                               and isinstance(vocab_size_or_config_json_file, unicode)):
            with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.hidden_act = hidden_act
            self.intermediate_size = intermediate_size
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.type_vocab_size = type_vocab_size
            self.initializer_range = initializer_range
            self.pre_trained = pre_trained
            self.training = training
        else:
            raise ValueError("First argument must be either a vocabulary size (int)"
                             "or the path to a pretrained model config file (str)")

    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `BertConfig` from a Python dictionary of parameters."""
        config = BertConfig(vocab_size_or_config_json_file=-1)
        for key, value in json_object.items():
            config.__dict__[key] = value
        return config

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `BertConfig` from a json file of parameters."""
        with open(json_file, "r", encoding='utf-8') as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path):
        """ Save this instance to a json file."""
        with open(json_file_path, "w", encoding='utf-8') as writer:
            writer.write(self.to_json_string())

class BertLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super(BertLayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        #return x
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

    def copy(self, target):
        self.weight.data = target.weight.data.clone()
        self.bias.data = target.bias.data.clone()

def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    return F.gelu(x) #x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu}
NORM = {'layer_norm': BertLayerNorm}

class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)


    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(
            seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

    def copy(self, target):
        self.LayerNorm.copy(target.LayerNorm)
        self.word_embeddings.weight.data = target.word_embeddings.weight.data
        self.position_embeddings.weight.data = target.position_embeddings.weight.data
        self.token_type_embeddings.weight.data = target.token_type_embeddings.weight.data

class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super(BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(
            config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def copy(self, target):
        self.query.weight.data = target.query.weight.data.clone()
        if self.query.bias is not None:
            self.query.bias.data = target.query.bias.data.clone()

        self.key.weight.data = target.key.weight.data.clone()
        if self.key.bias is not None:
            self.key.bias.data = target.key.bias.data.clone()

        self.value.weight.data = target.value.weight.data.clone()
        if self.value.bias is not None:
            self.value.bias.data = target.value.bias.data.clone()

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[
            :-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask):

        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(
            query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / \
            math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[
            :-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer, attention_scores

class BertSelfAttentionSplit(nn.Module):
    def __init__(self, config):
        super(BertSelfAttentionSplit, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(
            config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))


    def copy(self, target):
        self.vth = target.vth
        self.query.weight.data = target.query.weight.data.clone()
        if self.query.bias is not None:
            self.query.bias.data = target.query.bias.data.clone()

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[
            :-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, mixed_key_layer, mixed_value_layer, attention_mask, pos = None, output_att=False):
        mixed_query_layer = self.query(hidden_states if pos is None else hidden_states + pos)
        #print('Spikes :', mixed_key_layer[0][0][:100])
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(
            query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / \
            math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        if attention_mask != None:
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores) #F.relu(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[
            :-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer, attention_scores

class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super(BertSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))


    def forward(self, hidden_states, input_tensor):
        #hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor) #hidden_states + input_tensor #self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

    def copy(self, target):
        self.vth = target.vth
        self.dense.weight.data = target.dense.weight.data.clone()
        self.LayerNorm.copy(target.LayerNorm)
        if self.dense.bias is not None:
            self.dense.bias.data = target.dense.bias.data.clone()

class BertFC(nn.Module):
    def __init__(self, config):
        super(BertFC, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))

    def forward(self, hidden_states, pos: Optional[Tensor] = None):
        # Add positional embedding to Query and Key
        hidden_states = hidden_states if pos is None else hidden_states + pos
        hidden_states = self.dense(hidden_states)
        return hidden_states

    def copy(self, target):
        self.vth = target.vth
        self.dense.weight.data = target.dense.weight.data.clone()
        if self.dense.bias is not None:
            self.dense.bias.data = target.dense.bias.data.clone()

class BertAttention(nn.Module):
    def __init__(self, config):
        super(BertAttention, self).__init__()

        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)

    def forward(self, input_tensor, attention_mask):
        self_output, layer_att = self.self(input_tensor, attention_mask)
        attention_output = self.output(self_output, input_tensor)
        return attention_output, layer_att

    def copy(self, target):
        self.self.copy(target.self)
        self.output.copy(target.output)

class BertIntermediate(nn.Module):
    def __init__(self, config, intermediate_size=-1):
        super(BertIntermediate, self).__init__()
        if intermediate_size < 0:
            self.dense = nn.Linear(
                config.hidden_size, config.intermediate_size)
        else:
            self.dense = nn.Linear(config.hidden_size, intermediate_size)
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))


    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

    def copy(self, target):
        self.vth = target.vth
        self.dense.weight.data = target.dense.weight.data.clone()
        if self.dense.bias is not None:
            self.dense.bias.data = target.dense.bias.data.clone()

class BertOutput(nn.Module):
    def __init__(self, config, intermediate_size=-1):
        super(BertOutput, self).__init__()
        if intermediate_size < 0:
            self.dense = nn.Linear(
                config.intermediate_size, config.hidden_size)
        else:
            self.dense = nn.Linear(intermediate_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor) #hidden_states + input_tensor #self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

    def copy(self, target):
        self.vth = target.vth
        self.dense.weight.data = target.dense.weight.data.clone()
        self.LayerNorm.copy(target.LayerNorm)
        if self.dense.bias is not None:
            self.dense.bias.data = target.dense.bias.data.clone()

class VTGSaliency(nn.Module):
    def __init__(self, config):
        super(VTGSaliency, self).__init__()
        self.attention1 = BertSelfAttentionSplit(config)
        self.attention2 = BertSelfAttentionSplit(config)

        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))

    def forward(self, hidden_states, input, attention_mask, vid_shape = 75):
        vid_mem_proj = hidden_states[:, :vid_shape,:]
        text_mem_proj = hidden_states[:, vid_shape:,:]

        txt_ori_proj = input[:, vid_shape:,:]
        # word-level -> sentence-level
        hidden_states1, _ = self.attention1(vid_mem_proj, txt_ori_proj, txt_ori_proj, attention_mask=None)
        #hidden_states2, _ = self.attention2(text_mem_proj, txt_ori_proj, txt_ori_proj, attention_mask=None)

        #hidden_states, _ = self.attention1(hidden_states, txt_ori_proj, txt_ori_proj, attention_mask=None)
        #return hidden_states
        return torch.cat([hidden_states1, torch.ones_like(txt_ori_proj)], dim=1)

    def copy(self, target):
        self.vth = target.vth
        self.dense.weight.data = target.dense.weight.data.clone()

def mask_logits(inputs, mask, mask_value=-1e30):
    mask = mask.type(torch.float32)
    return inputs + (1.0 - mask) * mask_value

class WeightedPool(nn.Module):
    def __init__(self, dim):
        super(WeightedPool, self).__init__()
        weight = torch.empty(dim, 1)
        nn.init.xavier_uniform_(weight)

        self.weight = nn.Parameter(weight, requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(1))
    def forward(self, x, mask):
        alpha = torch.tensordot(x, self.weight, dims=1)  # shape = (batch_size, seq_length, 1)
        alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
        alphas = nn.Softmax(dim=1)(alpha)
        pooled_x = torch.matmul(x.transpose(1, 2), alphas) #+ self.bias  # (batch_size, dim, 1)
        pooled_x = pooled_x.squeeze(2)
        return pooled_x



class VTGSaliencyPool(nn.Module):
    def __init__(self, config):
        self.hidden_dim = config.hidden_size
        super(VTGSaliencyPool, self).__init__()
        self.weightedpool = WeightedPool(config.hidden_size)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)

        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))

    def forward(self, hidden_states, input, src_mask, vid_shape = 75):

        #hidden_states = self.dense(hidden_states)
        src_vid_mask = src_mask[:, :vid_shape]
        src_txt_mask = src_mask[:, vid_shape:]

        vid_mem_proj = hidden_states[:, :vid_shape,:]
        txt_mem_proj = hidden_states[:, vid_shape:,:]

        txt_original_proj = input[:, vid_shape:,:]
        # word-level -> sentence-level
        txt_mem_proj_val = self.weightedpool(txt_original_proj, src_txt_mask).unsqueeze(1)


        sim1 = normalize_scores(F.cosine_similarity(vid_mem_proj, txt_mem_proj_val, dim=-1) + (src_vid_mask + 1e-45).log())  #normalize score turned off
        #sim2 = F.cosine_similarity(txt_mem_proj, txt_mem_proj_val, dim=-1) + (torch.ones_like(src_txt_mask) + 1e-45).log()

        # hidden_states1 = nn.Softmax(dim=-1)(sim1).unsqueeze(2).expand(-1, -1, self.hidden_dim)
        # hidden_states2 = nn.Softmax(dim=-1)(sim2).unsqueeze(2).expand(-1, -1, self.hidden_dim)

        hidden_states1 = sim1.unsqueeze(2).expand(-1, -1, self.hidden_dim)
        #hidden_states2 = sim2.unsqueeze(2).expand(-1, -1, self.hidden_dim)

        #return torch.cat([hidden_states1, torch.ones_like(txt_original_proj)], dim=1)
        return torch.cat([hidden_states1, torch.ones_like(txt_original_proj)], dim=1)

    def copy(self, target):
        self.vth = target.vth

class BertPooler(nn.Module):
    def __init__(self, config, recurs=None):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dense.weight.data.normal_(
            mean=0.0, std=config.initializer_range)
        self.dense.bias.data.zero_()
        self.activation = nn.Tanh()
        self.config = config
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token. "-1" refers to last layer
        pooled_output = hidden_states[-1][:, 0] #hidden_states[-1][:, 0]
        pooled_output = self.dense(pooled_output)
        pooled_output = self.activation(self.dropout(pooled_output))

        return pooled_output