from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import Dataset
import numpy as np
from tqdm import tqdm
import os
from openai import AzureOpenAI, AsyncAzureOpenAI
from sklearn.metrics.pairwise import cosine_similarity
from tabulate import tabulate
from sentence_transformers import SentenceTransformer
import openai
from azure.identity import CredentialUnavailableError
import asyncio
from openai import RateLimitError
from openai import APITimeoutError, BadRequestError, AuthenticationError
import backoff
import logging
import os
import torch
import string


os.environ["TOKENIZERS_PARALLELISM"] = "false"
attribute_list = ["helpfulness", "relevance", "correctness", "coherence", "complexity", "verbosity", "neutrality", "appropriateness", "assertiveness", "harmlessness", \
                               "sensitivity", "engagement", "avoid to answer", "informativeness", "clarity"]
attribute_descriptions = {
        "helpfulness": "whether or not the response addresses the points raised in the question",
        "relevance": "whether or not the response is in a relevant context as in the question",
        "correctness": "whether or not the response is factually correct",
        "coherence": "whether or not the contents in the response is self-contained and clear",
        "complexity": "the intellectual burden required by a person to understand this response",
        "verbosity": "how many relevant details are included in the response",
        "neutrality": "whether or not the response is neutral and is without biases towards certain groups", 
        "appropriateness": "the extent to which the response is appropriate in terms of language style, politeness, and whether it contains any sarcasm", 
        "assertiveness": "the extent to which the response sounds very certain and contains judgements",
        "harmlessness": "whether or not the response is relevant to any potentially unsafe, immoral or illegal behaviours",
        "sensitivity": "whether or not the response is relevant to any personal, sensitive, or private information",
        "engagement": "the extent to which the language style of the response is trying to engage with the person who wrote the question", 
        "avoid to answer": "whether or not the response is avoiding to give direct answers to the question",
        "informativeness": "whether or not the response provides informative knowledge",
        "clarity": "whether or not the response is clear and easy to read"
    }

refreshing_token = "ANONYMISED"

client = AzureOpenAI(
    azure_endpoint="ANONYMISED",
    api_key=refreshing_token.get_token(),
    api_version="ANONYMISED",
)
model = "gpt-4o-2024-05-13"
gptmodel = "gpt-4o-2024-05-13"


class TestSet:
    def __init__(self, data_path, data_name, rm, tok, num_test_points=100, load_file=True, run_num=10, device="cuda", model_type="deberta"):
        
        self.d = Dataset.from_file(data_path)
        dataset_length = self.d.num_rows
        self.data_name = data_name
        self.device = device
        self.model_type = model_type
        single_idxs = []
        test_name = f"testsets/{data_name}_run_{run_num}.npy"
        
        # process helpsteer dataset, filter for single turn conversations where response preference is clear
        if "helpsteer" in data_name.lower():
            multi_idxs = []
            for i in range(self.d.num_rows):
                if "Assistant" in self.d[i]['prompt'] and "User" in self.d[i]['prompt']:
                    multi_idxs.append(i)
            single_idxs = [i for i in range(self.d.num_rows) if i not in multi_idxs]
            self.d = self.d.select(single_idxs)
            score_attr = ['helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity']
            sum_count = 0
            all_count = 0
            clear_count = 0
            clear_idxs = []
            for i in range(self.d.num_rows - 1):
                if i % 2 == 1:
                    continue
                scores1 = []
                scores2 = []
                for attr in score_attr:
                    scores1.append(int(self.d[i][attr]))
                    scores2.append(int(self.d[i+1][attr]))
                snp1 = np.array(scores1)
                snp2 = np.array(scores2)
                sum_better = False
                all_position_better = False
                if np.sum(snp1) != np.sum(snp2):
                    sum_better = True
                    sum_count += 1
                if np.sum(snp1) != np.sum(snp2):
                    if np.sum(snp1) > np.sum(snp2):
                        if np.sum(snp1<snp2) == 0:
                            clear_idxs.append(i)
                            clear_idxs.append(i+1)
                            clear_count += 1
                    else:
                        if np.sum(snp2<snp1) == 0:
                            clear_idxs.append(i)
                            clear_idxs.append(i+1)
                            clear_count += 1
            self.d = self.d.select(clear_idxs)

            num_idxs_max = int(self.d.num_rows / 2) - 1
            if num_idxs_max < num_test_points:
                num_test_points = num_idxs_max - 1
            
            # randomly select single turn test points
            if load_file and os.path.isfile(test_name):
                single_idxs_test = np.load(test_name)
                self.d = self.d.select(single_idxs_test)
            else:
                single_idxs_test_one = np.random.choice(np.arange(num_idxs_max), num_test_points, replace=False)
                single_idxs_test = []
                for n in single_idxs_test_one:
                    single_idxs_test.append(n*2)
                    single_idxs_test.append(n*2+1)
                self.d = self.d.select(single_idxs_test)
                np.save(test_name, np.array(single_idxs_test))

        if "rlhf" in data_name.lower():
            for i in range(dataset_length):
                sample_text = self.d[i]["chosen"]
                if sample_text.count("n: ") == 1 and sample_text.count("t: ") == 1:
                    single_idxs.append(i)
                    
            if len(single_idxs) < num_test_points:
                num_test_points = len(single_idxs) - 1

            # randomly select single turn test points
            if load_file and os.path.isfile(test_name):
                single_idxs_test = np.load(test_name)
                self.d = self.d.select(single_idxs_test)
                print("Loading test sets...")
            else:
                single_idxs_test = np.random.choice(single_idxs, num_test_points, replace=False)
                self.d = self.d.select(single_idxs_test)
                np.save(test_name, single_idxs_test)

                print("Constructing test sets...")
        
        # extract questions, answers, and model scores
        self.questions = []
        self.pref_anss = []
        self.rej_anss = []
        self.pref_scores = []
        self.rej_scores = []
        self.corrects = []
        
        # extract components for each comparison in the test sets
        if "rlhf" in data_name.lower():
            for i in tqdm(range(self.d.num_rows)):
                question, preferred_ans, rej_ans, score_model_pref, score_model_rej, correct_ans = self.get_q_a_s(i, rm, tok)
                self.questions.append(question)
                self.pref_anss.append(preferred_ans)
                self.rej_anss.append(rej_ans)
                self.pref_scores.append(score_model_pref)
                self.rej_scores.append(score_model_rej)
                self.corrects.append(correct_ans)
        
        if "helpsteer" in data_name.lower():
            for i in tqdm(range(self.d.num_rows)):
                if i % 2 == 1:
                    continue
                # dataset score
                scores1 = []
                scores2 = []
                for attr in score_attr:
                    scores1.append(int(self.d[i][attr]))
                    scores2.append(int(self.d[i+1][attr]))
                snp1 = np.array(scores1)
                snp2 = np.array(scores2)

                # model score
                score1 = get_model_score(self.d[i]["prompt"], self.d[i]["response"], rm, tok, self.device, self.model_type)
                score2 = get_model_score(self.d[i]["prompt"], self.d[i+1]["response"], rm, tok, self.device, self.model_type)
                if score1 >= score2:
                    if np.sum(snp1) > np.sum(snp2):
                        self.corrects.append(True)
                    else:
                        self.corrects.append(False)
                    self.pref_anss.append(self.d[i]["response"])
                    self.pref_scores.append(score1)
                    self.rej_anss.append(self.d[i+1]["response"])
                    self.rej_scores.append(score2)
                else:
                    if np.sum(snp1) < np.sum(snp2):
                        self.corrects.append(True)
                    else:
                        self.corrects.append(False)
                    self.pref_anss.append(self.d[i+1]["response"])
                    self.pref_scores.append(score2)
                    self.rej_anss.append(self.d[i]["response"])
                    self.rej_scores.append(score1)
                self.questions.append(self.d[i]["prompt"])
        torch.cuda.empty_cache()
                
    def get_q_a_s(self, idx, rank_model, tokenizer):
        split_response = self.d[idx]["chosen"].split("Assistant:", 1)
        preferred_ans = split_response[1].strip()
        question = split_response[0].split("Human:", 1)[1][:-2].strip()
        rej_ans = self.d[idx]["rejected"].split("Assistant:", 1)[1].strip()
        input_dataset_pref = None
        input_dataset_rej = None
        if "deberta" in self.model_type:
            input_dataset_pref = tokenizer(question, preferred_ans, return_tensors='pt').to(self.device)
            input_dataset_rej = tokenizer(question, rej_ans, return_tensors='pt').to(self.device)
        if "pythia" in self.model_type:
            xy_pref = f"<|prompter|>{question}<|endoftext|><|assistant|>{preferred_ans}<|endoftext|>"
            xy_rej = f"<|prompter|>{question}<|endoftext|><|assistant|>{rej_ans}<|endoftext|>"
            input_dataset_pref = tokenizer(xy_pref, return_tensors='pt').to(self.device)
            input_dataset_rej = tokenizer(xy_rej, return_tensors='pt').to(self.device)
        score_dataset_pref = rank_model(**input_dataset_pref).logits[0].cpu().detach()
        score_dataset_rej = rank_model(**input_dataset_rej).logits[0].cpu().detach()
        input_dataset_pref.to("cpu")
        input_dataset_rej.to("cpu")
        # if the model ends up preferring the rejected answer, swap name of two answers.
        correct_ans = True
        if score_dataset_pref <= score_dataset_rej:
            temp_ans = preferred_ans
            preferred_ans = rej_ans
            rej_ans = temp_ans
            correct_ans = False
            temp_score = score_dataset_pref
            score_dataset_pref = score_dataset_rej
            score_dataset_rej = temp_score
        return question, preferred_ans, rej_ans, score_dataset_pref, score_dataset_rej, correct_ans


def get_model_score(x, y, rm, tok, device, model_type="deberta"):
    inputs = None
    if model_type == "deberta":
        inputs = tok(x, y, return_tensors='pt')
    else:
        xy = f"<|prompter|>{x}<|endoftext|><|assistant|>{y}<|endoftext|>"
        inputs = tok(xy, return_tensors='pt')
    inputs.to(device)
    score = rm(**inputs).logits[0].cpu().detach()[0]
    inputs.to("cpu")
    return score


async def query_for_attributes(question, response1, response2, score1, score2, client, refreshing_token):
    direction1 = "better" if score1 > score2 else "worse"
    prompt = f"In the task of response quality scoring, a trained deep learning model assigns real-valued scores for responses to questions. \
    The higher the score, the better the response quality. \
    The question is '{question}'. The model assigned a score {score1} for response A: '{response1}'. The model assigned a score {score2} for response B: '{response2}' \
    List out some high-level attributes which might have caused the model to assign a {direction1} score for response A than response B. Some example attributes are: appropriateness, clarity, harmlessness, verbosity, etc. Only output the attributes in a comma-separated list."
    ans = await request_llm(client, refreshing_token, prompt, "gpt-4o-2024-05-13", 0.1)
    if ans is None:
        return ""
    return ans


async def get_attributes_for_one_test_set(testset, run_name, gptmodel="gpt-4o-2024-05-13", naive=False):
    list_of_attr = ""
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    # openai setup is anonymised 
    client = AsyncAzureOpenAI(
        azure_endpoint=azure_endpoint, #required
        api_key=refreshing_token.get_token(),
        api_version=api_version,
    )
    for i in tqdm(range(len(testset.questions))):
        this_list_of_attr = asyncio.run(query_for_attributes(testset.questions[i], testset.pref_anss[i], testset.rej_anss[i], testset.pref_scores[i], testset.rej_scores[i], client, refreshing_token))
        list_of_attr += this_list_of_attr
        list_of_attr += ", "
    return list_of_attr


async def generate_perturbations(question, response1, response2, score1, score2, client, refreshing_token, prompt_type="center"):
    direction1 = "better" if score1 > score2 else "worse"
    direction2 = "worse" if score1 > score2 else "better"
    howchangesemantic = "Negatively" if score1>score2 else "Positively"
    # step 1 prompt
    attribute_word_dict = {}
    prompt = f"In the task of response quality scoring, a trained deep learning model assigns real-valued scores for responses to questions, the higher the score the better the response quality.\
             The question is '{question}'. The model assigned a score {score1} for response A: '{response1}'.  The model assigned a score {score2} for response B: '{response2}'. \
             The high-level attributes that potentially caused the model to assign a {direction1} score for response A than response B are {attribute_list}. \
             Your task: for each high-level attribute in this list, identify the words in response A that are relevant to it. \
             Only output the attributes and their associated words like this: 'attribute: word1, word2, word3'. Each line should contain a comma-separated word list for one high-level attribute. \
             It is fine to have repeated words in the words identified for each high-level attribute, but you need to keep them in their original order of occurrence in the response A."
    res_lists = ""
    res = await request_llm(client, refreshing_token, prompt, "gpt-4o-2024-05-13", 0.1)
        if res is not None:
            res_lists = res.replace("\n\n", "\n")
    step_1_word_lists = res_lists.split("\n")

    for this_list in step_1_word_lists:
        attr = this_list.split(":")[0]
        this_list_words = []
        this_attribute_pass = False
        for l, item in enumerate(this_list_words):
            this_list_words[l] = item.strip()
            if this_list_words[l].lower() not in response1.lower():
                this_attribute_pass = True
                break
        if this_attribute_pass:
            continue
        attr = attr.strip()
        attribute_word_dict[attr] = this_list_words
        
    # step 2 prompt
    perturbation_dict = {}
    for attribute in list(attribute_descriptions.keys()):
        add_where_change_prompts = True
        if attribute not in list(attribute_word_dict.keys()) or attribute_word_dict[attribute] is None or len(attribute_word_dict[attribute]) == 0:
            add_where_change_prompts = False
        prompt_start = f"In the task of response quality scoring, a trained deep learning model assigns real-valued scores for responses to questions. \
    The higher the score, the better the response quality. \
    The question is '{question}'. The model assigned a score {score1} for response A: '{response1}'. The model assigned a score {score2} for response B: '{response2}'. \
    The potential high-level attribute that caused the model to assign a {direction1} score for response A than response B is: {attribute}. This attribute concerns {attribute_descriptions[attribute]}.\n\n\
    Your task is to modify response A. Here is a list of requirements for the modification: \n\
    - The modified response A becomes a {direction2} response to the question than response B. \n\
    - {howchangesemantic} change the semantic meaning of response A by making it {direction2} in terms of {attribute}. \n\
        prompt_end = "\n\n ONLY output the modified response A."
        if add_where_change_prompts:
            if prompt_type == "center":
                where_change_prompts = f"\n- The changes made to response A should be centered around the following words: {attribute_word_dict[attribute]}."
            elif prompt_type == "only":
                where_change_prompts = f"\n- Response A can only be modified by deleting, replacing, or inserting words, at the locations of all or a subset of the following words: {attribute_word_dict[attribute]}."
            else:
                where_change_prompts = ""
            prompt = prompt_start + where_change_prompts
        prompt = prompt + prompt_end
        ans = await request_llm(client, refreshing_token, prompt, "gpt-4o-2024-05-13", 0.1)
        if ans is None:
            continue
        if "modified response a" in ans.lower() or "this is it:" in ans.lower():
            ans_split_list = ans.split(":")[1:]
            ans = ""
            for item in ans_split_list:
                ans += item
        ans = ans.replace("\n", " ")
        perturbation_dict[attribute] = ans
    return perturbation_dict


# random perturbation baseline
async def generate_perturbations_naively_random(question, response1, response2, score1, score2, client, refreshing_token):
    # directly perturb
    perturbation_dict = {}
    keys_naive = [f"perturb{i}" for i in range(15)]
    generated_responses = [response1]
    for i, attribute in enumerate(keys_naive):
        prompt = f"Generate one random perturbation of this piece of text, which greatly change its semantic meaning: **{response1}**. \n\n ONLY output the perturbed text, without any explanations. Do NOT output any characters other than English texts and common punctuations. The perturbed response should be different from the following responses: {generated_responses}. "
        generated_responses.append(ans)
        ans = await request_llm(client, refreshing_token, prompt, "gpt-4o-2024-05-13", temperatures[i])
        if ans is None:
            continue
        ans = ans.replace("\n", " ")
        perturbation_dict[attribute] = ans
    return perturbation_dict


async def generate_for_one_test_set(testset, run_name, gptmodel="gpt-4o-2024-05-13", naive=False, prompt_type="center"):
    print("Generating text perturbations...")
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    # write explanations to file for visual inspections
    if not os.path.exists("generated/"):
        os.mkdir("generated/")
    pref_path = "generated/" + testset.data_name + "_" + run_name + "_ces_pref_" + gptmodel + ".txt"
    rej_path = "generated/" + testset.data_name + "_" + run_name + "_ces_rej_" + gptmodel + ".txt"
    f_p = open(pref_path, "w")
    f_r = open(rej_path, "w")

    # setup openai anonymised
    client = AsyncAzureOpenAI(
        azure_endpoint=azure_endpoint, #required
        api_key=refreshing_token.get_token(),
        api_version=api_version,
    )

    for i in tqdm(range(len(testset.questions))):
        if not naive:
            pref_ces = asyncio.run(generate_perturbations(testset.questions[i], testset.pref_anss[i], testset.rej_anss[i], testset.pref_scores[i], testset.rej_scores[i], client, refreshing_token, prompt_type))
            rej_ces = asyncio.run(generate_perturbations(testset.questions[i], testset.rej_anss[i], testset.pref_anss[i], testset.rej_scores[i], testset.pref_scores[i], client, refreshing_token, prompt_type))
        else:
            pref_ces = asyncio.run(generate_perturbations_naively_random(testset.questions[i], testset.pref_anss[i], testset.rej_anss[i], testset.pref_scores[i], testset.rej_scores[i], client, refreshing_token))
            rej_ces = asyncio.run(generate_perturbations_naively_random(testset.questions[i], testset.rej_anss[i], testset.pref_anss[i], testset.rej_scores[i], testset.pref_scores[i], client, refreshing_token))
        for key in list(pref_ces.keys()):
            f_p.write(f"$${key}$${pref_ces[key]}$$\n")
        for key in list(rej_ces.keys()):
            f_r.write(f"$${key}$${rej_ces[key]}$$\n")
        f_p.write(f"=====\n")
        f_r.write(f"=====\n")
    f_p.close()
    f_r.close()
    return pref_path, rej_path


def generate_for_one_test_set_pj(testset, run_name):
    print("Generating text perturbations...")
    if not os.path.exists("generated/"):
        os.mkdir("generated/")
    pref_path = "generated/" + testset.data_name + "_" + run_name + "_ces_pref_" + gptmodel + ".txt"
    rej_path = "generated/" + testset.data_name + "_" + run_name + "_ces_rej_" + gptmodel + ".txt"
    f_p = open(pref_path, "w")
    f_r = open(rej_path, "w")

    for i in tqdm(range(len(testset.questions))):
        pref_ces = {}
        rej_ces = {}
        pref_perturb = pj.perturb(orig_sent=testset.pref_anss[i], num_beams=2, truncation=True, num_perturbations=15)
        rej_perturb = pj.perturb(orig_sent=testset.rej_anss[i], num_beams=2, truncation=True, num_perturbations=15)
        count = 0
        keyname = f"polyjuice{count}"
        for item in pref_perturb:
            pref_ces[keyname] = item
            count += 1
            keyname = f"polyjuice{count}"
        count = 0
        for item in rej_perturb:
            rej_ces[keyname] = item
            count += 1
            keyname = f"polyjuice{count}"
        for key in list(pref_ces.keys()):
            f_p.write(f"$${key}$${pref_ces[key]}$$\n")
        for key in list(rej_ces.keys()):
            f_r.write(f"$${key}$${rej_ces[key]}$$\n")
        f_p.write(f"=====\n")
        f_r.write(f"=====\n")
    f_p.close()
    f_r.close()
    return pref_path, rej_path


def read_perturbations(file_path):
    f = open(file_path, "r", encoding='utf-8')
    perturbations = []
    this_perturbations = []
    attributes = []
    this_attributes = []
    for i, line in enumerate(f):
        if line.startswith("==="):
            perturbations.append(this_perturbations)
            attributes.append(this_attributes)
            this_perturbations = []
            this_attributes = []
            continue
        this_ce = line.split("$$")[2].strip()
        this_attribute = line.split("$$")[1].strip()
        if this_ce.startswith("'") and this_ce.endswith("'") or this_ce.startswith('"') and this_ce.endswith('"'):
            this_ce = this_ce[1:-1]
        # get rid of weird stuff, especially emoji
        clean_text = ""
        for c in this_ce.split(" "):
            isclean = True
            if not c.isalnum():
                for char in c:
                    if not (char.isalnum() or char in string.punctuation):
                        isclean=False
                        break
            if not isclean:
                continue
            clean_text += c
            clean_text += " "
        clean_text = clean_text.strip()
        this_perturbations.append(clean_text)
        this_attributes.append(this_attribute)
    f.close()
    return perturbations, attributes


def levenshtein_distance(word1, word2):
    word1 = word1.strip().replace("\n", "").replace(",", "").replace(".", "").replace("?", "").replace("!", "").split()
    word2 = word2.strip().replace("\n", "").replace(",", "").replace(".", "").replace("?", "").replace("!", "").split()
    m, n = len(word1), len(word2)
    matrix = np.zeros((m+1, n+1), dtype=int)
    matrix[:, 0] = np.arange(m+1)
    matrix[0, :] = np.arange(n+1)
    for i in range(1, m+1):
        for j in range(1, n+1):
            if word1[i-1] == word2[j-1]:
                substitution_cost = 0
            else:
                substitution_cost = 1

            matrix[i, j] = min(
                matrix[i-1, j] + 1,                # deletion
                matrix[i, j-1] + 1,                # insertion
                matrix[i-1, j-1] + substitution_cost    # substitution
            )
    return matrix[m, n] / max(m, n)

def sbert_dissimilarity(word1, word2, sbert):
    embeddings1 = sbert.encode(word1)
    embeddings2 = sbert.encode(word2)
    csty = cosine_similarity([embeddings1], [embeddings2])
    return 1 - csty

def calculate_diversity(ce_list, sbert):
    res = 0
    count = 0
    for i, ce1 in enumerate(ce_list):
        for j, ce2 in enumerate(ce_list):
            if j <= i:
                continue
            count += 1
            res += sbert_dissimilarity(ce1, ce2, sbert)
    if count == 0:
        return 0
    res /= count
    return res

def load_sbert():
    sbert_dir = "ANONYMISED"
    sbert = SentenceTransformer(sbert_dir)
    return sbert
    
# You can add the specific error you are getting to be handled here - https://pypi.org/project/backoff/
@backoff.on_exception(backoff.expo, (RateLimitError, APITimeoutError))
async def request_llm(client, refreshing_token, prompt, model_name:str, temperature:float):
    # Setting this before every API call will ensure a valid token is being used
    client.api_key = refreshing_token.get_token()
    res = await client.chat.completions.create(
        model=model_name,
        messages=[
            {"role": "user", "content": prompt}
        ],
        temperature=temperature,
    )
    return res.choices[0].message.content


def eval_ce_correct_wrong_tables(pref_path, rej_path, t, rm, tok, device="cuda"):
    correct_idxs, wrong_idxs = [], []
    for i, c in enumerate(t.corrects):
        if c:
            correct_idxs.append(i)
        else:
            wrong_idxs.append(i)
    sbert = load_sbert()
    print("Starting evaluations...")
    if not os.path.exists("evaluated/"):
        os.mkdir("evaluated/")
    config = pref_path.split("/")[-1].split("_ces")[0]
    inspect_pref_path = "evaluated/" + config + "_ces_pref.txt"
    inspect_rej_path = "evaluated/" + config + "_ces_rej.txt"
    f_insp_p = open(inspect_pref_path, "w", encoding='utf-8')
    f_insp_r = open(inspect_rej_path, "w", encoding='utf-8')
    ces_pref, attributes_pref = read_perturbations(pref_path)
    ces_rej, attributes_rej = read_perturbations(rej_path)
    
    # variables for outputs
    # coverage both sides
    cov_cfe, cov_sfe = 0, 0
    cov_cfe_correct, cov_sfe_correct = 0, 0
    cov_cfe_wrong, cov_sfe_wrong = 0, 0
    cov_cfe_pref, cov_cfe_rej, cov_sfe_pref, cov_sfe_rej = 0, 0, 0, 0

    # syntactic distance
    syn_dist_allperturb, syn_dist_cfe, syn_dist_sfe = 0, 0, 0
    syn_dist_allperturb_correct, syn_dist_allperturb_pref, syn_dist_allperturb_rej = 0, 0, 0
    syn_dist_allperturb_wrong = 0
    
    # semantic distance
    sem_dist_allperturb, sem_dist_cfe, sem_dist_sfe = 0, 0, 0
    sem_dist_allperturb_correct, sem_dist_allperturb_pref, sem_dist_allperturb_rej = 0, 0, 0
    sem_dist_allperturb_wrong = 0
    
    # diversity
    div_allperturb, div_cfe, div_sfe = 0, 0, 0
    div_allperturb_correct, div_allperturb_pref, div_allperturb_rej = 0, 0, 0
    div_allperturb_wrong = 0
    
    # counts
    count_ces, count_ses, count_allperturb = 0, 0, 0
    count_allperturb_correct, count_allperturb_wrong = 0, 0
    count_allperturb_pref, count_allperturb_rej = 0, 0
    
    div_allperturb_count, div_cfe_count, div_sfe_count = 0, 0, 0
    div_allperturb_correct_count, div_allperturb_pref_count, div_allperturb_rej_count = 0, 0, 0
    div_allperturb_wrong_count = 0
    
    count_ces_pref, count_ses_pref, count_ces_rej, count_ses_rej = 0, 0, 0, 0
    
    for i in tqdm(range(len(t.questions))):
        # correct comparison
        if i in correct_idxs:
            f_insp_p.write(f"$$ QUESTION: {t.questions[i]} $$\n")
            f_insp_p.write(f"$$ ORIGINAL PREF ANSWER: {t.pref_scores[i]} $$ {t.pref_anss[i]} $$\n")
            f_insp_p.write(f"$$ ORIGINAL REJ ANSWER: {t.rej_scores[i]} $$ {t.rej_anss[i]} $$\n")
            f_insp_p.write(f"$$ CORRECT COMPARISON: {t.corrects[i]} $$\n")

            f_insp_r.write(f"$$ QUESTION: {t.questions[i]} $$\n")
            f_insp_r.write(f"$$ ORIGINAL PREF ANSWER: {t.pref_scores[i]} $$ {t.pref_anss[i]} $$\n")
            f_insp_r.write(f"$$ ORIGINAL REJ ANSWER: {t.rej_scores[i]} $$ {t.rej_anss[i]} $$\n")
            f_insp_r.write(f"$$ CORRECT COMPARISON: {t.corrects[i]} $$\n")
            
            # pref perturbations
            pref_ce_list = []
            pref_se_list = []
            pref_ce_found, pref_se_found = False, False
            for j in range(len(ces_pref[i])):
                # get perturbation text
                ce = ces_pref[i][j]
                attribute = attributes_pref[i][j]
                if "none" in attribute.lower() or len(ce) > 7000:
                    continue

                # calculate distances
                lvstn = levenshtein_distance(t.pref_anss[i], ce)
                sb = sbert_dissimilarity(t.pref_anss[i], ce, sbert)
                # put into reward model
                inputs = None
                if "deberta" in t.model_type:
                    inputs = tok(t.questions[i], ce, return_tensors='pt').to(device)
                if "pythia" in t.model_type:
                    xy = f"<|prompter|>{t.questions[i]}<|endoftext|><|assistant|>{ce}<|endoftext|>"
                    inputs = tok(xy, return_tensors='pt').to(device)
                score = rm(**inputs).logits[0].cpu().detach()[0]
                inputs.to("cpu")
                torch.cuda.empty_cache()
                
                # update distances
                syn_dist_allperturb += lvstn
                syn_dist_allperturb_correct += lvstn
                syn_dist_allperturb_pref += lvstn
                sem_dist_allperturb += sb
                sem_dist_allperturb_correct += sb
                sem_dist_allperturb_pref += sb
                
                # update counts
                count_allperturb += 1
                count_allperturb_correct += 1
                count_allperturb_pref += 1
                
                # counterfactual
                if score <= t.rej_scores[i]:
                    pref_ce_found = True
                    pref_ce_list.append(ce)
                    count_ces += 1
                    count_ces_pref += 1
                    syn_dist_cfe += lvstn
                    sem_dist_cfe += sb
                # semifactual
                else:
                    pref_se_found = True
                    pref_se_list.append(ce)
                    count_ses += 1
                    count_ses_pref += 1
                    syn_dist_sfe += lvstn
                    sem_dist_sfe += sb
                    
                sf_or_cf = "==counterfactual==" if score <= t.rej_scores[i] else "==semifactual=="
                f_insp_p.write(f"$$ [CE FOR PREF ANSWER]: {attribute}:: {score} ({t.rej_scores[i]}, {score <= t.rej_scores[i]}, {sf_or_cf}) $$ {ce} $$\n")
            
            rej_ce_list = []
            rej_se_list = []
            rej_ce_found, rej_se_found = False, False
            # rej perturbations
            for j in range(len(ces_rej[i])):
                # get perturbation text
                ce = ces_rej[i][j]
                attribute = attributes_rej[i][j]
                if "none" in attribute.lower() or len(ce) > 7000:
                    continue

                # calculate distances
                lvstn = levenshtein_distance(t.rej_anss[i], ce)
                sb = sbert_dissimilarity(t.rej_anss[i], ce, sbert)
                # put into reward model
                inputs = None
                if "deberta" in t.model_type:
                    inputs = tok(t.questions[i], ce, return_tensors='pt').to(device)
                if "pythia" in t.model_type:
                    xy = f"<|prompter|>{t.questions[i]}<|endoftext|><|assistant|>{ce}<|endoftext|>"
                    inputs = tok(xy, return_tensors='pt').to(device)
                score = rm(**inputs).logits[0].cpu().detach()[0]
                inputs.to("cpu")
                torch.cuda.empty_cache()
                
                # update distances
                syn_dist_allperturb += lvstn
                syn_dist_allperturb_correct += lvstn
                syn_dist_allperturb_rej += lvstn
                sem_dist_allperturb += sb
                sem_dist_allperturb_correct += sb
                sem_dist_allperturb_rej += sb
                
                # update counts
                count_allperturb += 1
                count_allperturb_correct += 1
                count_allperturb_rej += 1
                
                # counterfactual
                if score >= t.pref_scores[i]:
                    rej_ce_found = True
                    rej_ce_list.append(ce)
                    count_ces += 1
                    count_ces_rej += 1
                    syn_dist_cfe += lvstn
                    sem_dist_cfe += sb
                # semifactual
                else:
                    rej_se_found = True
                    rej_se_list.append(ce)
                    count_ses += 1
                    count_ses_rej += 1
                    syn_dist_sfe += lvstn
                    sem_dist_sfe += sb
                    
                sf_or_cf = "==counterfactual==" if score >= t.pref_scores[i] else "==semifactual=="
                f_insp_r.write(f"$$ [CE FOR REJ ANSWER]: {attribute}::  {score} ({t.pref_scores[i]}, {score >= t.pref_scores[i]}, {sf_or_cf}) $$ {ce} $$\n")
            # update coverage
            if pref_ce_found:
                cov_cfe_pref += 1
            if pref_se_found:
                cov_sfe_pref += 1
            if rej_ce_found:
                cov_cfe_rej += 1
            if rej_se_found:
                cov_sfe_rej += 1
            if pref_ce_found and rej_ce_found:
                cov_cfe += 1
                cov_cfe_correct += 1
            if pref_se_found and rej_se_found:
                cov_sfe += 1
                cov_sfe_correct += 1

            # update diversity
            if len(pref_ce_list) != 0:
                div_val = calculate_diversity(pref_ce_list, sbert)
                div_allperturb += div_val
                div_allperturb_count += 1
                div_cfe += div_val
                div_cfe_count += 1
                div_allperturb_correct += div_val
                div_allperturb_correct_count += 1
                div_allperturb_pref += div_val
                div_allperturb_pref_count += 1
            if len(pref_se_list) != 0:
                div_val = calculate_diversity(pref_se_list, sbert)
                div_allperturb += div_val
                div_allperturb_count += 1
                div_sfe += div_val
                div_sfe_count += 1
                div_allperturb_correct += div_val
                div_allperturb_correct_count += 1
                div_allperturb_pref += div_val
                div_allperturb_pref_count += 1
            if len(rej_ce_list) != 0:
                div_val = calculate_diversity(rej_ce_list, sbert)
                div_allperturb += div_val
                div_allperturb_count += 1
                div_cfe += div_val
                div_cfe_count += 1
                div_allperturb_correct += div_val
                div_allperturb_correct_count += 1
                div_allperturb_rej += div_val
                div_allperturb_rej_count += 1
            if len(rej_se_list) != 0:
                div_val = calculate_diversity(rej_se_list, sbert)
                div_allperturb += div_val
                div_allperturb_count += 1
                div_sfe += div_val
                div_sfe_count += 1
                div_allperturb_correct += div_val
                div_allperturb_correct_count += 1
                div_allperturb_rej += div_val
                div_allperturb_rej_count += 1
            f_insp_p.write("========\n\n")
            f_insp_r.write("========\n\n")
            torch.cuda.empty_cache()
        # wrong comparison
        else:
            f_insp_p.write(f"$$ QUESTION: {t.questions[i]} $$\n")
            f_insp_p.write(f"$$ ORIGINAL PREF ANSWER: {t.pref_scores[i]} $$ {t.pref_anss[i]} $$\n")
            f_insp_p.write(f"$$ ORIGINAL REJ ANSWER: {t.rej_scores[i]} $$ {t.rej_anss[i]} $$\n")
            f_insp_p.write(f"$$ CORRECT COMPARISON: {t.corrects[i]} $$\n")

            f_insp_r.write(f"$$ QUESTION: {t.questions[i]} $$\n")
            f_insp_r.write(f"$$ ORIGINAL PREF ANSWER: {t.pref_scores[i]} $$ {t.pref_anss[i]} $$\n")
            f_insp_r.write(f"$$ ORIGINAL REJ ANSWER: {t.rej_scores[i]} $$ {t.rej_anss[i]} $$\n")
            f_insp_r.write(f"$$ CORRECT COMPARISON: {t.corrects[i]} $$\n")
            
            # pref perturbations
            pref_ce_list = []
            pref_se_list = []
            pref_ce_found, pref_se_found = False, False
            for j in range(len(ces_pref[i])):
                # get perturbation text
                ce = ces_pref[i][j]
                attribute = attributes_pref[i][j]
                if "none" in attribute.lower() or len(ce) > 7000:
                    continue

                # calculate distances
                lvstn = levenshtein_distance(t.pref_anss[i], ce)
                sb = sbert_dissimilarity(t.pref_anss[i], ce, sbert)
                # put into reward model
                inputs = None
                if "deberta" in t.model_type:
                    inputs = tok(t.questions[i], ce, return_tensors='pt').to(device)
                if "pythia" in t.model_type:
                    xy = f"<|prompter|>{t.questions[i]}<|endoftext|><|assistant|>{ce}<|endoftext|>"
                    inputs = tok(xy, return_tensors='pt').to(device)
                score = rm(**inputs).logits[0].cpu().detach()[0]
                inputs.to("cpu")
                torch.cuda.empty_cache()
                
                # update distances
                syn_dist_allperturb += lvstn
                syn_dist_allperturb_wrong += lvstn
                syn_dist_allperturb_pref += lvstn
                sem_dist_allperturb += sb
                sem_dist_allperturb_wrong += sb
                sem_dist_allperturb_pref += sb
                
                # update counts
                count_allperturb += 1
                count_allperturb_wrong += 1
                count_allperturb_pref += 1
                
                # counterfactual
                if score <= t.rej_scores[i]:
                    pref_ce_found = True
                    pref_ce_list.append(ce)
                    count_ces += 1
                    count_ces_pref += 1
                    syn_dist_cfe += lvstn
                    sem_dist_cfe += sb
                # semifactual
                else:
                    pref_se_found = True
                    pref_se_list.append(ce)
                    count_ses += 1
                    count_ses_pref += 1
                    syn_dist_sfe += lvstn
                    sem_dist_sfe += sb
                    
                sf_or_cf = "==counterfactual==" if score <= t.rej_scores[i] else "==semifactual=="
                f_insp_p.write(f"$$ [CE FOR PREF ANSWER]: {attribute}:: {score} ({t.rej_scores[i]}, {score <= t.rej_scores[i]}, {sf_or_cf}) $$ {ce} $$\n")
            
            rej_ce_list = []
            rej_se_list = []
            rej_ce_found, rej_se_found = False, False
            # rej perturbations
            for j in range(len(ces_rej[i])):
                # get perturbation text
                ce = ces_rej[i][j]
                attribute = attributes_rej[i][j]
                if "none" in attribute.lower() or len(ce) > 7000:
                    continue

                # calculate distances
                lvstn = levenshtein_distance(t.rej_anss[i], ce)
                sb = sbert_dissimilarity(t.rej_anss[i], ce, sbert)
                # put into reward model
                inputs = None
                if "deberta" in t.model_type:
                    inputs = tok(t.questions[i], ce, return_tensors='pt').to(device)
                if "pythia" in t.model_type:
                    xy = f"<|prompter|>{t.questions[i]}<|endoftext|><|assistant|>{ce}<|endoftext|>"
                    inputs = tok(xy, return_tensors='pt').to(device)
                score = rm(**inputs).logits[0].cpu().detach()[0]
                inputs.to("cpu")
                torch.cuda.empty_cache()
                
                # update distances
                syn_dist_allperturb += lvstn
                syn_dist_allperturb_wrong += lvstn
                syn_dist_allperturb_rej += lvstn
                sem_dist_allperturb += sb
                sem_dist_allperturb_wrong += sb
                sem_dist_allperturb_rej += sb
                
                # update counts
                count_allperturb += 1
                count_allperturb_wrong += 1
                count_allperturb_rej += 1
                
                # counterfactual
                if score >= t.pref_scores[i]:
                    rej_ce_found = True
                    rej_ce_list.append(ce)
                    count_ces += 1
                    count_ces_rej += 1
                    syn_dist_cfe += lvstn
                    sem_dist_cfe += sb
                # semifactual
                else:
                    rej_se_found = True
                    rej_se_list.append(ce)
                    count_ses += 1
                    count_ses_rej += 1
                    syn_dist_sfe += lvstn
                    sem_dist_sfe += sb
                    
                sf_or_cf = "==counterfactual==" if score >= t.pref_scores[i] else "==semifactual=="
                f_insp_r.write(f"$$ [CE FOR REJ ANSWER]: {attribute}::  {score} ({t.pref_scores[i]}, {score >= t.pref_scores[i]}, {sf_or_cf}) $$ {ce} $$\n")
            # update coverage
            if pref_ce_found:
                cov_cfe_pref += 1
            if pref_se_found:
                cov_sfe_pref += 1
            if rej_ce_found:
                cov_cfe_rej += 1
            if rej_se_found:
                cov_sfe_rej += 1
            if pref_ce_found and rej_ce_found:
                cov_cfe += 1
                cov_cfe_wrong += 1
            if pref_se_found and rej_se_found:
                cov_sfe += 1
                cov_sfe_wrong += 1

            # update diversity
            if len(pref_ce_list) != 0:
                div_val = calculate_diversity(pref_ce_list, sbert)
                div_allperturb += div_val
                div_allperturb_count += 1
                div_cfe += div_val
                div_cfe_count += 1
                div_allperturb_wrong += div_val
                div_allperturb_wrong_count += 1
                div_allperturb_pref += div_val
                div_allperturb_pref_count += 1
            if len(pref_se_list) != 0:
                div_val = calculate_diversity(pref_se_list, sbert)
                div_allperturb += div_val
                div_allperturb_count += 1
                div_sfe += div_val
                div_sfe_count += 1
                div_allperturb_wrong += div_val
                div_allperturb_wrong_count += 1
                div_allperturb_pref += div_val
                div_allperturb_pref_count += 1
            if len(rej_ce_list) != 0:
                div_val = calculate_diversity(rej_ce_list, sbert)
                div_allperturb += div_val
                div_allperturb_count += 1
                div_cfe += div_val
                div_cfe_count += 1
                div_allperturb_wrong += div_val
                div_allperturb_wrong_count += 1
                div_allperturb_rej += div_val
                div_allperturb_rej_count += 1
            if len(rej_se_list) != 0:
                div_val = calculate_diversity(rej_se_list, sbert)
                div_allperturb += div_val
                div_allperturb_count += 1
                div_sfe += div_val
                div_sfe_count += 1
                div_allperturb_wrong += div_val
                div_allperturb_wrong_count += 1
                div_allperturb_rej += div_val
                div_allperturb_rej_count += 1
            f_insp_p.write("========\n\n")
            f_insp_r.write("========\n\n")
            torch.cuda.empty_cache()
    # normalise numbers
    syn_dist_allperturb /= count_allperturb
    syn_dist_cfe /= count_ces
    syn_dist_sfe /= count_ses
    syn_dist_allperturb_correct /= count_allperturb_correct
    syn_dist_allperturb_wrong /= count_allperturb_wrong
    syn_dist_allperturb_pref /= count_allperturb_pref
    syn_dist_allperturb_rej /= count_allperturb_rej
    
    sem_dist_allperturb /= count_allperturb
    sem_dist_cfe /= count_ces
    sem_dist_sfe /= count_ses
    sem_dist_allperturb_correct /= count_allperturb_correct
    sem_dist_allperturb_wrong /= count_allperturb_wrong
    sem_dist_allperturb_pref /= count_allperturb_pref
    sem_dist_allperturb_rej /= count_allperturb_rej
    
    div_allperturb /= div_allperturb_count
    div_cfe /= div_cfe_count
    div_sfe /= div_sfe_count
    div_allperturb_correct /= div_allperturb_correct_count
    div_allperturb_pref /= div_allperturb_pref_count
    div_allperturb_rej /= div_allperturb_rej_count
    div_allperturb_wrong /= div_allperturb_wrong_count
    
    cov_cfe /= len(t.questions)
    cov_sfe /= len(t.questions)
    if len(correct_idxs) != 0:
        cov_cfe_correct /= len(correct_idxs)
        cov_sfe_correct /= len(correct_idxs)
    if len(wrong_idxs) != 0:
        cov_cfe_wrong /= len(wrong_idxs)
        cov_sfe_wrong /= len(wrong_idxs)
    cov_cfe_pref /= len(t.questions)
    cov_cfe_rej /= len(t.questions)
    cov_sfe_pref /= len(t.questions)
    cov_sfe_rej /= len(t.questions)

    f_insp_p.close()
    f_insp_r.close()
    
    res_tab1 = [cov_cfe, cov_sfe, syn_dist_allperturb, sem_dist_allperturb[0][0], div_allperturb[0][0]]
    res_tab2 = [syn_dist_cfe, sem_dist_cfe[0][0], div_cfe[0][0], syn_dist_sfe, sem_dist_sfe[0][0], div_sfe[0][0]]
    res_tabcw = [cov_cfe_correct, cov_sfe_correct, syn_dist_allperturb_correct, sem_dist_allperturb_correct[0][0], div_allperturb_correct[0][0], 
                cov_cfe_wrong, cov_sfe_wrong, syn_dist_allperturb_wrong, sem_dist_allperturb_wrong[0][0], div_allperturb_wrong[0][0]]
    res_tabpr = [cov_cfe_pref, cov_sfe_pref, syn_dist_allperturb_pref, sem_dist_allperturb_pref[0][0], div_allperturb_pref[0][0], 
                cov_cfe_rej, cov_sfe_rej, syn_dist_allperturb_rej, sem_dist_allperturb_rej[0][0], div_allperturb_rej[0][0]]
    
    # report cfe/sfe rate
    count_ces_pref, count_ses_pref, count_ces_rej, count_ses_rej
    res_rates = [count_ces/count_allperturb, count_ses/count_allperturb, count_ces_pref/count_allperturb_pref, count_ses_pref/count_allperturb_pref, 
                count_ces_rej/count_allperturb_rej, count_ses_rej/count_allperturb_rej]
    return res_tab1, res_tab2, res_tabcw, res_tabpr, res_rates

