from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import evaluate
import numpy as np
import scipy as sp
import torch
import os


class metric_preference():
    def __init__(self, model_path='stanfordnlp/SteamSHP-flan-t5-large', device='cuda'):
        """
        Initialize SHP metric object for generation.
        """
        self._tokenizer = T5Tokenizer.from_pretrained(model_path)
        self._model = T5ForConditionalGeneration.from_pretrained(model_path).to(device)
        self._device = device

    def metric(self, input_text, output_text, input_text_infilled='', output_text_infilled='', input_label=0, info=False):
        if output_text != output_text_infilled: # only check preference if the responses are different
            # run preference model with two outputs
            input_text_preference = 'POST: ' + input_text + '\n\n RESPONSE A: ' + output_text + '. \n\n RESPONSE B: ' + output_text_infilled + '. \n\n Which response is better? RESPONSE'
            # remove special tokens that are due to model being explained
            input_text_preference = input_text_preference.replace('<pad>','')
            input_text_preference = input_text_preference.replace('</s>','')
            x = self._tokenizer([input_text_preference], return_tensors='pt').input_ids.to(self._device)
            y = self._model.generate(x, return_dict_in_generate=True, output_scores=True, max_new_tokens=1)

            score_a = torch.exp(y.scores[0][:, 71]) / torch.exp(y.scores[0][:,:]).sum(axis=1).item()
            score_b = torch.exp(y.scores[0][:, 272]) / torch.exp(y.scores[0][:,:]).sum(axis=1).item()
            score = (score_a-score_b).item() # measure the difference in preference for a response generated from changing the prompt
        else:
            score = 0.0 # assume no preference between two equivalent responses
        label_contrast = 0

        return (score, label_contrast)

class metric_nli():
    def __init__(self, model_path='cross-encoder/nli-roberta-base', device='cuda'):
        """
        Initialize NLI metric object for generation.
        """
        self._model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self._tokenizer = AutoTokenizer.from_pretrained(model_path)
        self._device = device

    def metric(self, input_text, output_text, input_text_infilled='', output_text_infilled='', input_label=0, info=False):
        # run nli model with two outputs
        features_nli = self._tokenizer([input_text, input_text],[output_text, output_text_infilled],  padding=True, truncation=True, return_tensors="pt")
        self._model.eval()

        with torch.no_grad():
            scores_nli = self._model(**features_nli).logits
        if info == False:
            scores_np = sp.special.softmax(scores_nli.numpy(), axis=1)
            lab = np.argmax(scores_np[0,:])
            score = scores_np[0,lab] - scores_np[1,lab] # change in class
            label_contrast = 0
        else:
            label_mapping = ['contradiction', 'entailment', 'neutral']
            labels = [label_mapping[score_max] for score_max in scores_nli.argmax(dim=1)]
            print('NLI initial prediction: ' + labels[0])
            print('NLI modified prediction: ' + labels[1])
            score = 0
            label_contrast = 0

        return (score, label_contrast)

class metric_contradiction():
    def __init__(self, model_path='cross-encoder/nli-roberta-base', device='cuda'):
        """
        Initialize NLI metric object for generation.
        """
        self._model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self._tokenizer = AutoTokenizer.from_pretrained(model_path)
        self._device = device

    def metric(self, input_text, output_text, input_text_infilled='', output_text_infilled='', input_label=0, info=False):
        # run nli model with two outputs
        features_nli = self._tokenizer([output_text, output_text],[output_text, output_text_infilled],  padding=True, truncation=True, return_tensors="pt")
        self._model.eval()

        with torch.no_grad():
            scores_nli = self._model(**features_nli).logits
        if info == False:
            scores_np = sp.special.softmax(scores_nli.numpy(), axis=1)
            score =  scores_np[1,0] - scores_np[0,0] # change in contradiction class
            label_contrast = 0
        else:
            label_mapping = ['contradiction', 'entailment', 'neutral']
            labels = [label_mapping[score_max] for score_max in scores_nli.argmax(dim=1)]
            print('NLI initial prediction: ' + labels[0])
            print('NLI modified prediction: ' + labels[1])
            score = 0
            label_contrast = 0

        return (score, label_contrast)

class metric_bleu():
    def __init__(self, model_path='cross-encoder/nli-roberta-base', device='cuda', experiment_id='id'):
        """
        Initialize bleu metric object.
        """
        self._bleu = evaluate.load("bleu", experiment_id=experiment_id)
        self._device = device

    def metric(self, input_text, output_text, input_text_infilled='', output_text_infilled='', input_label=0, info=False):
        predictions = [output_text]
        references = [[output_text_infilled]]
        results_response = self._bleu.compute(predictions=predictions, references=references)
        predictions = [input_text]
        references = [[input_text_infilled]]
        results_prompt = self._bleu.compute(predictions=predictions, references=references)
        score_response = 1. - results_response['bleu'] # subtract from 1. because we want higher score to mean more dissimilar
        score_prompt = results_prompt['bleu'] # we want prompt to change as little as possible
        score = 0.8*score_response+0.2*score_prompt # compute weighting of both bleu scores
        label_contrast = 0

        return (score, label_contrast)