from cell.algorithms.lbbe import LocalBBExplainer
from cell.algorithms.infilling_utils import BART_infiller, T5_infiller
from cell.algorithms.metrics_cem import metric_preference, metric_nli, metric_contradiction, metric_bleu

import numpy as np
import scipy as sp
import torch
import re
import os
import evaluate
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForSequenceClassification

class CELL(LocalBBExplainer):
    def __init__(self, model, infiller='bart', num_return_sequences=1, metric='shp', metric_type='distance', metric_path=None, generation=True, experiment_id='id'):
        """Initialize contrastive explainer.
            model: model that we want to explain (must have a model.generate function that inputs and outputs a str)
            tokenizer (huggingface tokenizer): tokenizer for model that we want to explain
            num_return_sequences (int): number of sequences returned when doing generation for mask infilling
            metric (str): select which metric to use to determine if a contrast is found (must be from ['shp', 'nli', 'bleu', 'implicit_hate', 'stigma']
            metric_type (str): 'distance' for explaining LLM generation using distances, 'classifier' for explaining a classifier
            metric_path (str): path for loading a metric model. only needed for certain classifiers
            generation (bool): the model being explained performs true generation (as opposed to having output==input)
            experiment_id (str): passed to evaluate.load for certain metrics. This is used if several distributed evaluations share the same file system.
        """
        self._model = model
        self._num_return_sequences = num_return_sequences

        if infiller == 'bart':
            self.infiller = BART_infiller.BART_infiller()
        elif infiller == 't5':
            self.infiller = T5_infiller.T5_infiller()
        else:
            raise Exception("CELL received parameter value for infiller that is not recognized")

        self._metric_name  = metric
        self._metric_type = metric_type
        if metric == 'shp':
            self._metric_func = metric_preference()
        elif metric == 'nli':
            self._metric_func = metric_nli()
        elif metric == 'contradiction':
            self._metric_func = metric_contradiction()
        elif metric == 'bleu':
            self._metric_func = metric_bleu(experiment_id=experiment_id)
        elif metric == 'implicit_hate':
            if metric_path is None:
                self._metric_func = metric_implicit_hate()
            else:
                self._metric_func = metric_implicit_hate(model_path=metric_path)
        elif metric == 'stigma':
            if metric_path is None:
                self._metric_func = metric_stigma()
            else:
                self._metric_func = metric_stigma(model_path=metric_path)
        else:
            print('INVALID METRIC')
        self._generation = generation

    def splitTextByK(self, str, k):
        sentences_iter = re.finditer(r"[.!?;]", str)
        grouped_words = []
        start=0
        for sentence_iter in sentences_iter:
            end = sentence_iter.end()
            sentence = str[start:end].strip()
            words = sentence.split(' ')
            grouped_words.extend([' '.join(words[i: i + k]) for i in range(0, len(words), k)])
            start = end
        if start == 0: # special case for no punctuations found
            words = str.split(' ')
            grouped_words.extend([' '.join(words[i: i + k]) for i in range(0, len(words), k)])
        return grouped_words

    def explain_instance(self, input_text, epsilon_contrastive=.5, epsilon_iter=.001, split_k=1, no_change_max_iters=3, info=True, ir=False):
        """
        Provide explanations of large language model applied to prompt input_text
        Provide a contrastive explanation by changing prompt input_text such that the
        new prompt generates a response that is preferred as a response to input_text much
        less by a certain amount

        input_text (str): input prompt to model that we want to explain
        epsilon_contrastive (float): amount of change in response to deem a contrastive explanation
        epsilon_iter (float): minimum amount of change between iterations to continue search
        split_k (int): number of words to be split into each token that is masked together
        info (boolean): True if to print output information, False otherwise
        ir (boolean): True if to do input reduction, i.e., remove tokens that cause minimal change to response
                        until a large change occurs
        """

        if info:
            if ir:
                print('Starting Input Reduction')
            else:
                print('Starting Contrastive Explanation Method')

        if self._metric_type == 'classifier':
            if self._metric_name == 'implicit_hate' or self._metric_name == 'stigma':
                (scores_input_text, label_input_text) = self._metric_func.metric(input_text,  input_text, input_text, input_text, -1)
            else:
                print('INVALID METRIC FOR CLASSIFICATION TASK')
        else:
            scores_input_text = 0
            label_input_text = -1

        output_text = self._model.generate(input_text) # output from input text prompt

        input_tokens = self.splitTextByK(input_text, split_k)
        num_input_tokens = len(input_tokens)

        tokens_changed = np.zeros((num_input_tokens,1)) # keep track of which tokens have been modified
        modify_token = True
        input_tokens_curr = input_tokens.copy()
        iters = 0
        scores_max_prev = 0
        count_no_change = 0
        mask_order = [] # keep track of order of tokens being masked
        masks_optimal = [] # keep track of the tokens that masked
        modifications_optimal = [] # keep track of the modifications made
        num_model_calls = 0
        while modify_token:
            print('Running iteration '+str(iters+1))
            inds_modify = np.where(tokens_changed == 0)[0] # tokens that have not yet been modified

            num_input_modify = len(inds_modify)
            scores = np.zeros((num_input_modify,1))
            scores_abs = np.zeros((num_input_modify,1))
            labels_contrast = np.zeros((num_input_modify,)) # for classification tasks
            prompts_modified = {}
            responses_modified = {}
            prompts_masked_enc = {}
            prompts_modified_enc = {}
            mask_filled_dico = {}
            for i in range(num_input_modify):
                input_tokens_mask = input_tokens_curr.copy()
                input_tokens_mask[inds_modify[i]] = self.infiller.mask_string
                input_text_mask = ' '.join(input_tokens_mask)

                batch = self.infiller.encode(input_text_mask, add_special_tokens=True)
                (generated_ids, mask_filled) = self.infiller.generate(batch, masked_word=input_tokens_curr[inds_modify[i]], num_return_sequences=self._num_return_sequences, return_mask_filled=True)
                input_text_infilled = self.infiller.decode(generated_ids)

                # these encodings are used later to find what was infilled for mask
                prompts_masked_enc[i] = batch
                prompts_modified_enc[i] = generated_ids
                mask_filled_dico[i] = mask_filled

                prompts_modified[i] = input_text_infilled
                output_infilled_text = self._model.generate(input_text_infilled) # output from modified input text prompt
                num_model_calls += 1
                responses_modified[i] = output_infilled_text

                (score_temp, label_temp) = self._metric_func.metric(input_text,  output_text, input_text_infilled, output_infilled_text, input_label=label_input_text)
                scores[i] = score_temp
                labels_contrast[i] = label_temp
                if self._metric_type == 'distance':
                    scores_abs[i] = np.abs(scores[i]) # measure the absolute difference
                else: # metric_type is classifier so always want to measure in one direction
                    scores[i] = scores_input_text - scores[i] # classification always measures difference from input text score
                    scores_abs[i] = scores[i]

            if ir:
                inds_max = np.argmin(scores_abs)
            else:
                inds_max = np.argmax(scores_abs)

            scores_max = scores_abs[inds_max]
            tokens_changed[inds_modify[inds_max]] = 1
            mask_order.append(inds_modify[inds_max])
            # find what replaced the <mask>

            mask_filled = mask_filled_dico[inds_max]

            token_to_modify = input_tokens_curr[inds_modify[inds_max]]
            modifications_optimal.append(input_tokens_curr[inds_modify[inds_max]]+'->'+mask_filled)
            input_tokens_curr[inds_modify[inds_max]] = mask_filled
            masks_optimal.append(mask_filled)

            if ir:
                if scores_max > epsilon_contrastive and iters < (num_input_tokens-1):
                    modify_token = False
                    # remove previous modifications
                    input_tokens_curr[inds_modify[inds_max]] = token_to_modify
                    mask_order = mask_order[:-1]
                    masks_optimal = masks_optimal[:-1]
                    modifications_optimal = modifications_optimal[:-1]
            elif self._metric_type == 'classifier':
                if np.abs(scores_max-scores_max_prev) <= epsilon_iter:
                    count_no_change += 1
                else:
                    count_no_change = 0
                if labels_contrast[inds_max] != label_input_text or count_no_change >= no_change_max_iters:
                    modify_token = False
                    if info:
                        if labels_contrast[inds_max] != label_input_text:
                            print('Stopping because initial classification has changed')
                        elif count_no_change >= no_change_max_iters:
                            print('Stopping because no significant change has occurred in '+str(no_change_max_iters)+ ' iterations.')
            else: # metric_type is distance
                if np.abs(scores_max-scores_max_prev) <= epsilon_iter:
                    count_no_change += 1
                else:
                    count_no_change = 0
                if scores_max > epsilon_contrastive or count_no_change >= no_change_max_iters:
                    modify_token = False
                    if info:
                        if scores_max > epsilon_contrastive:
                            print('Stopping because contrastive threshold has been passed')
                        elif count_no_change >= no_change_max_iters:
                            print('Stopping because no significant change has occurred in '+str(no_change_max_iters)+ ' iterations.')

            if iters >= (num_input_tokens-1):
                modify_token = False
                if info:
                    print('Modified all tokens.')
            scores_max_prev = scores_max
            iters += 1

        prompt_contrastive = ' '.join(input_tokens_curr)

        if info:
            print(str(num_model_calls) + ' model calls made.')
            if ir:
                print('Input Reduction Solution')
            else:
                print('Contrastive Explanation Solution')
            print('Metric: '+ self._metric_name)
            print('Input prompt: ' + input_text)
            if self._generation:
                print('Input response: ' + output_text)
            print('Contrastive prompt: ' + prompt_contrastive)
            if self._generation:
                print('Contrastive response: ' + responses_modified[inds_max])
            print('Modifications made: ' + ', '.join(modifications_optimal))
            if self._metric_name == 'shp':
                if scores[inds_max] > 0:
                    print('Preference decreased.')
                elif scores[inds_max] < 0:
                    print('Preference increased.')
                else:
                    print('Prefence remained the same.')
            elif self._metric_name == 'nli' or self._metric_name == 'contradiction':
                (score_temp, label_temp) = self._metric_func.metric(input_text,  output_text, input_text, responses_modified[inds_max], input_label=label_input_text, info=True)
            elif self._metric_name == 'bleu':
                print('BLEU score of difference in responses is larger than threshold.')
            elif self._metric_name == 'implicit_hate' or self._metric_name == 'stigma':
                print('Initial label: ' + self._metric_func._model.config.id2label[label_input_text])
                print('Contrast label: ' + self._metric_func._model.config.id2label[labels_contrast[inds_max]])
            else:
                    print('INVALID METRIC')

        return prompt_contrastive, input_tokens_curr, mask_order, masks_optimal