"""
This module provides utility functions for natural language processing.
"""
from typing import Tuple, List
import copy
import inspect
import torch

def get_mean_pooling(model_attr_embedding):
    """
    Calculate the mean pooling of the given model attribute embedding.

    Args:
        model_attr_embedding: The input model attribute embedding tensor.

    Returns:
        The mean pooled tensor if the input tensor's dimension is 3, otherwise returns the input tensor.
    """
    if len(model_attr_embedding.shape) == 3:
        mean_pooling = model_attr_embedding.mean(dim=-1)
        
        # normlaize mean_pooling so that its l1-norm is 1
        normalized_mean_pooling = mean_pooling / torch.norm(mean_pooling, p=1, dim=-1, keepdim=True)
        
    else:
        normalized_mean_pooling = model_attr_embedding
    return normalized_mean_pooling

def summarize_attributions(attributions):
    """
    Given attributions, this function calculates the mean along the last dimension, 
    squeezes the resulting tensor, normalizes the attributions, and returns the 
    normalized attributions as a NumPy array with a shape of (1, -1).
    """
    normalized_attributions = attributions.mean(dim=-1).squeeze(0)
    final_normalized_attributions = (normalized_attributions - normalized_attributions.min()) / (normalized_attributions.max() - normalized_attributions.min() + 1e-8)
    return final_normalized_attributions.detach().cpu().numpy().reshape(1, -1)

def get_clean_attr(obj: object, attr: str) -> object:
    """
    Return the value of the specified attribute from the object if it exists, otherwise return None.

    Args:
        obj (object): The object from which to retrieve the attribute.
        attr (str): The name of the attribute to retrieve.

    Returns:
        object: The value of the specified attribute if it exists, otherwise None.
    """

    if hasattr(obj, attr):
        return getattr(obj, attr)
    
    return None
    
def get_model_embedding(hf_model):
    """
    Retrieve the model embeddings from the given Hugging Face model.

    Args:
        hf_model (object): The Hugging Face model object.

    Returns:
        object: The model embeddings.
    """

    config = get_clean_attr(hf_model, 'config')
    model_type = get_clean_attr(config, 'model_type')
    model_info = get_clean_attr(hf_model, model_type)

    if model_info is None:
        model_embeddings = get_clean_attr(hf_model, 'embeddings')
    else:
        model_embeddings = get_clean_attr(model_info, 'embeddings')

    word_embeddings = get_clean_attr(model_embeddings, 'word_embeddings')
    token_type_embeddings = get_clean_attr(model_embeddings, 'token_type_embeddings')
    position_embeddings = get_clean_attr(model_embeddings, 'position_embeddings')

    return model_embeddings, word_embeddings, token_type_embeddings, position_embeddings

def get_correct_args(model, init_kwargs:dict) -> dict:
    """
    Returns a dictionary of arguments that are required by the model's `forward` method.

    Parameters:
    model (Any): The model to check for required arguments.
    init_kwargs (Dict[str, Any]): The initial keyword arguments passed to the model.
    
    Returns:
    Dict[str, Any]: A dictionary of arguments that are required by the model's `forward` method.
    """

    # Check whether model require init_kwargs or not when calling model.forward()
    correct_args = {k: v for k, v in init_kwargs.items() if k in model.forward.__code__.co_varnames}
    return correct_args

def get_position_id(input_id):
    """
    Get the position id for each token in the input id using the provided tokenizer.
    
    Args:
        tokenizer: The tokenizer object used to convert input tokens to position ids.
        input_id: The input token ids for which position ids need to be generated.
        
    Returns:
        List of position ids corresponding to the input token ids.
    """
    position_id = list(range(len(input_id)))
    return position_id

def get_ref_id(hf_tokenizer, input_id):
    """
    Get the reference id for each token in the input id using the provided tokenizer.
    
    Args:
        tokenizer: The tokenizer object used to convert input tokens to reference ids.
        input_id: The input token ids for which reference ids need to be generated.
        
    Returns:
        List of reference ids corresponding to the input token ids.
    """
    ref_id = [x.item() if x in hf_tokenizer.all_special_ids else hf_tokenizer.pad_token_id for x in input_id]
    return ref_id

def construct_input_ref_pair(hf_tokenizer, input_str:str, **kwrags):
    """
    Construct input and reference pairs for a given input string using the provided Hugging Face tokenizer.

    Args:
        hf_tokenizer: The Hugging Face tokenizer used to tokenize the input string.
        input_str: The input string to be tokenized.

    Returns:
        Tuple of torch tensors containing:
          input ids and reference input ids, 
          token type ids, reference token type ids, 
          position ids, reference position ids, 
          and attention mask.
    """

    tokenized_inputs = hf_tokenizer(input_str, return_tensors='pt', **kwrags)

    # Create the list of input ids
    input_ids = tokenized_inputs.get('input_ids')
    
    token_type_ids = tokenized_inputs.get('token_type_ids', None)
    position_ids = torch.tensor([get_position_id(input_id) for input_id in input_ids], dtype=torch.long)

    sep_token_id = hf_tokenizer.sep_token_id
    cls_token_id = hf_tokenizer.cls_token_id
    ref_token_id = hf_tokenizer.pad_token_id

    # Get all special token ids
    all_special_token_ids = hf_tokenizer.all_special_ids

    # Check if all special token ids are present
    if not set([sep_token_id, cls_token_id, ref_token_id]).issubset(set(all_special_token_ids)):
        raise ValueError(f"All defined special token ids should be in the list of special token ids: {all_special_token_ids}")

    # Create list of the reference input ids
    ref_input_ids = torch.tensor([get_ref_id(hf_tokenizer, input_id) for input_id in input_ids])

    # Create reference token type ids
    ref_token_type_ids = torch.zeros_like(token_type_ids) if token_type_ids is not None else None

    # Create reference position ids
    ref_position_ids = torch.zeros_like(position_ids, dtype=torch.long)

    attention_mask = tokenized_inputs.get('attention_mask')

    ref_pair_info = {
        'input_ids': input_ids,
        'ref_input_ids': ref_input_ids,
        'token_type_ids': token_type_ids,
        'ref_token_type_ids': ref_token_type_ids,
        'position_ids': position_ids,
        'ref_position_ids': ref_position_ids,
        'attention_mask': attention_mask
    }
    return ref_pair_info


def construct_input_ref_pair_qa(hf_tokenizer, question:str, context:str, **kwargs):
    """
    Construct input and reference pairs for a given input string using the provided Hugging Face tokenizer.

    Args:
        hf_tokenizer: The Hugging Face tokenizer used to tokenize the input string.
        input_str: The input string to be tokenized.

    Returns:
        Tuple of torch tensors containing:
          input ids and reference input ids, 
          token type ids, reference token type ids, 
          position ids, reference position ids, 
          and attention mask.
    """

    tokenized_inputs = hf_tokenizer(question, context, return_tensors='pt', **kwargs)

    # Create the list of input ids
    input_ids = tokenized_inputs.get('input_ids')
    
    token_type_ids = tokenized_inputs.get('token_type_ids', None)
    position_ids = torch.tensor([get_position_id(input_id) for input_id in input_ids], dtype=torch.long)

    sep_token_id = hf_tokenizer.sep_token_id
    cls_token_id = hf_tokenizer.cls_token_id
    ref_token_id = hf_tokenizer.pad_token_id

    # Get all special token ids
    all_special_token_ids = hf_tokenizer.all_special_ids

    # Check if all special token ids are present
    if not set([sep_token_id, cls_token_id, ref_token_id]).issubset(set(all_special_token_ids)):
        raise ValueError(f"All defined special token ids should be in the list of special token ids: {all_special_token_ids}")

    # Create list of the reference input ids
    ref_input_ids = torch.tensor([get_ref_id(hf_tokenizer, input_id) for input_id in input_ids])

    # Create reference token type ids
    ref_token_type_ids = torch.zeros_like(token_type_ids) if token_type_ids is not None else None

    # Create reference position ids
    ref_position_ids = torch.zeros_like(position_ids, dtype=torch.long)

    attention_mask = tokenized_inputs.get('attention_mask')

    ref_pair_info = {
        'input_ids': input_ids,
        'ref_input_ids': ref_input_ids,
        'token_type_ids': token_type_ids,
        'ref_token_type_ids': ref_token_type_ids,
        'position_ids': position_ids,
        'ref_position_ids': ref_position_ids,
        'attention_mask': attention_mask
    }
    return ref_pair_info


class SequenceClassifier:
    """
    A class that encapsulates the functionality of a sequence classifier model.
    It provides methods for tokenizing input, running the model, and calculating gradients.
    """
    def __init__(self, model, tokenizer):
        """
        Initializes a new instance of the SequenceClassifier class.

        Args:
            model (PreTrainedModel): The model to use.
            tokenizer (PreTrainedTokenizer): The tokenizer to use.
        """
        self.model = model
        self.tokenizer = tokenizer


    def get_input_tokens(self, input_str: str) -> Tuple[dict, List[str]]:
        """
        Tokenizes the input string using the given tokenizer for sequence classification.

        Args:
            input_str (str): The input string to tokenize.

        Returns:
            Tuple[List[str], dict]: A tuple containing:
            the list of the tokenized inputs and the input tokens.
        """
        tokenized_inputs = self.tokenizer(input_str, return_tensors='pt')
        input_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])
        return tokenized_inputs, input_tokens
    
    def sc_custom_forward_func(self, input_ids, token_type_ids, position_ids, attention_mask) -> torch.Tensor:
        """
        Custom forward function for the sc model.

        Args:
            input_ids (torch.Tensor): The input token IDs.
            token_type_ids (torch.Tensor): The token type IDs.
            position_ids (torch.Tensor): The position IDs.
            attention_mask (torch.Tensor): The attention mask.

        Returns:
            torch.Tensor: The output probabilities.
        """
        init_kwargs = {
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'position_ids': position_ids,
            'attention_mask': attention_mask
        }

        kwargs = get_correct_args(self.model, init_kwargs)
        model_outputs = self.model(**kwargs)
        
        class_logits = model_outputs.get('logits')[0]
        class_probabilities = class_logits.softmax(dim=-1)
        probs = class_probabilities.reshape(1,-1)
        
        return probs
    
    def get_model_attribution(self, input_str:str, target:int, captum_model, captum_model_kwargs, attr_model_kwargs):
        """
        Calculate the attribution of the input to the model's prediction using the Captum library.

        Args:
            input_str (str): The input string to be attributed.
            target (int): The target class for which attribution is calculated.
            captum_model: The Captum model for attribution.
            captum_model_kwargs: Keyword arguments for constructing the Captum model.
            attr_model_kwargs: Keyword arguments for attributing the input to the model.

        Returns:
            dict: A dictionary containing the model attribution information.
        """
        ref_pair_info = construct_input_ref_pair(self.tokenizer, input_str)

        input_ids = ref_pair_info.get('input_ids', None)
        token_type_ids = ref_pair_info.get('token_type_ids', None)
        position_ids = ref_pair_info.get('position_ids', None)
        attention_mask = ref_pair_info.get('attention_mask', None)

        ref_input_ids = ref_pair_info.get('ref_input_ids', None)

        # check attribute model has attribute method
        if not hasattr(captum_model, 'attribute'):
            raise AttributeError("Attribute model must have attribute method")
        

        model_embeddings, _, _, _ = get_model_embedding(self.model)

        captum_model_args = inspect.getfullargspec(captum_model).args

        if 'layer' in captum_model_args:
            attr_model = captum_model(self.sc_custom_forward_func, model_embeddings, **captum_model_kwargs)
        else:
            attr_model = captum_model(self.sc_custom_forward_func, **captum_model_kwargs)

        model_attr_info = attr_model.attribute(inputs=input_ids,
                                             baselines=ref_input_ids,
                                             target=target,
                                             additional_forward_args=(token_type_ids, 
                                                                      position_ids, 
                                                                      attention_mask,
                                                                      ),
                                             **attr_model_kwargs
                                             )

        # Handle variable number of outputs
        if not isinstance(model_attr_info, tuple):
            model_attr_info = (model_attr_info,)

        model_attr_embeddings = get_mean_pooling(model_attr_info[0])
        model_attr_info_dict = {
            'model_attr_embeddings': model_attr_embeddings
        }
        
        return model_attr_info_dict, model_attr_info


    def get_detailed_attribution(self, input_str, target, captum_model, captum_model_kwargs, attr_model_kwargs):
        """
        Generate detailed attribution information for the given input using the provided captum model.

        Args:
            input_str (str): The input string to generate attribution for.
            target: The target for the attribution.
            captum_model: The captum model to use for attribution.
            captum_model_kwargs: Additional keyword arguments for the captum model.
            attr_model_kwargs: Additional keyword arguments for the attribution model.

        Returns:
            dict: A dictionary containing detailed attribution information, including words, tokens, and positions.
        """
        
        ref_pair_info = construct_input_ref_pair(self.tokenizer, input_str)

        input_ids = ref_pair_info.get('input_ids', None)
        ref_input_ids = ref_pair_info.get('ref_input_ids', None)

        token_type_ids = ref_pair_info.get('token_type_ids', None)
        ref_token_type_ids = ref_pair_info.get('ref_token_type_ids', None)

        position_ids = ref_pair_info.get('position_ids', None)
        ref_position_ids = ref_pair_info.get('ref_position_ids', None)

        attention_mask = ref_pair_info.get('attention_mask', None)

        # check attribute model has attribute method
        if not hasattr(captum_model, 'attribute'):
            raise AttributeError("Attribute model must have attribute method")
        

        _, word_embeddings, token_type_embeddings, position_embeddings = get_model_embedding(self.model)
        detailed_embeddings = [word_embeddings, token_type_embeddings, position_embeddings]
        clean_detailed_embeddings = [embeddings for embeddings in detailed_embeddings if embeddings is not None]

        captum_model_args = inspect.getfullargspec(captum_model).args

        if 'layer' in captum_model_args:
            detailed_attr_model = captum_model(self.sc_custom_forward_func, clean_detailed_embeddings, **captum_model_kwargs)

        else:
            detailed_attr_model = captum_model(self.sc_custom_forward_func, **captum_model_kwargs)

        if token_type_ids is None:
            updated_token_type_ids = torch.zeros_like(input_ids, dtype=torch.long)
            updated_ref_token_type_ids = torch.zeros_like(ref_input_ids, dtype=torch.long)

        else:
            updated_token_type_ids = token_type_ids
            updated_ref_token_type_ids = ref_token_type_ids

        detailed_model_attr_info = detailed_attr_model.attribute(inputs=(input_ids, 
                                                                         updated_token_type_ids, 
                                                                         position_ids),
                                                                baselines=(ref_input_ids, 
                                                                           updated_ref_token_type_ids, 
                                                                           ref_position_ids),
                                                                target=target,
                                                                additional_forward_args=(attention_mask,),
                                                                **attr_model_kwargs
                                                            )
        if not isinstance(detailed_model_attr_info, tuple):
            detailed_model_attr_info = (detailed_model_attr_info,)

        assert len(detailed_model_attr_info) <=3 , f"""
        Detailed attribution information should have at most 3 elements, 
        but got {len(detailed_model_attr_info)} elements
        """

        if len(detailed_model_attr_info[0])==1:
            detailed_attr_model_words = detailed_model_attr_info[0][0]
            detailed_attr_model_tokens = None
            detailed_attr_model_positions = None

        elif len(detailed_model_attr_info[0])==2:
            detailed_attr_model_words = detailed_model_attr_info[0][0]
            detailed_attr_model_tokens = None
            detailed_attr_model_positions = detailed_model_attr_info[0][1]

        elif len(detailed_model_attr_info[0])==3:
            detailed_attr_model_words = detailed_model_attr_info[0][0]
            detailed_attr_model_tokens = detailed_model_attr_info[0][1]
            detailed_attr_model_positions = detailed_model_attr_info[0][2]

        detailed_model_attr_info_dict = {
            'model_attr_words': detailed_attr_model_words,
            'model_attr_tokens': detailed_attr_model_tokens,
            'model_attr_positions': detailed_attr_model_positions
        }
        
        return detailed_model_attr_info_dict, detailed_model_attr_info
        

class MaskedLanguageModeler:
    """
    A class that encapsulates the functionality of a sequence classifier model.
    It provides methods for tokenizing input, running the model, and calculating gradients.
    """
    def __init__(self, model, tokenizer):
        """
        Initializes a new instance of the SequenceClassifier class.

        Args:
            model (PreTrainedModel): The model to use.
            tokenizer (PreTrainedTokenizer): The tokenizer to use.
        """
        self.model = model
        self.tokenizer = tokenizer
        

    def get_input_tokens(self, input_str: str) -> Tuple[dict, List[str]]:
        """
        Tokenizes the input string using the given tokenizer for sequence classification.

        Args:
            input_str (str): The input string to tokenize.

        Returns:
            Tuple[List[str], dict]: A tuple containing:
            the list of the tokenized inputs and the input tokens.
        """
        tokenized_inputs = self.tokenizer(input_str, return_tensors='pt')
        input_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])
        return tokenized_inputs, input_tokens
    
    def mlm_custom_forward_func(self, input_ids, token_type_ids, position_ids, attention_mask, mask_index=0):
        """
        Custom forward function for MLM model.

        Args:
            self: The object instance.
            input_ids (Tensor): The input IDs for the model.
            token_type_ids (Tensor): The token type IDs for the model.
            position_ids (Tensor): The position IDs for the model.
            attention_mask (Tensor): The attention mask for the model.
            mask_index (int, optional): Index of the mask. Defaults to 0.

        Returns:
            Tensor: The probabilities for the masked token at the specified index.
        """
        init_kwargs = {
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'position_ids': position_ids,
            'attention_mask': attention_mask
        }

        kwargs = get_correct_args(self.model, init_kwargs)
        model_outputs = self.model(**kwargs)
        
        
        mask_token_index = (input_ids == self.tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
        class_logits = model_outputs.get('logits')
        class_probabilities = class_logits.softmax(dim=-1)

        masked_logits = []
        masked_probs = []

        for _, index in enumerate(mask_token_index):
            masked_logits.append(class_logits[0, index].reshape(1,-1))
            masked_probs.append(class_probabilities[0, index].reshape(1,-1))

        probs = masked_probs
        
        return probs[mask_index]
    
    def get_model_attribution(self, input_str:str, target:int, captum_model, captum_model_kwargs, attr_model_kwargs):
        """
        Calculate the attribution of the input to the model's prediction using the Captum library.

        Args:
            input_str (str): The input string to be attributed.
            target (int): The target class for which attribution is calculated.
            captum_model: The Captum model for attribution.
            captum_model_kwargs: Keyword arguments for constructing the Captum model.
            attr_model_kwargs: Keyword arguments for attributing the input to the model.

        Returns:
            dict: A dictionary containing the model attribution information.
        """
        attr_model_kwargs_copy = copy.deepcopy(attr_model_kwargs)

        ref_pair_info = construct_input_ref_pair(self.tokenizer, input_str)

        input_ids = ref_pair_info.get('input_ids', None)
        token_type_ids = ref_pair_info.get('token_type_ids', None)
        position_ids = ref_pair_info.get('position_ids', None)
        attention_mask = ref_pair_info.get('attention_mask', None)

        ref_input_ids = ref_pair_info.get('ref_input_ids', None)

        # check attribute model has attribute method
        if not hasattr(captum_model, 'attribute'):
            raise AttributeError("Attribute model must have attribute method")
        

        model_embeddings, _, _, _ = get_model_embedding(self.model)

        captum_model_args = inspect.getfullargspec(captum_model).args

        if 'layer' in captum_model_args:
            attr_model = captum_model(self.mlm_custom_forward_func, model_embeddings, **captum_model_kwargs)
        else:
            attr_model = captum_model(self.mlm_custom_forward_func, **captum_model_kwargs)

        if attr_model_kwargs_copy.get('mask_index', None) is None:
            raise AttributeError("mask index not specified")
        
        mask_index = attr_model_kwargs_copy.get('mask_index')
        attr_model_kwargs_copy.pop('mask_index', None)

        model_attr_info = attr_model.attribute(inputs=input_ids,
                                             baselines=ref_input_ids,
                                             target=target,
                                             additional_forward_args=(token_type_ids, 
                                                                      position_ids, 
                                                                      attention_mask,
                                                                      mask_index,
                                                                      ),
                                             **attr_model_kwargs_copy
                                             )

        # Handle variable number of outputs
        if not isinstance(model_attr_info, tuple):
            model_attr_info = (model_attr_info,)

        model_attr_embeddings = get_mean_pooling(model_attr_info[0])
        model_attr_info_dict = {
            'model_attr_embeddings': model_attr_embeddings
        }
        
        return model_attr_info_dict, model_attr_info


    def get_detailed_attribution(self, input_str, target, captum_model, captum_model_kwargs, attr_model_kwargs):
        """
        Generate detailed attribution information for the given input using the provided captum model.

        Args:
            input_str (str): The input string to generate attribution for.
            target: The target for the attribution.
            captum_model: The captum model to use for attribution.
            captum_model_kwargs: Additional keyword arguments for the captum model.
            attr_model_kwargs: Additional keyword arguments for the attribution model.

        Returns:
            dict: A dictionary containing detailed attribution information, including words, tokens, and positions.
        """
        attr_model_kwargs_copy = copy.deepcopy(attr_model_kwargs)

        ref_pair_info = construct_input_ref_pair(self.tokenizer, input_str)

        input_ids = ref_pair_info.get('input_ids', None)
        ref_input_ids = ref_pair_info.get('ref_input_ids', None)

        token_type_ids = ref_pair_info.get('token_type_ids', None)
        ref_token_type_ids = ref_pair_info.get('ref_token_type_ids', None)

        position_ids = ref_pair_info.get('position_ids', None)
        ref_position_ids = ref_pair_info.get('ref_position_ids', None)

        attention_mask = ref_pair_info.get('attention_mask', None)

        # check attribute model has attribute method
        if not hasattr(captum_model, 'attribute'):
            raise AttributeError("Attribute model must have attribute method")
        

        _, word_embeddings, token_type_embeddings, position_embeddings = get_model_embedding(self.model)
        detailed_embeddings = [word_embeddings, token_type_embeddings, position_embeddings]
        clean_detailed_embeddings = [embeddings for embeddings in detailed_embeddings if embeddings is not None]

        captum_model_args = inspect.getfullargspec(captum_model).args

        if 'layer' in captum_model_args:
            detailed_attr_model = captum_model(self.mlm_custom_forward_func, clean_detailed_embeddings, **captum_model_kwargs)

        else:
            detailed_attr_model = captum_model(self.mlm_custom_forward_func, **captum_model_kwargs)

        if token_type_ids is None:
            updated_token_type_ids = torch.zeros_like(input_ids, dtype=torch.long)
            updated_ref_token_type_ids = torch.zeros_like(ref_input_ids, dtype=torch.long)

        else:
            updated_token_type_ids = token_type_ids
            updated_ref_token_type_ids = ref_token_type_ids

        if attr_model_kwargs_copy.get('mask_index', None) is None:
            raise AttributeError("mask index not specified")
        
        mask_index = attr_model_kwargs_copy.get('mask_index')
        attr_model_kwargs_copy.pop('mask_index', None)

        detailed_model_attr_info = detailed_attr_model.attribute(inputs=(input_ids, 
                                                                         updated_token_type_ids, 
                                                                         position_ids),
                                                                baselines=(ref_input_ids, 
                                                                           updated_ref_token_type_ids, 
                                                                           ref_position_ids),
                                                                target=target,
                                                                additional_forward_args=(attention_mask, 
                                                                                         mask_index,),
                                                                **attr_model_kwargs_copy
                                                            )
        if not isinstance(detailed_model_attr_info, tuple):
            detailed_model_attr_info = (detailed_model_attr_info,)

        assert len(detailed_model_attr_info) <=3 , f"""
        Detailed attribution information should have at most 3 elements, 
        but got {len(detailed_model_attr_info)} elements
        """

        if len(detailed_model_attr_info[0])==1:
            detailed_attr_model_words = detailed_model_attr_info[0][0]
            detailed_attr_model_tokens = None
            detailed_attr_model_positions = None

        elif len(detailed_model_attr_info[0])==2:
            detailed_attr_model_words = detailed_model_attr_info[0][0]
            detailed_attr_model_tokens = None
            detailed_attr_model_positions = detailed_model_attr_info[0][1]

        elif len(detailed_model_attr_info[0])==3:
            detailed_attr_model_words = detailed_model_attr_info[0][0]
            detailed_attr_model_tokens = detailed_model_attr_info[0][1]
            detailed_attr_model_positions = detailed_model_attr_info[0][2]

        detailed_model_attr_info_dict = {
            'model_attr_words': detailed_attr_model_words,
            'model_attr_tokens': detailed_attr_model_tokens,
            'model_attr_positions': detailed_attr_model_positions
        }
        
        return detailed_model_attr_info_dict, detailed_model_attr_info
        

class QuestionAnswerer:
    """
    A class that encapsulates the functionality of a sequence classifier model.
    It provides methods for tokenizing input, running the model, and calculating gradients.
    """
    def __init__(self, model, tokenizer):
        """
        Initializes a new instance of the SequenceClassifier class.

        Args:
            model (PreTrainedModel): The model to use.
            tokenizer (PreTrainedTokenizer): The tokenizer to use.
        """
        self.model = model
        self.tokenizer = tokenizer
        

    def get_input_tokens(self, question: str, context: str) -> Tuple[dict, List[str]]:
        """
        Tokenizes the input string using the given tokenizer for sequence classification.

        Args:
            input_str (str): The input string to tokenize.

        Returns:
            Tuple[List[str], dict]: A tuple containing:
            the list of the tokenized inputs and the input tokens.
        """
        tokenized_inputs = self.tokenizer(question, context, return_tensors='pt')
        input_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])
        return tokenized_inputs, input_tokens
    
    def qa_custom_forward_func(self, input_ids, token_type_ids, position_ids, attention_mask, start:True):
        """
        Custom forward function for question answering model.

        Args:
            input_ids: Tensor of input token IDs.
            token_type_ids: Tensor of token type IDs.
            position_ids: Tensor of positional IDs.
            attention_mask: Tensor of attention mask.
            start: Boolean flag to indicate whether to process start or end logits.

        Returns:
            Tensor: Probabilities of the predicted tokens.
        """
        init_kwrags = {
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'position_ids': position_ids,
            'attention_mask': attention_mask
        }
        
        kwrags = get_correct_args(self.model, init_kwrags)

        model_output = self.model(**kwrags)
        
        if start:
            logits = model_output.get('start_logits').reshape(1, -1)
            probabilities =logits.softmax(dim=-1)
            
        else:
            logits = model_output.get('end_logits').reshape(1,-1)
            probabilities = logits.softmax(dim=-1)
        return probabilities
    
    def get_model_attribution(self, question:str, context:str, target:int, captum_model, captum_model_kwargs, attr_model_kwargs):
        """
        Calculate the attribution of the input to the model's prediction using the Captum library.

        Args:
            input_str (str): The input string to be attributed.
            target (int): The target class for which attribution is calculated.
            captum_model: The Captum model for attribution.
            captum_model_kwargs: Keyword arguments for constructing the Captum model.
            attr_model_kwargs: Keyword arguments for attributing the input to the model.

        Returns:
            dict: A dictionary containing the model attribution information.
        """
        attr_model_kwargs_copy = copy.deepcopy(attr_model_kwargs)

        ref_pair_info = construct_input_ref_pair_qa(self.tokenizer, question, context)

        input_ids = ref_pair_info.get('input_ids', None)
        token_type_ids = ref_pair_info.get('token_type_ids', None)
        position_ids = ref_pair_info.get('position_ids', None)
        attention_mask = ref_pair_info.get('attention_mask', None)

        ref_input_ids = ref_pair_info.get('ref_input_ids', None)

        # check attribute model has attribute method
        if not hasattr(captum_model, 'attribute'):
            raise AttributeError("Attribute model must have attribute method")
        

        model_embeddings, _, _, _ = get_model_embedding(self.model)

        captum_model_args = inspect.getfullargspec(captum_model).args

        if 'layer' in captum_model_args:
            attr_model = captum_model(self.qa_custom_forward_func, model_embeddings, **captum_model_kwargs)
        else:
            attr_model = captum_model(self.qa_custom_forward_func, **captum_model_kwargs)

        if attr_model_kwargs_copy.get('start', None) is None:
            raise ValueError("start bollean varibale is not provided")

        start = attr_model_kwargs_copy.get('start')
        attr_model_kwargs_copy.pop('start', None)

        model_attr_info = attr_model.attribute(inputs=input_ids,
                                             baselines=ref_input_ids,
                                             target=target,
                                             additional_forward_args=(token_type_ids, 
                                                                      position_ids, 
                                                                      attention_mask,
                                                                      start,
                                                                      ),
                                             **attr_model_kwargs_copy
                                             )

        # Handle variable number of outputs
        if not isinstance(model_attr_info, tuple):
            model_attr_info = (model_attr_info,)

        model_attr_embeddings = get_mean_pooling(model_attr_info[0])
        
        model_attr_info_dict = {
            'model_attr_embeddings': model_attr_embeddings
        }
        
        return model_attr_info_dict, model_attr_info


    def get_detailed_attribution(self, question:str, context:str, target, captum_model, captum_model_kwargs, attr_model_kwargs):
        """
        Generate detailed attribution information for the given input using the provided captum model.

        Args:
            input_str (str): The input string to generate attribution for.
            target: The target for the attribution.
            captum_model: The captum model to use for attribution.
            captum_model_kwargs: Additional keyword arguments for the captum model.
            attr_model_kwargs: Additional keyword arguments for the attribution model.

        Returns:
            dict: A dictionary containing detailed attribution information, including words, tokens, and positions.
        """
        attr_model_kwargs_copy = copy.deepcopy(attr_model_kwargs)

        ref_pair_info = construct_input_ref_pair_qa(self.tokenizer, question, context)

        input_ids = ref_pair_info.get('input_ids', None)
        ref_input_ids = ref_pair_info.get('ref_input_ids', None)

        token_type_ids = ref_pair_info.get('token_type_ids', None)
        ref_token_type_ids = ref_pair_info.get('ref_token_type_ids', None)

        position_ids = ref_pair_info.get('position_ids', None)
        ref_position_ids = ref_pair_info.get('ref_position_ids', None)

        attention_mask = ref_pair_info.get('attention_mask', None)

        # Check attribute model has attribute method
        if not hasattr(captum_model, 'attribute'):
            raise AttributeError("Attribute model must have attribute method")
        

        _, word_embeddings, token_type_embeddings, position_embeddings = get_model_embedding(self.model)
        detailed_embeddings = [word_embeddings, token_type_embeddings, position_embeddings]
        clean_detailed_embeddings = [embeddings for embeddings in detailed_embeddings if embeddings is not None]

        captum_model_args = inspect.getfullargspec(captum_model).args

        if 'layer' in captum_model_args:
            detailed_attr_model = captum_model(self.qa_custom_forward_func, clean_detailed_embeddings, **captum_model_kwargs)

        else:
            detailed_attr_model = captum_model(self.qa_custom_forward_func, **captum_model_kwargs)

        if token_type_ids is None:
            updated_token_type_ids = torch.zeros_like(input_ids, dtype=torch.long)
            updated_ref_token_type_ids = torch.zeros_like(ref_input_ids, dtype=torch.long)

        else:
            updated_token_type_ids = token_type_ids
            updated_ref_token_type_ids = ref_token_type_ids

        if attr_model_kwargs_copy.get('start', None) is None:
            raise ValueError("start bollean varibale is not provided")

        start = attr_model_kwargs_copy.get('start')
        attr_model_kwargs_copy.pop('start', None)

        detailed_model_attr_info = detailed_attr_model.attribute(inputs=(input_ids, 
                                                                         updated_token_type_ids, 
                                                                         position_ids),
                                                                baselines=(ref_input_ids, 
                                                                           updated_ref_token_type_ids, 
                                                                           ref_position_ids),
                                                                target=target,
                                                                additional_forward_args=(attention_mask, 
                                                                                         start,),
                                                                **attr_model_kwargs_copy
                                                            )
        if not isinstance(detailed_model_attr_info, tuple):
            detailed_model_attr_info = (detailed_model_attr_info,)

        assert len(detailed_model_attr_info) <=3 , f"""
        Detailed attribution information should have at most 3 elements, 
        but got {len(detailed_model_attr_info)} elements
        """

        if len(detailed_model_attr_info[0])==1:
            detailed_attr_model_words = detailed_model_attr_info[0][0]
            detailed_attr_model_tokens = None
            detailed_attr_model_positions = None

        elif len(detailed_model_attr_info[0])==2:
            detailed_attr_model_words = detailed_model_attr_info[0][0]
            detailed_attr_model_tokens = None
            detailed_attr_model_positions = detailed_model_attr_info[0][1]

        elif len(detailed_model_attr_info[0])==3:
            detailed_attr_model_words = detailed_model_attr_info[0][0]
            detailed_attr_model_tokens = detailed_model_attr_info[0][1]
            detailed_attr_model_positions = detailed_model_attr_info[0][2]

        detailed_model_attr_info_dict = {
            'model_attr_words': detailed_attr_model_words,
            'model_attr_tokens': detailed_attr_model_tokens,
            'model_attr_positions': detailed_attr_model_positions
        }
        
        return detailed_model_attr_info_dict, detailed_model_attr_info
    



def get_all_info_sc(sc_model:SequenceClassifier, captum_models, input_str:str, logit_indices:List[int], 
                    captum_models_kwargs:List[dict], attr_models_kwargs:List[dict]):
    """
    Generate attribution information for the given input using the provided captum model.

    Args:
        input_str (str): The input string to generate attribution for.
        target: The target for the attribution.
        captum_model: The captum model to use for attribution.
        captum_model_kwargs: Additional keyword arguments for the captum model.
        attr_model_kwargs: Additional keyword arguments for the attribution model.

    Returns:
        dict: A dictionary containing detailed attribution information, including words, tokens, and positions.
    """
    captum_models_attr_info ={}

    assert len(captum_models) == len(captum_models_kwargs) == len(attr_models_kwargs), f"""
    The length of captum_models, captum_models_kwargs, and attr_models_kwargs should be the same, 
    but got {len(captum_models)} and {len(captum_models_kwargs)} and {len(attr_models_kwargs)}
    """

    for captum_model, captum_model_kwargs, attr_model_kwargs in zip(captum_models, captum_models_kwargs, attr_models_kwargs):
        model_attr_info = {}

        for i, logit_index in enumerate(logit_indices):
            sc_model_attr_info_dict, sc_model_attr_info = sc_model.get_model_attribution(input_str, 
                                                                                         logit_index, 
                                                                                         captum_model, 
                                                                                         captum_model_kwargs, 
                                                                                         attr_model_kwargs
                                                                                         )
            model_attr_info[f"attr_info_logit_index_{i}"] = sc_model_attr_info
            model_attr_info[f'model_attr_embeddings_logit_index_{i}'] = sc_model_attr_info_dict.get('model_attr_embeddings', None)

            detailed_sc_model_attr_info_dict, detailed_sc_model_attr_info = sc_model.get_detailed_attribution(input_str,
                                                                                                             logit_indices[i], 
                                                                                                             captum_model, 
                                                                                                             captum_model_kwargs, 
                                                                                                             attr_model_kwargs
                                                                                                             )
            model_attr_info[f'detailed_attr_info_logit_index_{i}'] = detailed_sc_model_attr_info
            model_attr_info[f'model_attr_words_logit_index_{i}'] = detailed_sc_model_attr_info_dict.get('model_attr_words', None)
            model_attr_info[f'model_attr_tokens_logit_index_{i}'] = detailed_sc_model_attr_info_dict.get('model_attr_tokens', None)
            model_attr_info[f'model_attr_positions_logit_index_{i}'] = detailed_sc_model_attr_info_dict.get('model_attr_positions', None)
        
        captum_models_attr_info[f'{captum_model.__name__}_attr_info'] = model_attr_info

    return captum_models_attr_info

def get_all_info_mlm(mlm_model:MaskedLanguageModeler, captum_models, input_str:str, logit_indices:List[int], 
                    captum_models_kwargs:List[dict], attr_models_kwargs:List[dict]):
    """
    Generate attribution information for the given input using the provided captum model.

    Args:
        input_str (str): The input string to generate attribution for.
        target: The target for the attribution.
        captum_model: The captum model to use for attribution.
        captum_model_kwargs: Additional keyword arguments for the captum model.
        attr_model_kwargs: Additional keyword arguments for the attribution model.

    Returns:
        dict: A dictionary containing detailed attribution information, including words, tokens, and positions.
    """
    captum_models_attr_info ={}

    assert len(captum_models) == len(captum_models_kwargs) == len(attr_models_kwargs), f"""
    The length of captum_models, captum_models_kwargs, and attr_models_kwargs should be the same, 
    but got {len(captum_models)} and {len(captum_models_kwargs)} and {len(attr_models_kwargs)}
    """

    for captum_model, captum_model_kwargs, attr_model_kwargs in zip(captum_models, captum_models_kwargs, attr_models_kwargs):
        model_attr_info = {}

        for i, logit_index in enumerate(logit_indices):
            attr_model_kwargs['mask_index'] = i
            mlm_model_attr_info_dict, mlm_model_attr_info = mlm_model.get_model_attribution(input_str, 
                                                                                         logit_index, 
                                                                                         captum_model, 
                                                                                         captum_model_kwargs, 
                                                                                         attr_model_kwargs
                                                                                         )
            model_attr_info[f"attr_info_logit_index_{i}"] = mlm_model_attr_info
            model_attr_info[f'model_attr_embeddings_logit_index_{i}'] = mlm_model_attr_info_dict.get('model_attr_embeddings', None)

            detailed_sc_model_attr_info_dict, detailed_sc_model_attr_info = mlm_model.get_detailed_attribution(input_str,
                                                                                                             logit_indices[i], 
                                                                                                             captum_model, 
                                                                                                             captum_model_kwargs, 
                                                                                                             attr_model_kwargs
                                                                                                             )
            model_attr_info[f'detailed_attr_info_logit_index_{i}'] = detailed_sc_model_attr_info
            model_attr_info[f'model_attr_words_logit_index_{i}'] = detailed_sc_model_attr_info_dict.get('model_attr_words', None)
            model_attr_info[f'model_attr_tokens_logit_index_{i}'] = detailed_sc_model_attr_info_dict.get('model_attr_tokens', None)
            model_attr_info[f'model_attr_positions_logit_index_{i}'] = detailed_sc_model_attr_info_dict.get('model_attr_positions', None)
        
        captum_models_attr_info[f'{captum_model.__name__}_attr_info'] = model_attr_info

    return captum_models_attr_info

def get_all_info_qa(qa_model:QuestionAnswerer, captum_models, question:str, context:str, 
                    start_logit_indices:List[int], end_logit_indices:List[int], 
                    captum_models_kwargs:List[dict], attr_models_kwargs:List[dict]):
    """
    Generate attribution information for the given input using the provided captum model.

    Args:
        input_str (str): The input string to generate attribution for.
        target: The target for the attribution.
        captum_model: The captum model to use for attribution.
        captum_model_kwargs: Additional keyword arguments for the captum model.
        attr_model_kwargs: Additional keyword arguments for the attribution model.

    Returns:
        dict: A dictionary containing detailed attribution information, including words, tokens, and positions.
    """
    captum_models_attr_info ={}

    assert len(captum_models) == len(captum_models_kwargs) == len(attr_models_kwargs), f"""
    The length of captum_models, captum_models_kwargs, and attr_models_kwargs should be the same, 
    but got {len(captum_models)} and {len(captum_models_kwargs)} and {len(attr_models_kwargs)}
    """

    for captum_model, captum_model_kwargs, attr_model_kwargs in zip(captum_models, captum_models_kwargs, attr_models_kwargs):
        start_model_attr_info = {}

        attr_model_kwargs['start'] = True
        for i, logit_index in enumerate(start_logit_indices):
            qa_model_attr_info_dict, qa_model_attr_info = qa_model.get_model_attribution(question, 
                                                                                         context,
                                                                                         logit_index, 
                                                                                         captum_model, 
                                                                                         captum_model_kwargs, 
                                                                                         attr_model_kwargs
                                                                                         )
            start_model_attr_info[f"attr_info_logit_index_{i}"] = qa_model_attr_info
            start_model_attr_info[f'model_attr_embeddings_logit_index_{i}'] = qa_model_attr_info_dict.get('model_attr_embeddings', None)

            detailed_qa_model_attr_info_dict, detailed_qa_model_attr_info = qa_model.get_detailed_attribution(question,
                                                                                                             context,
                                                                                                             logit_index, 
                                                                                                             captum_model, 
                                                                                                             captum_model_kwargs, 
                                                                                                             attr_model_kwargs
                                                                                                             )
            start_model_attr_info[f'detailed_attr_info_logit_index_{i}'] = detailed_qa_model_attr_info
            start_model_attr_info[f'model_attr_words_logit_index_{i}'] = detailed_qa_model_attr_info_dict.get('model_attr_words', None)
            start_model_attr_info[f'model_attr_tokens_logit_index_{i}'] = detailed_qa_model_attr_info_dict.get('model_attr_tokens', None)
            start_model_attr_info[f'model_attr_positions_logit_index_{i}'] = detailed_qa_model_attr_info_dict.get('model_attr_positions', None)
        
        captum_models_attr_info[f'{captum_model.__name__}_attr_info_start'] = start_model_attr_info

        end_model_attr_info = {}

        attr_model_kwargs['start'] = False
        
        for i, logit_index in enumerate(end_logit_indices):
            qa_model_attr_info_dict, qa_model_attr_info = qa_model.get_model_attribution(question, 
                                                                                         context,
                                                                                         logit_index, 
                                                                                         captum_model, 
                                                                                         captum_model_kwargs, 
                                                                                         attr_model_kwargs
                                                                                         )
            end_model_attr_info[f"attr_info_logit_index_{i}"] = qa_model_attr_info
            end_model_attr_info[f'model_attr_embeddings_logit_index_{i}'] = qa_model_attr_info_dict.get('model_attr_embeddings', None)

            detailed_qa_model_attr_info_dict, detailed_qa_model_attr_info = qa_model.get_detailed_attribution(question,
                                                                                                             context,
                                                                                                             logit_index, 
                                                                                                             captum_model, 
                                                                                                             captum_model_kwargs, 
                                                                                                             attr_model_kwargs
                                                                                                             )
            end_model_attr_info[f'detailed_attr_info_logit_index_{i}'] = detailed_qa_model_attr_info
            end_model_attr_info[f'model_attr_words_logit_index_{i}'] = detailed_qa_model_attr_info_dict.get('model_attr_words', None)
            end_model_attr_info[f'model_attr_tokens_logit_index_{i}'] = detailed_qa_model_attr_info_dict.get('model_attr_tokens', None)
            end_model_attr_info[f'model_attr_positions_logit_index_{i}'] = detailed_qa_model_attr_info_dict.get('model_attr_positions', None)


        captum_models_attr_info[f'{captum_model.__name__}_attr_info_end'] = end_model_attr_info

    return captum_models_attr_info
