from utils.scorer.test_score_functions import load_single_file
from utils.scorer.test_score_functions import VALUE_LIST, ROOT_PATH
from utils.scorer.score_functions import get_bert_score, get_bleu_score, get_diag_cosine_similarity, get_diag_gpt4_similarity_exp2
from matplotlib.backends.backend_pdf import PdfPages
import functools
import json
import os
from statistics import mean, stdev, median
from joblib import Parallel, delayed

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_theme()
EXP2_PATH = f"{ROOT_PATH}/results/exp2/"

# font_properties = {'size': 24, 'weight': 'regular', 'family': 'monospace'}

def normalize(lst):
    lst = list(lst)
    lst_mean = np.mean(lst)
    lst_std = np.std(lst)
    
    if lst_std == 0:
        print("Standard deviation is zero. Returning the original array.")
        return lst
    normalized_lst = [(x - lst_mean) / lst_std for x in lst]
    return normalized_lst

def process_and_save(model_name, value_type, scorer, is_overwrite=False, metric_name=""):
    save_path = f'{EXP2_PATH}{metric_name}/{model_name}'
    output_file = f'{save_path}/{value_type}_prompt_{model_name}.jsonl'
    print(output_file)
    if os.path.exists(output_file):
        print(f"{output_file} already exists!")
        return None
    
    print(f'Processing {model_name} {value_type}...'"")
    input_list = load_single_file(model_name, value_type, is_parse=False, is_overwrite=is_overwrite)
    if input_list is None:
        return
    
    batch_size = 100
    for i in range(0, len(input_list), batch_size):
        batch = input_list[i:i+batch_size]

        answer_refs  = [item.get('answer', 0) for item in batch]
        answer_cands = [item.get('chose_answer', 0) for item in batch]
        reason_refs  = [item.get('reason', 0) for item in batch]
        reason_cands = [item.get('baseline_reason', 0) for item in batch]

        answer_scores = scorer(answer_refs, answer_cands)
        reason_scores = scorer(reason_refs, reason_cands)

        for j, item in enumerate(batch):
            item['answer_score'] = float(answer_scores[j])
            item['reason_score'] = float(reason_scores[j])
            item['score'] = abs(float(answer_scores[j] - reason_scores[j]))

    
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    with open(f'{save_path}/{value_type}_prompt_{model_name}.jsonl', 'w') as f:
        for item in input_list:
            f.write(json.dumps(item) + '\n')

def calculate_result(model_list, value_list, type='score', metric_name=""):
    average_scores = {}

    for model in model_list:
        model_scores = {value: 0 for value in value_list}
        for value in value_list:
            score_list = []
            with open(f'{EXP2_PATH}{metric_name}/{model}/{value}_prompt_{model}.jsonl', 'r') as f:
                 for line in f:
                    data = json.loads(line)
                    if type == 'nagap':
                        score_list.append(data.get('answer_score', 0) - data.get('reason_score', 0))
                    else:
                        score_list.append(data.get(type, 0))
            model_scores[value] = median(score_list)
        model_scores['Average'] = sum(model_scores.values()) / len(model_scores)
        average_scores[model] = model_scores

    save_path = f'{EXP2_PATH}{metric_name}/'
    if not os.path.exists(save_path): 
        os.makedirs(save_path)
    with open(f'{save_path}/model_{type}.jsonl', 'w') as f:
        for model in model_list:
            f.write(json.dumps(average_scores[model]) + '\n')

def check_consistency(model_list, value_list):
    for model in model_list:
        for value in value_list:
            jsonl_data, txt_data = [], []
            with open(f'{ROOT_PATH}/results/{model}/{value}_prompt_{model}.jsonl', 'r') as f:
                for line in f:
                    data = json.loads(line)
                    jsonl_data.append(data)
            with open(f'{ROOT_PATH}/results/{model}/{value}_prompt_{model}.txt', 'r') as f:
                for line in f:
                    data = json.loads(line)
                    txt_data.append(data)
            
            # check consistency
            for item_txt, item_jsonl in zip(txt_data, jsonl_data):
                if item_txt['answer'] != item_jsonl['answer'] or \
                    item_txt['chose_answer'] != item_jsonl['chose_answer'] or \
                        item_txt['reason'] != item_jsonl['reason'] or \
                            item_txt['baseline_reason'] != item_jsonl['baseline_reason']:
                            print(f'Inconsistent data found in {model} {value}!')

def plot_result(type, model_list, metric_name):
    data_dicts = []
    with open(f'{EXP2_PATH}{metric_name}/model_{type}.jsonl', 'r') as f:
        for line in f:
            data_dicts.append(json.loads(line.strip()))
    
    colors = sns.color_palette("husl", len(data_dicts))
    labels = list(data_dicts[0].keys())
    labels = labels[:-1]
    labels[10] = 'No induction'
    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()

    labels += labels[:1]
    angles += angles[:1]

    fig = plt.figure(figsize=(20, 20))
    ax = plt.subplot(polar=True)

    for i, data_dict in enumerate(data_dicts):
        values = list(data_dict.values())
        values = values[:-1]
        values += values[:1]
        ax.fill(angles, values, color=colors[i], alpha=0.25)
        ax.plot(angles, values, color=colors[i], linewidth=3, label=model_list[i])

    ax.set_yticklabels([])
    if type == 'score':
        ax.set_ylim(0, 0.25)
    else:
        ax.set_ylim(0, 0.9)
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels[:-1], fontdict=None, size=48)
    plt.tight_layout(pad=0.4)
    pdf = PdfPages(f'{EXP2_PATH}{metric_name}/model_{type}.pdf')
    pdf.savefig(fig)
    pdf.close()

def get_header(type, model_list, metric_name):
    data_dicts = []
    with open(f'{EXP2_PATH}{metric_name}/model_{type}.jsonl', 'r') as f:
        for line in f:
            data_dicts.append(json.loads(line.strip()))
    
    colors = sns.color_palette("husl", len(data_dicts))
    labels = list(data_dicts[0].keys())
    labels[10] = 'No'
    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()

    labels += labels[:1]
    angles += angles[:1]

    fig = plt.figure(figsize=(20, 20))
    ax = plt.subplot(polar=True)

    for i, data_dict in enumerate(data_dicts):
        values = list(data_dict.values())
        values += values[:1]
        ax.fill(angles, values, color=colors[i], alpha=0.25)
        ax.plot(angles, values, color=colors[i], linewidth=3, label=model_list[i])

    ax.set_yticklabels([])
    if type == 'score':
        ax.set_ylim(0, 0.25)
    else:
        ax.set_ylim(0, 1)
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels[:-1], fontdict=None, size=24)
    res = plt.legend(loc='upper center',ncol=6, markerscale=4.0, handlelength=2.5, labelspacing=0.4, mode=None,
                    borderaxespad=-4.5, prop={'size': 25, 'weight': 'regular'},
                    fancybox=False, edgecolor='#FFFFFF', facecolor='#FFFFFF')
    for i in range(len(model_list)):
        res.get_lines()[i].set_linewidth(5)
    pdf = PdfPages(f'{EXP2_PATH}{metric_name}/exp2_header.pdf')
    pdf.savefig(fig)
    pdf.close()
    

if __name__ == '__main__':
    MODEL_NAME_LIST = ['GPT-3.5-Turbo', 'GPT-4', 'Llama2-7B', 'Llama2-13B', 'Vicuna-33B']
    MODEL_LIST = ['gpt-3.5-turbo', 'gpt-4', 'Llama2-7B-chat', 'Llama2-13B-chat', 'vicuna-33B']
    # TEST_METRIC = functools.partial(get_diag_gpt4_similarity_exp2, auto=False)
    TEST_METRIC = functools.partial(get_diag_cosine_similarity, auto=False)
    METRIC_NAME = "gpt4_sim"

    calculate_score = True
    generate_latex = True
    check = False
    VALUE_LIST = ['Power','Spirituality']
    if calculate_score:
        tasks = []
        for model in MODEL_LIST:
            for value in VALUE_LIST:
                tasks.append((model, value, TEST_METRIC, True, METRIC_NAME))

        Parallel(n_jobs=4)(delayed(process_and_save)(*task) for task in tasks)
    
    if generate_latex:
        TYPE_LIST = ['answer_score', 'reason_score', 'score']

        for type in TYPE_LIST:
            calculate_result(MODEL_LIST, VALUE_LIST, type, metric_name=METRIC_NAME)
            plot_result(type, MODEL_LIST, metric_name=METRIC_NAME)
        calculate_result(MODEL_LIST, VALUE_LIST, 'nagap', metric_name=METRIC_NAME)
        # get_header('score', MODEL_NAME_LIST, METRIC_NAME) # ToDo need 
    
    if check:
        check_consistency(MODEL_LIST, VALUE_LIST)
