from typing import List, Dict, Tuple
import numpy as np
import json
import os
import functools
import warnings
warnings.filterwarnings('ignore')
# from score_functions import get_bert_score, get_bleu_score, get_cosine_similarity
from utils.scorer.score_functions import get_bert_score, get_bleu_score, get_cosine_similarity

ROOT_PATH = os.getcwd()
VALUE_LIST = ['Power', 'Spirituality', 'Benevolence', 'Tradition', 'Self-Direction', 'Achievement', 'Stimulation', 'Security', 'Conformity', 'Hedonism', 'no']
VALUE_ALIAS_LIST = ["k8jZw2", "pLm5tQ", "3rXo1v", "Hd9yG6", "cB7fA4", "e0iUyV", "s5TzRq", "Mn2oW3", "b6LjKp", "x4EgH0", "uF1sD8"]


def extract_lists(dict_list: List[Dict[str, str]]) -> Tuple[List[str], List[str]]:
    answers = [d['answer'] for d in dict_list if 'answer' in d]
    chose_answers = [d['chose_answer'] for d in dict_list if 'chose_answer' in d]
    reasons = [d['reason'] for d in dict_list if 'reason' in d]
    baseline_reasons = [d['baseline_reason'] for d in dict_list if 'baseline_reason' in d]
    return answers, chose_answers, reasons, baseline_reasons

def load_single_file(model_name, value_type, is_parse=True, is_exp1=False, is_overwrite=False):
    file_name = f"{ROOT_PATH}/results/{model_name}/{value_type}_prompt_{model_name}.txt"
    if not os.path.exists(file_name):
        print(f"{file_name} does not exist!")
        return None
    if (not is_exp1 and not is_overwrite) and os.path.exists(file_name[:-4] + '.jsonl'):
        print(f"{file_name[:-4] + '.jsonl'} already exists!")
        return None
    data = []
    with open(file_name, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return extract_lists(data) if is_parse else data

if __name__ == '__main__':
    # Test Setup
    PRINT_DESCRIPTION = True
    # Test Experiment Setup
    TEST_METRIC = functools.partial(get_cosine_similarity, model_name="all-MiniLM-L6-v2", auto=False)
    METRIC_NAME = "Sentence Transformer"
    MODEL_NAME = "gpt-4"
    VALUE_TYPE = VALUE_LIST[10]
    NUM = 3
    answers, chose_answers, reasons, baseline_reasons = load_single_file(model_name = MODEL_NAME, value_type = VALUE_TYPE)
    indexs = np.random.choice(len(answers), NUM, replace=False)

    print(f"Test: {MODEL_NAME} {VALUE_TYPE} {METRIC_NAME}")
    for i in indexs:
        if PRINT_DESCRIPTION:
            print(f"Answer: {answers[i]}")
            print(f"Chose Answer: {chose_answers[i]}")
        know_what_score = TEST_METRIC([answers[i]], [chose_answers[i]])
        print(f"(Konw what) {METRIC_NAME} Score: {know_what_score}")
        
        if PRINT_DESCRIPTION:
            print(f"Reason: {reasons[i]}")
            print(f"Baseline Reason: {baseline_reasons[i]}")
        know_why_score = TEST_METRIC([reasons[i]], [baseline_reasons[i]])
        print(f"(Konw why) {METRIC_NAME} Score: {know_why_score}")

        print(f"Final score: {abs(know_what_score - know_why_score)}")
        print("===========================================")