import math
import os
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence

from transformers import BertPreTrainedModel
from transformers.modeling_outputs import (
    SequenceClassifierOutput,
    BaseModelOutputWithPoolingAndCrossAttentions,
)

class LSTMNonlinearClassificationHead(nn.Module):
    """Head for sentence-level classification tasks. Identical to RobertaClassificationHead."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features  # features is the pooled [CLS] token
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

class LSTMEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id).requires_grad_(config.update_embeddings)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values_length: int = 0,
    ) -> torch.Tensor:
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        embeddings = inputs_embeds

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings
    
class Attention(nn.Module):
    """
    Attention network
    Parameters
    ----------
    rnn_size : int
        Size of Bi-LSTM
    """
    def __init__(self, rnn_size: int) -> None:
        super(Attention, self).__init__()
        self.w = nn.Linear(rnn_size, 1)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Parameters
        ----------
        H : torch.Tensor (batch_size, word_pad_len, hidden_size)
            Output of Bi-LSTM
        Returns
        -------
        r : torch.Tensor (batch_size, rnn_size)
            Sentence representation
        alpha : torch.Tensor (batch_size, word_pad_len)
            Attention weights
        """
        # eq.9: M = tanh(H)
        M = self.tanh(H)  # (batch_size, word_pad_len, rnn_size)

        # eq.10: α = softmax(w^T M)
        alpha = self.w(M).squeeze(2)  # (batch_size, word_pad_len)
        alpha = self.softmax(alpha)  # (batch_size, word_pad_len)

        # eq.11: r = H
        r = H * alpha.unsqueeze(2)  # (batch_size, word_pad_len, rnn_size)
        r = r.sum(dim = 1)  # (batch_size, rnn_size)

        return r, alpha
    
class LSTMModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.embeddings = LSTMEmbeddings(config)
    
        # Encoder RNN
        self.rnn_size = self.config.hidden_size
        self.BiLSTM = nn.LSTM(
            input_size = self.config.hidden_size,
            hidden_size = self.config.hidden_size,
            num_layers = self.config.num_hidden_layers,
            bidirectional = self.config.bidirectional,
            dropout = (0 if self.config.num_hidden_layers == 1 else self.config.hidden_dropout_prob),
            batch_first = True
        )
        self.tanh = nn.Tanh()
        
        self.attention = Attention(self.config.hidden_size)
        
        self.intervention_h_dim = config.intervention_h_dim
        
    def apply_attention(self, rnn_output, final_hidden_state):
        '''
        Apply Attention on RNN output
        
        Input:
            rnn_output (batch_size, seq_len, num_directions * hidden_size): tensor representing hidden state for every word in the sentence
            final_hidden_state (batch_size, num_directions * hidden_size): final hidden state of the RNN
            
        Returns:
            attention_output(batch_size, num_directions * hidden_size): attention output vector for the batch
        '''
        hidden_state = final_hidden_state.unsqueeze(2)
        attention_scores = torch.bmm(rnn_output, hidden_state).squeeze(2)
        soft_attention_weights = F.softmax(attention_scores, 1).unsqueeze(2) #shape = (batch_size, seq_len, 1)
        attention_output = torch.bmm(rnn_output.permute(0,2,1), soft_attention_weights).squeeze(2)
        return attention_output
    
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        # counterfactual arguments
        source_hidden_states=None,
        base_intervention_corr=None,
        source_intervention_corr=None,
        all_layers=None,
        cls_hidden_reprs=None,
        # counterfactual arguments
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
        
        embedding_output = self.embeddings(
            input_ids=input_ids,
        )
        
        # pack sequences (remove word-pads, SENTENCES -> WORDS)
        packed_words = pack_padded_sequence(
            embedding_output,
            lengths = attention_mask.sum(dim=-1).tolist(),
            batch_first = True,
            enforce_sorted = False
        )  # a PackedSequence object, where 'data' is the flattened words (n_words, emb_size)
        
        self.BiLSTM.flatten_parameters()
        # run through bidirectional LSTM (PyTorch automatically applies it on the PackedSequence)
        rnn_out, _ = self.BiLSTM(packed_words)  # a PackedSequence object, where 'data' is the output of the LSTM (n_words, 2 * rnn_size)

        # unpack sequences (re-pad with 0s, WORDS -> SENTENCES)
        rnn_out, _ = pad_packed_sequence(rnn_out, batch_first = True)  # (batch_size, word_pad_len, 2 * word_rnn_size)

        # eq.8: h_i = [\overrightarrow{h}_i ⨁ \overleftarrow{h}_i ]
        # H = {h_1, h_2, ..., h_T}
        H = rnn_out[ :, :, : self.rnn_size] + rnn_out[ :, :, self.rnn_size : ] # (batch_size, word_pad_len, rnn_size)

        # attention module
        hidden_states, alphas = self.attention(H)  # (batch_size, rnn_size), (batch_size, word_pad_len)

        # INT_POINT: only the last layer!
        if base_intervention_corr is not None and source_hidden_states is not None:
            for b in range(0, hidden_states.shape[0]):
                if base_intervention_corr[b] != -1:
                    start_idx = base_intervention_corr[b]*self.intervention_h_dim
                    end_idx = (base_intervention_corr[b]+1)*self.intervention_h_dim
                    # we support where the pass in source_hidden_states
                    # is a partial one only for the interchanging aspect.
                    if not isinstance(source_hidden_states, tuple) and hidden_states.shape[-1] != source_hidden_states.shape[-1]:
                        hidden_states[b][start_idx:end_idx] = source_hidden_states[b]
                    else:
                        hidden_states[b][start_idx:end_idx] = \
                            source_hidden_states[b][start_idx:end_idx]
        
        # eq.12: h* = tanh(pooled_output)
        final_out = self.tanh(hidden_states)  # (batch_size, rnn_size)
        
        
        if not return_dict:
            return (final_out, final_out) + final_out
        
        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=hidden_states,
            pooler_output=final_out,
            hidden_states=hidden_states,
        )

class MultiTaskClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config, num_labels):
        super().__init__()
        hidden_size = config.intervention_h_dim
        self.dense = nn.Linear(hidden_size, hidden_size)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(hidden_size, num_labels)

    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

    
class IITLSTMForSequenceClassification(BertPreTrainedModel): # fake overhead loading.
    def __init__(
        self, config,
        num_aspect_labels=3,
    ):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.lstm = LSTMModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = LSTMNonlinearClassificationHead(config)
        self.intervention_h_dim = config.intervention_h_dim
        self.interchange_hidden_layer = config.interchange_hidden_layer
        self.multitask_classifier = MultiTaskClassificationHead(
            config, num_aspect_labels
        )
        
        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        # counterfactual arguments
        source_hidden_states=None,
        base_intervention_corr=None,
        source_intervention_corr=None,
        all_layers=None,
        cls_hidden_reprs=None,
        # counterfactual arguments
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if cls_hidden_reprs is not None:
            # we also need to all the pooler once if configured.
            pooled_output = cls_hidden_reprs
        else:
            outputs = self.lstm(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                # counterfactual arguments
                source_hidden_states=source_hidden_states,
                base_intervention_corr=base_intervention_corr,
                source_intervention_corr=source_intervention_corr,
                all_layers=all_layers,
                # counterfactual arguments
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            before_pooled_output = outputs[0]
            pooled_output = outputs[1]

        #####
        # We can simply do IIT here as well as the multitask objective!
        mul_logits_0 = self.multitask_classifier(before_pooled_output[:,:self.intervention_h_dim])
        mul_logits_1 = self.multitask_classifier(before_pooled_output[:,self.intervention_h_dim:self.intervention_h_dim*2])
        mul_logits_2 = self.multitask_classifier(before_pooled_output[:,self.intervention_h_dim*2:self.intervention_h_dim*3])
        mul_logits_3 = self.multitask_classifier(before_pooled_output[:,self.intervention_h_dim*3:self.intervention_h_dim*4])
        #####
            
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=(logits, mul_logits_0, mul_logits_1, mul_logits_2, mul_logits_3),
            hidden_states=None if cls_hidden_reprs is not None else before_pooled_output,
            attentions=None,
        )