""" """
from typing import List
import numpy as np
from transformers_interpret import SequenceClassificationExplainer, QuestionAnsweringExplainer, ZeroShotClassificationExplainer


class SequenceClassifierAttributions:
    """
    A class that encapsulates the functionality of a sequence classifier model.
    It provides methods for tokenizing input, running the model, and calculating attributions.
    """
    def __init__(self, model, tokenizer, attribution_type="lig"):
        """
        Initializes a new instance of the SequenceClassifier class.

        Args:
            model (PreTrainedModel): The model to use.
            tokenizer (PreTrainedTokenizer): The tokenizer to use.
        """
        self.model = model
        self.tokenizer = tokenizer
        self.attribution_type = attribution_type

    def get_sc_attributions(self, input_str:str, logits_indices:List[int]):
        """
        Generates the attributions for each word in the input string using the given class explainer and class name.

        Args:
            cls_explainer (callable): The class explainer function that calculates the attributions.
            input_str (str): The input string to generate attributions for.
            class_name (str): The name of the class to use for generating the attributions.

        Returns:
            tuple: A tuple containing the following:
                - input_tokens (list): A list of tokens in the input string.
                - input_attributions (np.ndarray): An array of attributions for each token in the input string.
                - normalized_input_attributions (np.ndarray): An array of normalized attributions for each token in the input string.
        """
        custom_labels = [str(i) for i in logits_indices]

        tokenized_inputs = self.tokenizer(input_str, return_tensors='pt')
        input_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])

        sc_explainer = SequenceClassificationExplainer(self.model, self.tokenizer, self.attribution_type, custom_labels=custom_labels)

        attribution_info = {}

        model_attributions_info= {}
        model_normalized_attributions_info = {}

        for i, label in enumerate(custom_labels):
            token_attributions = sc_explainer(text=input_str, class_name=label)
            input_attributions = np.array([token_attribution[1] for token_attribution in token_attributions])
            normalized_input_attributions = get_normalized_attribution(input_attributions)

            model_attributions_info[f'attributions_label_{logits_indices[i]}'] = input_attributions
            model_normalized_attributions_info[f'normalized_attributions_label_{logits_indices[i]}'] = normalized_input_attributions

        attribution_info['model_attributions_info'] = model_attributions_info
        attribution_info['model_normalized_attributions_info'] = model_normalized_attributions_info
        attribution_info['input_tokens'] = input_tokens
        attribution_info['input_str'] = input_str
        attribution_info['logits_indices'] = logits_indices
        attribution_info['attribution_type'] = self.attribution_type

        return attribution_info


class QuestionAnswererAttributions:
    """
    A class that encapsulates the functionality of a question answerer model.
    It provides methods for tokenizing input, running the model, and calculating attributions.
    """
    def __init__(self, model, tokenizer, attribution_type="lig"):
        """
        Initializes a new instance of the class.

        Args:
            model (object): The model object to be used.
            tokenizer (object): The tokenizer object to be used.
            attribution_type (str, optional): The type of attribution. Defaults to "lig".

        Returns:
            None
        """
        self.model = model
        self.tokenizer = tokenizer
        self.attribution_type = attribution_type


    def get_qa_attributions(self, question, context):
        """
        Retrieves the token attributions for a given question and context.

        Args:
            qa_explainer (callable): The question answering explainer function.
            question (str): The question to be answered.
            context (str): The context in which the question is being asked.

        Returns:
            Tuple: A tuple containing the following:
                - input_tokens (List[str]): The input tokens.
                - input_attributions_start (ndarray): The start attributions for each input token.
                - normalized_input_attributions_start (ndarray): The normalized start attributions for each input token.
                - input_attributions_end (ndarray): The end attributions for each input token.
                - normalized_input_attributions_end (ndarray): The normalized end attributions for each input token.

        Raises:
            AssertionError: If the number of tokens in the start and end attributions is not the same.
        """
        qa_explainer =  QuestionAnsweringExplainer(self.model, self.tokenizer, self.attribution_type)
        token_attributions = qa_explainer(question,context)

        token_attributions_start = token_attributions.get("start")
        input_tokens_start = [token_attribution_start[0] for token_attribution_start in token_attributions_start]
        input_attributions_start = np.array([token_attribution_start[1] for token_attribution_start in token_attributions_start])
        normalized_input_attributions_start = get_normalized_attribution(input_attributions_start)

        token_attributions_end = token_attributions.get("end")
        input_tokens_end = [token_attribution_end[0] for token_attribution_end in token_attributions_end]
        input_attributions_end = np.array([token_attribution_end[1] for token_attribution_end in token_attributions_end])
        normalized_input_attributions_end = get_normalized_attribution(input_attributions_end)

        assert len(input_tokens_start) == len(input_tokens_end) , "The number of tokens in the start and end attributions should be the same"
        input_tokens = input_tokens_start

        model_attributions_info = {
            'attributions_start': input_attributions_start,
            'attributions_end': input_attributions_end
        }

        model_normalized_attributions_info = {
            'attributions_start': normalized_input_attributions_start,
            'attributions_end': normalized_input_attributions_end
        }
        attribution_info = {
            'question': question,
            'context': context,
            'input_tokens': input_tokens,
            'model_attributions_info': model_attributions_info,
            'model_normalized_attributions_info': model_normalized_attributions_info,
            'attribution_type': self.attribution_type
        }
        
        return attribution_info


class ZeroShotClassifierAttributions:
    """
    A class that encapsulates the functionality of a zero-shot classification (ZSC) model.
    It provides methods for tokenizing input, running the model, and calculating attributions.
    """
    def __init__(self, model, tokenizer, attribution_type="lig"):
        """
        Initializes a new instance of the class.

        Args:
            model (object): The model object to be used.
            tokenizer (object): The tokenizer object to be used.
            attribution_type (str, optional): The type of attribution. Defaults to "lig".

        Returns:
            None
        """
        self.model = model
        self.tokenizer = tokenizer
        self.attribution_type = attribution_type

    def get_zsc_attributions(self, input_str:str, labels:List[str], logits_indices:List[int]):
        """
        Get the attributions for each token in the input string using the provided zero-shot classification (ZSC) explainer.

        Parameters:
        - zsc_explainer: A function that takes an input string and a list of labels as input and returns the token attributions.
        - input_str: The input string for which to compute the token attributions.
        - labels: A list of strings representing the possible labels for the input string.
        - logits_index: An integer representing the index of the label for which to compute the token attributions.

        Returns:
        - input_tokens: A list of strings representing the tokens in the input string.
        - input_attributions: An array of floats representing the attributions for each token in the input string.
        - normalized_input_attributions: An array of floats representing the normalized attributions for each token in the input string.
        """

        zsc_explainer =  ZeroShotClassificationExplainer(self.model, self.tokenizer, self.attribution_type)

        attribution_info = {}

        model_attributions_info = {}
        model_normalized_attributions_info = {}
        for i in range(len(logits_indices)):
            token_attributions = zsc_explainer(input_str, labels).get(labels[i])
            input_tokens = [token_attribution[0].replace("▁", "") for token_attribution in token_attributions]
            input_attributions = np.array([token_attribution[1] for token_attribution in token_attributions])
            normalized_input_attributions = get_normalized_attribution(input_attributions)

            if i==0:
                attribution_info["input_tokens"] = input_tokens

            model_attributions_info[f"attributions_label_{i}"] = input_attributions
            model_normalized_attributions_info[f"normalized_attributions_label_{i}"] = normalized_input_attributions

        attribution_info["model_attributions_info"] = model_attributions_info
        attribution_info["model_normalized_attributions_info"] = model_normalized_attributions_info
        attribution_info["input_str"] = input_str
        attribution_info["labels"] = labels
        attribution_info["logits_indices"] = logits_indices
        attribution_info['attribution_type'] = self.attribution_type

        return attribution_info



def get_normalized_attribution(input_attributions, epsilon=1e-8, use_absolute=False):
    """
    Normalize the attributions for each token in the input string.

    Parameters:
    - input_attributions: An array of floats representing the attributions for each token in the input string.

    Returns:
    - normalized_input_attributions: An array of floats representing the normalized attributions for each token in the input string.
    """
    if use_absolute:
        input_attributions = np.absolute(input_attributions)
    normalized_input_attributions = (input_attributions - input_attributions.min()+epsilon) / (input_attributions.max() - input_attributions.min()+epsilon)
    return normalized_input_attributions
