from utils.scorer.test_score_functions import load_single_file
from utils.scorer.test_score_functions import VALUE_LIST, VALUE_ALIAS_LIST, ROOT_PATH
from utils.scorer.score_functions import get_bert_score, get_bleu_score, get_diag_cosine_similarity, get_diag_gpt4_similarity_exp2
from matplotlib.backends.backend_pdf import PdfPages
import functools
import json
import os
from statistics import mean, stdev, median
from joblib import Parallel, delayed
import random
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_theme()
EXP2_PATH = f"{ROOT_PATH}/results/exp2/"

# font_properties = {'size': 24, 'weight': 'regular', 'family': 'monospace'}

def normalize(lst):
    lst = list(lst)
    lst_mean = np.mean(lst)
    lst_std = np.std(lst)
    
    if lst_std == 0:
        print("Standard deviation is zero. Returning the original array.")
        return lst
    normalized_lst = [(x - lst_mean) / lst_std for x in lst]
    return normalized_lst

def process_and_save(model_name, value_type, scorer, is_overwrite=False, metric_name=""):
    save_path = f'{EXP2_PATH}{metric_name}/{model_name}'
    output_file = f'{save_path}/{value_type}_prompt_{model_name}.jsonl'
    if os.path.exists(output_file):
        print(f"{output_file} already exists!")
        return None
    
    print(f'Processing {model_name} {value_type}...'"")
    input_list = load_single_file(model_name, value_type, is_parse=False, is_overwrite=is_overwrite)
    if input_list is None:
        return
    
    batch_size = 100
    for i in range(0, len(input_list), batch_size):
        batch = input_list[i:i+batch_size]

        answer_refs  = [item.get('answer', 0) for item in batch]
        answer_cands = [item.get('chose_answer', 0) for item in batch]
        reason_refs  = [item.get('reason', 0) for item in batch]
        reason_cands = [item.get('baseline_reason', 0) for item in batch]

        answer_scores = scorer(answer_refs, answer_cands)
        reason_scores = scorer(reason_refs, reason_cands)

        for j, item in enumerate(batch):
            item['answer_score'] = float(answer_scores[j])
            item['reason_score'] = float(reason_scores[j])
            item['score'] = abs(float(answer_scores[j] - reason_scores[j]))

    
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    with open(f'{save_path}/{value_type}_prompt_{model_name}.jsonl', 'w') as f:
        for item in input_list:
            f.write(json.dumps(item) + '\n')

def calculate_result(model_list, value_list, type='score', metric_name=""):
    average_scores = {}

    for model in model_list:
        model_scores = {value: 0 for value in value_list}
        for value in value_list:
            score_list = []
            with open(f'{EXP2_PATH}{metric_name}/{model}/{value}_prompt_{model}.jsonl', 'r') as f:
                 for line in f:
                    data = json.loads(line)
                    if type == 'nagap':
                        score_list.append(data.get('answer_score', 0) - data.get('reason_score', 0))
                    else:
                        score_list.append(data.get(type, 0))
            model_scores[value] = median(score_list)
        model_scores['Average'] = sum(model_scores.values()) / len(model_scores)
        model_scores['Max'] = max(model_scores.values())
        model_scores['Min'] = min(model_scores.values())
        average_scores[model] = model_scores

    save_path = f'{EXP2_PATH}{metric_name}/'
    if not os.path.exists(save_path): 
        os.makedirs(save_path)
    with open(f'{save_path}/model_{type}.jsonl', 'w') as f:
        for model in model_list:
            f.write(json.dumps(average_scores[model]) + '\n')

def check_consistency(model_list, value_list):
    for model in model_list:
        for value in value_list:
            jsonl_data, txt_data = [], []
            with open(f'{ROOT_PATH}/results/{model}/{value}_prompt_{model}.jsonl', 'r') as f:
                for line in f:
                    data = json.loads(line)
                    jsonl_data.append(data)
            with open(f'{ROOT_PATH}/results/{model}/{value}_prompt_{model}.txt', 'r') as f:
                for line in f:
                    data = json.loads(line)
                    txt_data.append(data)
            
            # check consistency
            for item_txt, item_jsonl in zip(txt_data, jsonl_data):
                if item_txt['answer'] != item_jsonl['answer'] or \
                    item_txt['chose_answer'] != item_jsonl['chose_answer'] or \
                        item_txt['reason'] != item_jsonl['reason'] or \
                            item_txt['baseline_reason'] != item_jsonl['baseline_reason']:
                            print(f'Inconsistent data found in {model} {value}!')

def plot_result(type, model_list, metric_name, length=11):
    data_dicts = []
    with open(f'{EXP2_PATH}{metric_name}/model_{type}.jsonl', 'r') as f:
        for line in f:
            data_dicts.append(json.loads(line.strip()))
    
    colors = sns.color_palette("husl", len(data_dicts))
    labels = list(data_dicts[0].keys())
    labels = labels[:length]
    labels[10] = 'No induction'
    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()

    labels += labels[:1]
    angles += angles[:1]

    fig = plt.figure(figsize=(20, 20))
    ax = plt.subplot(polar=True)

    for i, data_dict in enumerate(data_dicts):
        values = list(data_dict.values())
        values = values[:length]
        values += values[:1]
        ax.fill(angles, values, color=colors[i], alpha=0.25)
        ax.plot(angles, values, color=colors[i], linewidth=3, label=model_list[i])

    ax.set_yticklabels([])
    if type == 'score':
        ax.set_ylim(0, 0.4)
    elif type == 'answer_score':
        ax.set_ylim(0.3, 0.92)
    else:
        ax.set_ylim(0.8, 1)
    ax.set_xticks(angles[:length])
    ax.set_xticklabels(labels[:-1], fontdict=None, size=48)
    plt.tight_layout(pad=0.4)
    pdf = PdfPages(f'{EXP2_PATH}{metric_name}/model_{type}.pdf')
    pdf.savefig(fig)
    pdf.close()

def get_header(type, model_list, metric_name):
    data_dicts = []
    with open(f'{EXP2_PATH}{metric_name}/model_{type}.jsonl', 'r') as f:
        for line in f:
            data_dicts.append(json.loads(line.strip()))
    
    colors = sns.color_palette("husl", len(data_dicts))
    labels = list(data_dicts[0].keys())
    labels[10] = 'No'
    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()

    labels += labels[:1]
    angles += angles[:1]

    fig = plt.figure(figsize=(20, 20))
    ax = plt.subplot(polar=True)

    for i, data_dict in enumerate(data_dicts):
        values = list(data_dict.values())
        values += values[:1]
        ax.fill(angles, values, color=colors[i], alpha=0.25)
        ax.plot(angles, values, color=colors[i], linewidth=3, label=model_list[i])

    ax.set_yticklabels([])
    if type == 'score':
        ax.set_ylim(0, 0.25)
    else:
        ax.set_ylim(0, 1)
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels[:-1], fontdict=None, size=24)
    res = plt.legend(loc='upper center',ncol=6, markerscale=4.0, handlelength=2.5, labelspacing=0.4, mode=None,
                    borderaxespad=-4.5, prop={'size': 25, 'weight': 'regular'},
                    fancybox=False, edgecolor='#FFFFFF', facecolor='#FFFFFF')
    for i in range(len(model_list)):
        res.get_lines()[i].set_linewidth(5)
    pdf = PdfPages(f'{EXP2_PATH}{metric_name}/exp2_header.pdf')
    pdf.savefig(fig)
    pdf.close()
    
def plot_bar_result(type, model_list, metric_name):
    # Load data
    with open(f'{EXP2_PATH}{metric_name}/model_{type}.jsonl', 'r') as f:
        data_dicts = [json.loads(line.strip()) for line in f]

    # Extract data
    maxs = [item['Max']     for item in data_dicts]
    avgs = [item['Average'] for item in data_dicts]
    mins = [item['Min']     for item in data_dicts]
    
    # Colors
    colors = ["#A8E6CF", "#DCE775", "#FFD3B6"]
    
    # Plot
    fig, ax = plt.subplots(figsize=(12, 10))
    ax.bar(model_list, maxs, color=colors[0], width=0.6 ,label='Max')
    ax.bar(model_list, avgs, color=colors[1], width=0.6 ,label='Average')
    ax.bar(model_list, mins, color=colors[2], width=0.6 ,label='Min')

    # Set axis labels and legend
    font_size = 28
    ax.set_xlabel("Models", fontsize=font_size+8)
    ax.set_ylabel("Value", fontsize=font_size+8)
    ax.set_xticks(model_list)
    ax.set_xticklabels(model_list, rotation=30, fontsize=font_size)
    ax.set_yticklabels(np.around(ax.get_yticks(), 2), fontsize=font_size)
    ax.legend(fontsize=font_size, loc='upper right')
    plt.tight_layout()
    
    # Set y-axis range based on type
    ax.set_ylim(0, 0.45 if type == "score" else 1.2)

    # Save the figure
    with PdfPages(f'{EXP2_PATH}{metric_name}/bar_{type}.pdf') as pdf:
        pdf.savefig(fig)

def generate_question(metric_name, seed=42):
    random.seed(seed)
    np.random.seed(seed)

    MODEL_LIST = ['gpt-4', 'Llama2-7B-chat']
    SAVE_NAME_LIST = ['A' ,'B']

    for model,sname in zip(MODEL_LIST, SAVE_NAME_LIST):
        value_ids = np.random.randint(0, 11, 5)
        values, aliases = np.array(VALUE_LIST)[value_ids], np.array(VALUE_ALIAS_LIST)[value_ids]

        for value, alias in zip(values, aliases):
            value_data = []
            with open(f'{ROOT_PATH}/results/exp2/{metric_name}/{model}/{value}_prompt_{model}.jsonl', 'r') as f:
                for line in f:
                    value_data.append(json.loads(line.strip()))
            # sample data
            index = np.random.randint(0, 99, 10)
            questions        = [value_data[_]['question'] for _ in index]
            answers          = [value_data[_]['answer'] for _ in index]
            chose_answers    = [value_data[_]['chose_answer'] for _ in index]
            reasons          = [value_data[_]['reason'] for _ in index]
            baseline_reasons = [value_data[_]['baseline_reason'] for _ in index]
            d_scores         = [value_data[_]['answer_score'] for _ in index]
            c_scores         = [value_data[_]['reason_score'] for _ in index]

            # questiona and answer sheets
            tmp_qlist, tmp_alist = [],[]
            for q,a,ca,r,br,ds,cs in zip(questions, answers, chose_answers, reasons, baseline_reasons, d_scores, c_scores):
                # construct dict
                tmp_qlist.append({
                    'question': q,
                    'answer': a,
                    'chose_answer': ca,
                    'reason': r,
                    'baseline_reason': br,
                    'answer_score': -1,
                    'reason_score': -1,
                })
                tmp_alist.append({
                    'question': q,
                    'answer': a,
                    'chose_answer': ca,
                    'reason': r,
                    'baseline_reason': br,
                    'answer_score': ds,
                    'reason_score': cs
                })

            # save question
            # with open(f'{EXP2_PATH}/analysis/question_{sname}_{alias}.jsonl', 'w') as f:
            #     for item in tmp_qlist:
            #         f.write(json.dumps(item) + '\n')
            # save socre
            with open(f'{EXP2_PATH}/analysis/answer_{sname}_{alias}_{value}.jsonl', 'w') as f:
                for item in tmp_alist:
                    f.write(json.dumps(item) + '\n')


if __name__ == '__main__':
    MODEL_NAME_LIST = ['GPT-3.5-Turbo', 'GPT-4', 'Llama2-7B', 'Llama2-13B', 'Vicuna-33B']
    MODEL_LIST = ['gpt-3.5-turbo', 'gpt-4', 'Llama2-7B-chat', 'Llama2-13B-chat', 'vicuna-33B']
    TEST_METRIC = functools.partial(get_diag_gpt4_similarity_exp2, auto=False)
    METRIC_NAME = "gpt4_sim"

    calculate_score = False
    generate_latex = False
    generate_bar = False
    check = False
    analysis = True
    

    if calculate_score:
        tasks = []
        for model in MODEL_LIST:
            for value in VALUE_LIST:
                tasks.append((model, value, TEST_METRIC, True, METRIC_NAME))

        Parallel(n_jobs=4)(delayed(process_and_save)(*task) for task in tasks)
    
    if generate_latex:
        TYPE_LIST = ['answer_score', 'reason_score', 'score']

        for type in TYPE_LIST:
            calculate_result(MODEL_LIST, VALUE_LIST, type, metric_name=METRIC_NAME)
            plot_result(type, MODEL_LIST, metric_name=METRIC_NAME)
        calculate_result(MODEL_LIST, VALUE_LIST, 'nagap', metric_name=METRIC_NAME)
        # get_header('score', MODEL_NAME_LIST, METRIC_NAME) # ToDo need 
    
    if generate_bar:
        TYPE_LIST = ['answer_score', 'reason_score', 'score']

        for type in TYPE_LIST:
            plot_bar_result(type, MODEL_NAME_LIST, metric_name=METRIC_NAME)

    
    if check:
        check_consistency(MODEL_LIST, VALUE_LIST)

    if analysis:
        generate_question(METRIC_NAME)
