"""
This module provides utility functions for natural language processing.
"""

from typing import Tuple, List
import torch
from torch.nn import functional as F

def get_modified_input_tokens(input_tokens: List[str], removed_indices: List[int]) -> List[str]:
    """
        Generates the modified input tokens by removing elements at specified indices.

        Args:
            input_tokens (List[str]): The list of input tokens.
            removed_indices (List[int]): The list of indices to be removed.

        Returns:
            List[str]: The modified list of input tokens.
    """
    
    modified_input_tokens = [input_tokens[i] for i in range(len(input_tokens)) if i not in removed_indices]
        
    return modified_input_tokens


def get_special_tokens(input_tokens: List[str], tokenizer) -> Tuple[List[str], List[int]]:
    """
    Get special tokens and their positions from input tokens using a tokenizer.

    Parameters:
    - input_tokens (List[str]): A list of input tokens.
    - tokenizer: The tokenizer object used to tokenize the input.

    Returns:
    - Tuple[List[str], List[int]]: A tuple containing:
    the list of special tokens and their positions.
    """
    hf_special_tokens = tokenizer.special_tokens_map
    special_token_positions = [i for i, token in enumerate(input_tokens) if token in hf_special_tokens.values()]

    special_tokens = [token for token in input_tokens if token in hf_special_tokens.values()]
    return special_tokens, special_token_positions


class SequenceClassifier:
    """
    A class that encapsulates the functionality of a sequence classifier model.
    It provides methods for tokenizing input, running the model, and calculating gradients.
    """

    def __init__(self, model, tokenizer):
        """
        Initializes a new instance of the SequenceClassifier class.

        Args:
            model (PreTrainedModel): The model to use.
            tokenizer (PreTrainedTokenizer): The tokenizer to use.
        """
        self.model = model
        self.tokenizer = tokenizer

    def get_input_tokens(self, input_str: str) -> Tuple[dict, List[str]]:
        """
        Tokenizes the input string using the given tokenizer for sequence classification.

        Args:
            input_str (str): The input string to tokenize.

        Returns:
            Tuple[List[str], dict]: A tuple containing:
            the list of the tokenized inputs and the input tokens.
        """
        tokenized_inputs = self.tokenizer(input_str, return_tensors='pt')
        input_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])
        return tokenized_inputs, input_tokens

    def get_model_outputs(self, input_str: str) -> dict:
        """
        Runs the given model on the input string using the given tokenizer.

        Args:
            input_str (str): The input string to run the model on.

        Returns:
            Dict[dict, List[torch.Tensor], List[torch.Tensor], List[str], dict]: 
            A dict containing:
            the model outputs, 
            the hidden states, 
            the attentions, 
            the tokenized inputs, 
            and the input tokens.
        """
        tokenized_inputs, input_tokens = self.get_input_tokens(input_str)
        model_outputs = self.model(**tokenized_inputs)

        hidden_states, attentions = model_outputs.get('hidden_states', []), model_outputs.get('attentions', [])

        model_info = {
            'model_outputs': model_outputs,
            'hidden_states': hidden_states,
            'attentions': attentions,
            'tokenized_inputs': tokenized_inputs,
            'input_tokens': input_tokens
        }
        return model_info

    def get_info(self, input_str, logits_indices=List[int], objective: str = "probs") -> dict:
        """
        Get the gradient of the attention weights with respect to the hidden states.

        Args:
            input_str (str): The input string to tokenize.
            logits_indinces (List[int]): The indices of the logits to calculate the gradient for.
            objective (str): The objective to calculate the gradient for. Defaults to "probs".

        Returns:
            
        """
        model_info_dict = self.get_model_outputs(input_str)
        
        tokenized_inputs = model_info_dict.get('tokenized_inputs')
        input_tokens = model_info_dict.get('input_tokens')
        model_outputs = model_info_dict.get('model_outputs')
        hidden_states = model_info_dict.get('hidden_states')
        attentions = model_info_dict.get('attentions')

        for hidden_state in hidden_states:
            hidden_state.retain_grad()

        for attention in attentions:
            attention.retain_grad()

        class_logits = model_outputs.get('logits')
        class_probabilities = model_outputs.get('logits').softmax(dim=-1)

        hidden_states_grads = []
        attention_grads = []
        
        for logits_index in logits_indices:
            if objective == "probs":
                class_probabilities.flatten()[logits_index].backward(retain_graph=True)

            elif objective == "logits":
                class_logits.flatten()[logits_index].backward(retain_graph=True)

            else:
                raise ValueError("objective must be either 'probs' or 'logits'")

            hidden_states_grads_index = [hidden_states[i].grad.clone() for i in range(len(hidden_states))]
            hidden_states_grads.append(hidden_states_grads_index)
            
            # Reset the gradients
            for hidden_state in hidden_states:
                hidden_state.grad.zero_()
                
            attention_grads_index = [attentions[i].grad.clone() for i in range(len(attentions))]
            attention_grads.append(attention_grads_index)
            
            # Reset the gradients
            for attention in attentions:
                attention.grad.zero_()

        predicted_class_id = class_logits.argmax().item()
        predicted_class = self.model.config.id2label[predicted_class_id]
        
        info = {
            'model_outputs': model_outputs,
            'hidden_states': hidden_states,
            'hidden_states_grads': hidden_states_grads,
            'attentions': attentions,
            'attention_grads': attention_grads,
            'class_probabilities': class_probabilities,
            'class_logits': class_logits,
            'predicted_class_id': predicted_class_id,
            'predicted_class': predicted_class,
            'tokenized_inputs': tokenized_inputs,
            'input_tokens': input_tokens
        }
        return info
    
class QuestionAnswerer:
    """
    A class that provides methods for question answering using a pre-trained model.
    """

    def __init__(self, model, tokenizer, stride:int=32, max_length:int=512, truncation="only_second"):
        """
        Initializes a new instance of the class.

        Args:
            model: The model object to be assigned.
            tokenizer: The tokenizer object to be assigned.
        """
        self.model = model
        self.tokenizer = tokenizer
        self.stride = stride
        self.max_length = max_length
        self.truncation = truncation
        

    def get_input_tokens(self, question: str, context: str) -> Tuple[dict, List[str]]:
        """
        Tokenizes the input string using the given tokenizer.

        Args:
            question (str): The question to tokenize.
            context (str): The context to tokenize.
            tokenizer (PreTrainedTokenizer): The tokenizer to use.

        Returns:
            Tuple[List[str], dict]: 
            A tuple containing the list of the tokenized inputs and the input tokens.
        """

        tokenized_inputs = self.tokenizer(question,
                                          context,
                                          return_tensors="pt",
                                          max_length=self.max_length,
                                          truncation=self.truncation,
                                          stride=self.stride)
        
        input_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])
        return tokenized_inputs, input_tokens
    
    def get_model_outputs(self, question: str, context: str) -> dict:
        """
        Runs the given model on the input string using the given tokenizer.

        Args:
            model (PreTrainedModel): The model to use.
            tokenizer (PreTrainedTokenizer): The tokenizer to use.
            input_str (str): The input string to run the model on.

        Returns:
            Tuple[dict, List[torch.Tensor], List[torch.Tensor], List[str], dict]: 
            A tuple containing the model outputs, the hidden states, the attentions, the tokenized inputs, and the input tokens.
        """
        tokenized_inputs, input_tokens = self.get_input_tokens(question, context)

        model_outputs = self.model(**tokenized_inputs)

        hidden_states, attentions = model_outputs.get('hidden_states', []), model_outputs.get('attentions', [])
        
        model_info = {
            'model_outputs': model_outputs,
            'hidden_states': hidden_states,
            'attentions': attentions,
            'tokenized_inputs': tokenized_inputs,
            'input_tokens': input_tokens
        }

        return model_info
    
    def get_info(self, question: str, context: str, start_logits_indices: List[int], end_logits_indices: List[int], objective="probs") -> dict:
        """
        Runs the given model on the input string using the given tokenizer.

        Args:
            model (PreTrainedModel): The model to use.
            tokenizer (PreTrainedTokenizer): The tokenizer to use.
            input_str (str): The input string to run the model on.

        Returns:
            
        """
    
        model_info_dict = self.get_model_outputs(question, context)
        
        tokenized_inputs = model_info_dict.get('tokenized_inputs')
        input_tokens = model_info_dict.get('input_tokens')
        model_outputs = model_info_dict.get('model_outputs')
        hidden_states = model_info_dict.get('hidden_states')
        attentions = model_info_dict.get('attentions')

        for hidden_state in hidden_states:
            hidden_state.retain_grad()

        for attention in attentions:
            attention.retain_grad()
        
        start_index_logit = model_outputs.get('start_logits')
        end_index_logit = model_outputs.get('end_logits')
        
        start_index_probability = start_index_logit.softmax(dim=-1)
        end_index_probability = end_index_logit.softmax(dim=-1)
        
        start_index = torch.argmax(start_index_logit)
        end_index = torch.argmax(end_index_logit)
        
        predict_answer_tokens = tokenized_inputs.input_ids[0, start_index : end_index + 1]
        predicted_answer = self.tokenizer.decode(predict_answer_tokens)
        
        start_hidden_states_grads = []
        start_attention_grads = []
        end_hidden_states_grads = []
        end_attention_grads = []
        
        for logits_index in start_logits_indices:
            if objective == "logits":
                start_index_logit.flatten()[logits_index].backward(retain_graph=True)
                
                start_hidden_states_grads_index = [hidden_state.grad.clone() for hidden_state in hidden_states]
                start_hidden_states_grads.append(start_hidden_states_grads_index)
                
                start_attention_grads_index = [attention.grad.clone() for attention in attentions]
                start_attention_grads.append(start_attention_grads_index)
                
                for hiddden_state in hidden_states:
                    hiddden_state.grad.zero_()
                
                for attention in attentions:
                    attention.grad.zero_()
                
            elif objective == "probs":
                start_index_probability.flatten()[logits_index].backward(retain_graph=True)
                
                start_hidden_states_grads_index = [hidden_state.grad.clone() for hidden_state in hidden_states]
                start_hidden_states_grads.append(start_hidden_states_grads_index)
                
                start_attention_grads_index = [attention.grad.clone() for attention in attentions]
                start_attention_grads.append(start_attention_grads_index)
                
                for hiddden_state in hidden_states:
                    hiddden_state.grad.zero_()
                
                for attention in attentions:
                    attention.grad.zero_()
            
            else:
                raise ValueError("Objective must be either 'logits' or 'probs'")
                
        for logits_index in end_logits_indices:
            if objective == "logits":
                end_index_logit.flatten()[logits_index].backward(retain_graph=True)
                
                end_hidden_states_grads_index = [hidden_state.grad.clone() for hidden_state in hidden_states]
                end_hidden_states_grads.append(end_hidden_states_grads_index)
                
                end_attention_grads_index = [attention.grad.clone() for attention in attentions]
                end_attention_grads.append(end_attention_grads_index)
                
                for hiddden_state in hidden_states:
                    hiddden_state.grad.zero_()
                
                for attention in attentions:
                    attention.grad.zero_()

            elif objective == "probs":
                end_index_logit.flatten()[logits_index].backward(retain_graph=True)
                
                end_hidden_states_grads_index = [hidden_state.grad.clone() for hidden_state in hidden_states]
                end_hidden_states_grads.append(end_hidden_states_grads_index)
                
                end_attention_grads_index = [attention.grad.clone() for attention in attentions]
                end_attention_grads.append(end_attention_grads_index)
                
                for hiddden_state in hidden_states:
                    hiddden_state.grad.zero_()
                
                for attention in attentions:
                    attention.grad.zero_()
            
            else:
                raise ValueError("Objective must be either 'logits' or 'probs'")    
        
        info = {
            'model_outputs': model_outputs,
            'hidden_states': hidden_states,
            'attentions': attentions,
            'start_logits_indices': start_logits_indices,
            'end_logits_indices': end_logits_indices,
            'start_index': start_index,
            'end_index': end_index,
            'start_index_logit': start_index_logit,
            'end_index_logit': end_index_logit,
            'start_index_probability': start_index_probability,
            'end_index_probability': end_index_probability,
            'start_hidden_states_grads': start_hidden_states_grads,
            'start_attention_grads': start_attention_grads,
            'end_hidden_states_grads': end_hidden_states_grads,
            'end_attention_grads': end_attention_grads,
            'predicted_answer': predicted_answer,
            'tokenized_inputs': tokenized_inputs,
            'input_tokens': input_tokens
        }
        return info
    
class ZeroShotSequenceClassifier:
    """
    A class that encapsulates the functionality of a zero-shot classification model.
    It provides methods for tokenizing input, running the model, and calculating gradients.
    """
    def __init__(self, model, tokenizer):
        """
        Initializes a new instance of the class.

        Args:
            model: The model object to be initialized.
            tokenizer: The tokenizer object to be initialized.
        """
        self.model = model
        self.tokenizer = tokenizer
        
    def get_input_tokens(self, input_str: str, labels:List[str]) -> Tuple[dict, List[str]]:
        """
        Generate the input tokens for zero-shot classification.

        Args:
            input_str (str): The input string to tokenize.
            labels (List[str]): The list of labels to tokenize.
            tokenizer (Tokenizer): The tokenizer to use.

        Returns:
            Tuple[List[str], dict]: A tuple containing the tokenized inputs and the input tokens.
        """
        tokenized_inputs = self.tokenizer.batch_encode_plus([input_str] + labels,
                                        return_tensors='pt',
                                        padding=True)

        input_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])
        return tokenized_inputs, input_tokens
    
    def get_model_outputs(self, input_str:str, labels:List[str]) -> dict:
        """
        Generate the model outputs for zero-shot classification.

        Args:
            input_str (str): The input string to run the model on.
            labels (List[str]): The list of labels to run the model on.

        Returns:
            Dict[dict, List[torch.Tensor], List[torch.Tensor], List[str], dict]: 
            A dictionary containing the model outputs, the hidden states, the attentions, the tokenized inputs, and the input tokens.
        """
        final_input_str = [input_str] + labels

        tokenized_inputs = self.tokenizer(final_input_str, padding=True, truncation=True, return_tensors='pt')
        input_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])

        model_outputs = self.model(**tokenized_inputs)
        
        hidden_states, attentions = model_outputs.get('hidden_states', []), model_outputs.get('attentions', [])
        
        model_info= {
            'model_outputs': model_outputs,
            'hidden_states': hidden_states,
            'attentions': attentions,
            'tokenized_inputs': tokenized_inputs,
            'input_tokens': input_tokens
        }

        return model_info
    
    @staticmethod
    def mean_pooling(model_outputs, attention_mask):
        """
        Compute the mean pooling of the token embeddings based on the attention mask.

        Args:
            model_outputs (torch.Tensor): The output tensor of the model.
            attention_mask (torch.Tensor): The attention mask tensor.

        Returns:
            torch.Tensor: The mean pooled token embeddings.
        """
        token_embeddings = model_outputs[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    
    def get_info(self, input_str:str, labels: List[str], logits_indices: List[int]) -> dict:
        """
        Get the gradient of the attention weights with respect to the hidden states.

        Args:
            input_str (str): The input string to tokenize.
            labels (List[str]): The list of labels to tokenize.
            logits_index (int): The index of the logits to calculate the gradient for.

        Returns:
             
        """
        model_info = self.get_model_outputs(input_str, labels)
        
        model_outputs = model_info.get('model_outputs')
        tokenized_inputs = model_info.get('tokenized_inputs')
        input_tokens = model_info.get('input_tokens')
        hidden_states = model_info.get('hidden_states')
        attentions = model_info.get('attentions')
        
        # evaluate the embeddings
        embeddings = self.mean_pooling(model_outputs, tokenized_inputs['attention_mask'])
        embeddings = F.normalize(embeddings, p=2, dim=-1)
        
        # evaluate the class quasi-logits
        unnormalized_similarities = F.cosine_similarity(embeddings[0], embeddings[1:])
        
        # evaluate the class quasi-probabilities
        class_probabilities = F.normalize(unnormalized_similarities, p=1, dim=-1)
        closest = class_probabilities.argsort(descending=True)

        for hidden_state in hidden_states:
            hidden_state.retain_grad()

        for attention in attentions:
            attention.retain_grad()

        hidden_states_grads = []
        attention_grads = []
        
        for logits_index in logits_indices:
            class_probabilities.flatten()[logits_index].backward(retain_graph=True)
            
            hidden_states_grads_index = [hidden_states[i].grad.clone() for i in range(len(hidden_states))]
            hidden_states_grads.append(hidden_states_grads_index)
            
            # Reset the gradients
            for hidden_state in hidden_states:
                hidden_state.grad.zero_()
                
            attention_grads_index = [attentions[i].grad.clone() for i in range(len(attentions))]
            attention_grads.append(attention_grads_index)
            
            # Reset the gradients
            for attention in attentions:
                attention.grad.zero_()
    
        info = {
            'model_outputs': model_outputs,
            'hidden_states': hidden_states,
            'hidden_states_grads': hidden_states_grads,
            'attentions': attentions,
            'attention_grads': attention_grads,
            'class_probabilities': class_probabilities,
            'closest': closest,
            'tokenized_inputs': tokenized_inputs,
            'input_tokens': input_tokens
        }

        return info
    
class MaskedLanguageModeler:
    """
    A class that represents a masked language model.

    Args:
        model (object): The model object.
        tokenizer (object): The tokenizer object.
    """
    def __init__(self, model, tokenizer):
        """
        Initializes a new instance of the class.

        Parameters:
            model (object): The model to be assigned to the instance.
            tokenizer (object): The tokenizer to be assigned to the instance.
        """
        self.model = model
        self.tokenizer = tokenizer
    
    def get_input_tokens(self, input_str: str) -> Tuple[dict, List[str]]:
        """
        Tokenizes the input string using the given tokenizer for sequence classification.

        Args:
            input_str (str): The input string to tokenize.

        Returns:
            Tuple[List[str], dict]: A tuple containing:
            the list of the tokenized inputs and the input tokens.
        """
        tokenized_inputs = self.tokenizer(input_str, return_tensors='pt')
        input_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])
        return tokenized_inputs, input_tokens
    
    def get_model_outputs(self, input_str: str) -> dict:
        """
        Runs the given model on the input string using the given tokenizer.

        Args:
            input_str (str): The input string to run the model on.

        Returns:
            Dict[dict, List[torch.Tensor], List[torch.Tensor], List[str], dict]: 
            A dict containing:
            the model outputs, 
            the hidden states, 
            the attentions, 
            the tokenized inputs, 
            and the input tokens.
        """
        tokenized_inputs, input_tokens = self.get_input_tokens(input_str)
        model_outputs = self.model(**tokenized_inputs)

        hidden_states, attentions = model_outputs.get('hidden_states', []), model_outputs.get('attentions', [])

        model_info = {
            'model_outputs': model_outputs,
            'hidden_states': hidden_states,
            'attentions': attentions,
            'tokenized_inputs': tokenized_inputs,
            'input_tokens': input_tokens
        }
        return model_info
    
    
    def get_info(self, input_str, logits_indices:List[int], objective="probs"):
        """
        Retrieves information about the input string.

        Parameters:
            input_str (str): The input string.

        Returns:
            str: The decoded masked words.
        """
        model_info_dict = self.get_model_outputs(input_str)
        
        model_outputs = model_info_dict.get('model_outputs')
        tokenized_inputs = model_info_dict.get('tokenized_inputs')
        input_tokens = model_info_dict.get('input_tokens')
        attentions = model_info_dict.get('attentions')
        hidden_states = model_info_dict.get('hidden_states')

        mask_token_index = (tokenized_inputs.input_ids == self.tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]

        class_logits = model_outputs.get('logits')
        class_probabilities = model_outputs.get('logits').softmax(dim=-1)

        for hidden_state in hidden_states:
            hidden_state.retain_grad()

        for attention in attentions:
            attention.retain_grad()

        hidden_states_grads = []
        attention_grads = []
        class_logits = model_outputs.get('logits')
        class_probabilities = model_outputs.get('logits').softmax(dim=-1)

        assert len(logits_indices) == len(mask_token_index), "The length of logits_indices must be equal to the length of mask_token_index."

        for i, logits_index in enumerate(logits_indices):
            if objective == "probs":
                class_probabilities[0, mask_token_index[i]][logits_index].backward(retain_graph=True)

            elif objective == "logits":
                class_logits[0, mask_token_index[i]][logits_index].backward(retain_graph=True)

            else:
                raise ValueError("objective must be either 'probs' or 'logits'")

            hidden_states_grads_index = [hidden_states[i].grad.clone() for i in range(len(hidden_states))]
            hidden_states_grads.append(hidden_states_grads_index)
            
            # Reset the gradients
            for hidden_state in hidden_states:
                hidden_state.grad.zero_()
                
            attention_grads_index = [attentions[i].grad.clone() for i in range(len(attentions))]
            attention_grads.append(attention_grads_index)
            
            # Reset the gradients
            for attention in attentions:
                attention.grad.zero_()
            
        predicted_token_id = class_logits[0, mask_token_index].argmax(axis=-1)
        masked_words = self.tokenizer.decode(predicted_token_id)
        
        info = {
            'model_outputs': model_outputs,
            'tokenized_inputs': tokenized_inputs,
            'attentions': attentions,
            'attention_grads': attention_grads,
            'hidden_states': hidden_states,
            'hidden_states_grads': hidden_states_grads,
            'class_logits': class_logits,
            'class_probabilities': class_probabilities,
            'masked_token_index': mask_token_index,
            'predicted_token_id': predicted_token_id,
            'masked_words': masked_words,
            'input_tokens': input_tokens
        }
        return info
    