# Identifiers will be add once the code is made public.


from __future__ import absolute_import, division, print_function, unicode_literals

import copy
import json
import logging
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open
from platform import uname

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from typing import Optional
from torch.nn import CrossEntropyLoss
from torch.autograd import Variable
from torch.nn.parameter import Parameter

logger = logging.getLogger(__name__)
v_th_Val = 1.

def normalize_scores(scores):
    """
    Normalize the saliency scores to the range [0, 1] using min-max normalization.

    :param scores: Tensor of saliency scores (Batch, L_v)
    :return: Normalized scores (Batch, L_v)
    """
    # Ensure the scores are of type float for precision
    scores = scores.float()

    # Compute min and max values along the sequence length dimension
    min_scores = scores.min(dim=1, keepdim=True)[0]  # (Batch, 1)
    max_scores = scores.max(dim=1, keepdim=True)[0]  # (Batch, 1)

    # Apply min-max normalization
    normalized_scores = (scores - min_scores) / (
                max_scores - min_scores + 1e-8)  # Adding small epsilon to avoid division by zero

    return normalized_scores#

class BertConfig(object):
    """Configuration class to store the configuration of a `BertModel`.
    """

    def __init__(self,
                 vocab_size_or_config_json_file,
                 hidden_size=768,
                 num_hidden_layers=12,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 hidden_act="gelu",
                 hidden_dropout_prob=0.1,
                 attention_probs_dropout_prob=0.1,
                 max_position_embeddings=512,
                 type_vocab_size=2,
                 initializer_range=0.02,
                 pre_trained='',
                 training=''):
        """Constructs BertConfig.

        Args:
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
            hidden_size: Size of the encoder layers and the pooler layer.
            num_hidden_layers: Number of hidden layers in the Transformer encoder.
            num_attention_heads: Number of attention heads for each attention layer in
                the Transformer encoder.
            intermediate_size: The size of the "intermediate" (i.e., feed-forward)
                layer in the Transformer encoder.
            hidden_act: The non-linear activation function (function or string) in the
                encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
            hidden_dropout_prob: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attention_probs_dropout_prob: The dropout ratio for the attention
                probabilities.
            max_position_embeddings: The maximum sequence length that this model might
                ever be used with. Typically set this to something large just in case
                (e.g., 512 or 1024 or 2048).
            type_vocab_size: The vocabulary size of the `token_type_ids` passed into
                `BertModel`.
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
        """
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                                                               and isinstance(vocab_size_or_config_json_file, unicode)):
            with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.hidden_act = hidden_act
            self.intermediate_size = intermediate_size
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.type_vocab_size = type_vocab_size
            self.initializer_range = initializer_range
            self.pre_trained = pre_trained
            self.training = training
        else:
            raise ValueError("First argument must be either a vocabulary size (int)"
                             "or the path to a pretrained model config file (str)")

    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `BertConfig` from a Python dictionary of parameters."""
        config = BertConfig(vocab_size_or_config_json_file=-1)
        for key, value in json_object.items():
            config.__dict__[key] = value
        return config

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `BertConfig` from a json file of parameters."""
        with open(json_file, "r", encoding='utf-8') as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path):
        """ Save this instance to a json file."""
        with open(json_file_path, "w", encoding='utf-8') as writer:
            writer.write(self.to_json_string())

class BertLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super(BertLayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        #return x
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

    def copy(self, target):
        self.weight.data = target.weight.data.clone()
        self.bias.data = target.bias.data.clone()

def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    return F.gelu(x) #x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu}
NORM = {'layer_norm': BertLayerNorm}

class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)


    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(
            seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

    def copy(self, target):
        self.LayerNorm.copy(target.LayerNorm)
        self.word_embeddings.weight.data = target.word_embeddings.weight.data
        self.position_embeddings.weight.data = target.position_embeddings.weight.data
        self.token_type_embeddings.weight.data = target.token_type_embeddings.weight.data

class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super(BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(
            config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def copy(self, target):
        self.query.weight.data = target.query.weight.data.clone()
        if self.query.bias is not None:
            self.query.bias.data = target.query.bias.data.clone()

        self.key.weight.data = target.key.weight.data.clone()
        if self.key.bias is not None:
            self.key.bias.data = target.key.bias.data.clone()

        self.value.weight.data = target.value.weight.data.clone()
        if self.value.bias is not None:
            self.value.bias.data = target.value.bias.data.clone()

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[
            :-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, r1, r2, attention_mask, output_att=False):

        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(
            query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / \
            math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[
            :-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer, attention_scores

class BertSelfAttentionSplit(nn.Module):
    def __init__(self, config):
        super(BertSelfAttentionSplit, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(
            config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))


    def copy(self, target):
        self.vth = target.vth
        self.query.weight.data = target.query.weight.data.clone()
        if self.query.bias is not None:
            self.query.bias.data = target.query.bias.data.clone()

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[
            :-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, mixed_key_layer, mixed_value_layer, attention_mask, pos = None, output_att=False):
        mixed_query_layer = self.query(hidden_states if pos is None else hidden_states + pos)
        #print('Spikes :', mixed_key_layer[0][0][:100])
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(
            query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / \
            math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask
        length = hidden_states.shape[1]

        # Normalize the attention scores to probabilities.
        attention_probs = (1. / length) * F.relu(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[
            :-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer, attention_scores

class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super(BertSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))


    def forward(self, hidden_states, input_tensor):
        #hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = hidden_states + input_tensor #self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

    def copy(self, target):
        self.vth = target.vth
        self.dense.weight.data = target.dense.weight.data.clone()
        self.LayerNorm.copy(target.LayerNorm)
        if self.dense.bias is not None:
            self.dense.bias.data = target.dense.bias.data.clone()

class BertFC(nn.Module):
    def __init__(self, config):
        super(BertFC, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))

    def forward(self, hidden_states, pos: Optional[Tensor] = None):
        # Add positional embedding to Query and Key
        hidden_states = hidden_states if pos is None else hidden_states + pos
        hidden_states = self.dense(hidden_states)
        return hidden_states

    def copy(self, target):
        self.vth = target.vth
        self.dense.weight.data = target.dense.weight.data.clone()
        if self.dense.bias is not None:
            self.dense.bias.data = target.dense.bias.data.clone()

class BertAttention(nn.Module):
    def __init__(self, config):
        super(BertAttention, self).__init__()

        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)

    def forward(self, input_tensor, attention_mask):
        self_output, layer_att = self.self(input_tensor, attention_mask)
        attention_output = self.output(self_output, input_tensor)
        return attention_output, layer_att

    def copy(self, target):
        self.self.copy(target.self)
        self.output.copy(target.output)

class BertIntermediate(nn.Module):
    def __init__(self, config, intermediate_size=-1):
        super(BertIntermediate, self).__init__()
        if intermediate_size < 0:
            self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        else:
            self.dense = nn.Linear(config.hidden_size, intermediate_size)
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))


    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

    def copy(self, target):
        self.vth = target.vth
        self.dense.weight.data = target.dense.weight.data.clone()
        if self.dense.bias is not None:
            self.dense.bias.data = target.dense.bias.data.clone()

class VTGSaliency(nn.Module):
    def __init__(self, config):
        super(VTGSaliency, self).__init__()
        self.attention = BertSelfAttentionSplit(config)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))

    def forward(self, hidden_states, input, attention_mask, vid_shape = 75):
        vid_mem_proj = hidden_states[:, :vid_shape,:]
        txt_mem_proj = input[:, vid_shape:,:]
        # word-level -> sentence-level
        hidden_states, _ = self.attention(vid_mem_proj, txt_mem_proj, txt_mem_proj, attention_mask=None)

        return torch.cat([hidden_states, txt_mem_proj], dim=1)

    def copy(self, target):
        self.vth = target.vth
        self.dense.weight.data = target.dense.weight.data.clone()

class BertOutput(nn.Module):
    def __init__(self, config, intermediate_size=-1):
        super(BertOutput, self).__init__()
        if intermediate_size < 0:
            self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        else:
            self.dense = nn.Linear(intermediate_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = hidden_states + input_tensor #self.LayerNorm(hidden_states + input_tensor)
        # print('Mean : ', torch.mean(hidden_states[0]))
        # print('STD : ', torch.std(hidden_states[0]))

        return hidden_states

    def copy(self, target):
        self.vth = target.vth
        self.dense.weight.data = target.dense.weight.data.clone()
        self.LayerNorm.copy(target.LayerNorm)
        if self.dense.bias is not None:
            self.dense.bias.data = target.dense.bias.data.clone()

def mask_logits(inputs, mask, mask_value=-1e30):
    mask = mask.type(torch.float32)
    return inputs + (1.0 - mask) * mask_value


class WeightedPool(nn.Module):
    def __init__(self, dim):
        super(WeightedPool, self).__init__()
        weight = torch.empty(dim, 1)
        nn.init.xavier_uniform_(weight)

        self.weight = nn.Parameter(weight, requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(1))
    def forward(self, x, mask):
        alpha = torch.tensordot(x, self.weight, dims=1)  # shape = (batch_size, seq_length, 1)

        seq_length = alpha.shape[1]
        alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
        #print(seq_length)
        alphas = F.relu(alpha) / seq_length #nn.Softmax(dim=1)(alpha)
        pooled_x = torch.matmul(x.transpose(1, 2), alphas) #+ self.bias  # (batch_size, dim, 1)
        pooled_x = pooled_x.squeeze(2)
        return pooled_x

class VTGSaliencyPool(nn.Module):
    def __init__(self, config):
        self.hidden_dim = config.hidden_size
        super(VTGSaliencyPool, self).__init__()
        self.weightedpool = WeightedPool(config.hidden_size)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)

        self.vth = torch.tensor(v_th_Val, requires_grad=False) #nn.Parameter(torch.tensor(v_th_Val))

    def forward(self, hidden_states, input, src_mask, vid_shape = 75):

        #hidden_states = self.dense(hidden_states)
        src_vid_mask = src_mask[:, :vid_shape]
        src_txt_mask = src_mask[:, vid_shape:]

        vid_mem_proj = hidden_states[:, :vid_shape,:]
        txt_mem_proj = hidden_states[:, vid_shape:,:]

        txt_original_proj = input[:, vid_shape:,:]
        # word-level -> sentence-level
        txt_mem_proj_val = self.weightedpool(txt_original_proj, src_txt_mask).unsqueeze(1)


        sim1 = normalize_scores(F.cosine_similarity(vid_mem_proj, txt_mem_proj_val, dim=-1) + (src_vid_mask + 1e-45).log())  #normalize score turned off
        #sim2 = F.cosine_similarity(txt_mem_proj, txt_mem_proj_val, dim=-1) + (torch.ones_like(src_txt_mask) + 1e-45).log()

        # hidden_states1 = nn.Softmax(dim=-1)(sim1).unsqueeze(2).expand(-1, -1, self.hidden_dim)
        # hidden_states2 = nn.Softmax(dim=-1)(sim2).unsqueeze(2).expand(-1, -1, self.hidden_dim)

        hidden_states1 = sim1.unsqueeze(2).expand(-1, -1, self.hidden_dim)
        #hidden_states2 = sim2.unsqueeze(2).expand(-1, -1, self.hidden_dim)

        #return torch.cat([hidden_states1, torch.ones_like(txt_original_proj)], dim=1)
        return torch.cat([hidden_states1, torch.ones_like(txt_original_proj)], dim=1)

    def copy(self, target):
        self.vth = target.vth

class BertPooler(nn.Module):
    def __init__(self, config, recurs=None):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dense.weight.data.normal_(
            mean=0.0, std=config.initializer_range)
        self.dense.bias.data.zero_()
        self.activation = nn.Tanh()
        self.config = config
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token. "-1" refers to last layer
        pooled_output = hidden_states[-1][:, 0] #hidden_states[-1][:, 0]
        pooled_output = self.dense(pooled_output)
        pooled_output = self.activation(self.dropout(pooled_output))

        return pooled_output