import json
# import pearson and kendall from scipy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, kendalltau
def separate_modes(annotation_path, dataset_path):
    modes = ["human_vs_human", "human_vs_model", "model_vs_model"]
    comparisions = ["human_a_vs_human_b", "human_a_vs_model_a", "human_a_vs_model_b", "model_a_vs_model_b"]
    baseline_names = ['autoJ', 'Critique', 'GPT2', 'TIGER', 'BERT', 'ROUGE', 'BLEU', 'BART', 'BLEURT', 'UNIEVAL', 'GPT-4o', 'ChatGPT']
    baselines = []
    with open(annotation_path, 'r') as f:
        annotations = json.load(f)
    with open(dataset_path, 'r') as f:
        dataset = json.load(f)
    info = []
    for annot, data in zip(annotations, dataset):
        student_answer_a = data['refined_student_answer_a']
        student_answer_b = data['refined_student_answer_b']
        model_answer_a = data['model_answer_a']
        model_answer_b = data['model_answer_b']
        info.append(
            {
                "type": "human_vs_human",
                "context": data['context'],
                "question": data['Question'],
                "reference": data['Reference'],
                "key_information": data['Concise_Reference'],
                "answer_a": student_answer_a,
                "answer_b": student_answer_b,
                "autoj": data['autoJ_human_a_vs_human_b'],
                "critique": data['Critique_human_a_vs_human_b'],
                "gpt2": data['GPT2_human_a_vs_human_b'],
                "tiger": data['TIGER_human_a_vs_human_b'],
                "bert": data['BERT_human_a_vs_human_b'],
                "rouge": data['ROUGE_human_a_vs_human_b'],
                "bleu": data['BLEU_human_a_vs_human_b'],
                "bart": data['BART_human_a_vs_human_b'],
                "bleurt": data['BLEURT_human_a_vs_human_b'],
                "unieval": data['UNIEVAL_human_a_vs_human_b'],
                "gpt4o": data['GPT-4o_human_a_vs_human_b'],
                "chatgpt": data['ChatGPT_human_a_vs_human_b'],
            }
        )
        if 'a' in annot['student_answer_a vs student_answer_b']:
            info[-1]['decision'] = 'human_a'
        elif 'b' in annot['student_answer_a vs student_answer_b']:
            info[-1]['decision'] = 'human_b'
        else:
            info[-1]['decision'] = 'tie'
        info.append(
            {
                "type": "human_vs_model",
                "context": data['context'],
                "question": data['Question'],
                "reference": data['Reference'],
                "key_information": data['Concise_Reference'],
                "answer_a": model_answer_a,
                "answer_b": model_answer_b,
                "autoj": data['autoJ_human_a_vs_model_a'],
                "critique": data['Critique_human_a_vs_model_a'],
                "gpt2": data['GPT2_human_a_vs_model_a'],
                "tiger": data['TIGER_human_a_vs_model_a'],
                "bert": data['BERT_human_a_vs_model_a'],
                "rouge": data['ROUGE_human_a_vs_model_a'],
                "bleu": data['BLEU_human_a_vs_model_a'],
                "bart": data['BART_human_a_vs_model_a'],
                "bleurt": data['BLEURT_human_a_vs_model_a'],
                "unieval": data['UNIEVAL_human_a_vs_model_a'],
                "gpt4o": data['GPT-4o_human_a_vs_model_a'],
                "chatgpt": data['ChatGPT_human_a_vs_model_a'],
            }
        )
        if 'a' in annot['stud_answer_a vs model_answer_a']:
            info[-1]['decision'] = 'human_a'
        elif 'b' in annot['stud_answer_a vs model_answer_a']:
            info[-1]['decision'] = 'model_a'
        else:
            info[-1]['decision'] = 'tie'
        info.append(
            {
                "type": "human_vs_model",
                "context": data['context'],
                "question": data['Question'],
                "reference": data['Reference'],
                "key_information": data['Concise_Reference'],
                "answer_a": student_answer_a,
                "answer_b": model_answer_b,
                "autoj": data['autoJ_human_a_vs_model_b'],
                "critique": data['Critique_human_a_vs_model_b'],
                "gpt2": data['GPT2_human_a_vs_model_b'],
                "tiger": data['TIGER_human_a_vs_model_b'],
                "bert": data['BERT_human_a_vs_model_b'],
                "rouge": data['ROUGE_human_a_vs_model_b'],
                "bleu": data['BLEU_human_a_vs_model_b'],
                "bart": data['BART_human_a_vs_model_b'],
                "bleurt": data['BLEURT_human_a_vs_model_b'],
                "unieval": data['UNIEVAL_human_a_vs_model_b'],
                "gpt4o": data['GPT-4o_human_a_vs_model_b'],
                "chatgpt": data['ChatGPT_human_a_vs_model_b'],
            }
        )
        if 'a' in annot['stud_answer_a vs model_answer_b']:
            info[-1]['decision'] = 'human_a'
        elif 'b' in annot['stud_answer_a vs model_answer_b']:
            info[-1]['decision'] = 'model_b'
        else:
            info[-1]['decision'] = 'tie'
        info.append(
            {
                "type": "model_vs_model",
                "context": data['context'],
                "question": data['Question'],
                "reference": data['Reference'],
                "key_information": data['Concise_Reference'],
                "answer_a": model_answer_a,
                "answer_b": model_answer_b,
                "autoj": data['autoJ_model_a_vs_model_b'],
                "critique": data['Critique_model_a_vs_model_b'],
                "gpt2": data['GPT2_model_a_vs_model_b'],
                "tiger": data['TIGER_model_a_vs_model_b'],
                "bert": data['BERT_model_a_vs_model_b'],
                "rouge": data['ROUGE_model_a_vs_model_b'],
                "bleu": data['BLEU_model_a_vs_model_b'],
                "bart": data['BART_model_a_vs_model_b'],
                "bleurt": data['BLEURT_model_a_vs_model_b'],
                "unieval": data['UNIEVAL_model_a_vs_model_b'],
                "gpt4o": data['GPT-4o_model_a_vs_model_b'],
                "chatgpt": data['ChatGPT_model_a_vs_model_b'],
            }
        )
        if 'a' in annot['model_answer_a vs model_answer_b']:
            info[-1]['decision'] = 'model_a'
        elif 'b' in annot['model_answer_a vs model_answer_b']:
            info[-1]['decision'] = 'model_b'
        else:
            info[-1]['decision'] = 'tie'
    with open('english/geo_final.json', 'w', encoding='utf-8') as f:
        json.dump(info, f, indent=4, ensure_ascii=False)

        
def one_mode(input_path, annotation_path):
    with open(input_path, 'r') as f:
        dataset = json.load(f)
    with open(annotation_path, 'r') as f:
        annotations = json.load(f)
    info = []
    for anno, data in zip(annotations, dataset):
        item = data
        model_answer_a = item['model_answer_a']
        model_answer_b = item['model_answer_b']
        info.append(
            {
                "type": "model_vs_model",
                "context": data['context'],
                "question": data['Question'],
                "reference": data['Reference'],
                "key_information": data['Concise_Reference'],
                "answer_a": model_answer_a,
                "answer_b": model_answer_b,
                "autoj": data['autoJ_model_a_vs_model_b'],
                "critique": data['Critique_model_a_vs_model_b'],
                "gpt2": data['GPT2_model_a_vs_model_b'],
                "tiger": data['TIGER_model_a_vs_model_b'],
                "bert": data['BERT_model_a_vs_model_b'],
                "rouge": data['ROUGE_model_a_vs_model_b'],
                "bleu": data['BLEU_model_a_vs_model_b'],
                "bart": data['BART_model_a_vs_model_b'],
                "bleurt": data['BLEURT_model_a_vs_model_b'],
                "unieval": data['UNIEVAL_model_a_vs_model_b'],
                "gpt4o": data['GPT-4o_model_a_vs_model_b'],
                "chatgpt": data['ChatGPT_model_a_vs_model_b'],
            }
        )
        if 'a' in anno['model_a_vs_model_b_new_new']:
            info[-1]['decision'] = 'model_a'
        elif 'b' in anno['model_a_vs_model_b_new_new']:
            info[-1]['decision'] = 'model_b'
        else:
            info[-1]['decision'] = 'tie'
    with open('english/psy_final.json', 'w', encoding='utf-8') as f:
        json.dump(info, f, indent=4, ensure_ascii=False)
        
    
def compare_acc(input_path):
    baseline_names = ['autoJ', 'Critique', 'GPT2', 'TIGER', 'BERT', 'ROUGE', 'BLEU', 'BART', 'BLEURT', 'UNIEVAL', 'GPT4o', 'ChatGPT']
    with open(input_path, 'r') as f:
        data = json.load(f)
    mode2len = {}
    for mode in ["human_vs_human", "human_vs_model", "model_vs_model"]:
        mode2len[mode] = 0
        for item in data:
            if item['type'] == mode:
                mode2len[mode] += 1
    for mode in ["human_vs_human", "human_vs_model", "model_vs_model"]: 
        print(f'{mode}:')
        for baseline in baseline_names:
            acc = 0
            for item in data:
                if item['type'] == mode and item[baseline.lower()] == item['decision']:
                    acc += 1
            print(f'{baseline}: {acc / mode2len[mode]}')


def overall_acc(input_path):
    baseline_names = ['ROUGE', 'BLEU', 'GPT2', 'BLEURT', 'UNIEVAL', 'BERT', 'BART', 'AutoJ', 'Critique', 'Tiger', 'ChatGPT', 'ChatGPT_CoT', 'ChatGPT_Eval', 'GPT-4o', 'GPT_4o_CoT', 'GPT-4o_Eval', 'ReFeR']
    with open(input_path, 'r') as f:
        data = json.load(f)
    for baseline in baseline_names:
        acc = 0
        acc_list = []
        cnt = 0
        for i, item in enumerate(data):
            if item['Human_preference'] == 'tie':
                continue
            cnt += 1
            acc_list.append(item[f"{baseline}_preference"] == item['Human_preference'])
            if item[f"{baseline}_preference"] == item['Human_preference']:
                acc += 1
        print(f'{baseline}: {acc / cnt}')
    
def mode_acc(input_path):
    baseline_names = ['ROUGE', 'BLEU', 'GPT2', 'BLEURT', 'UNIEVAL', 'BERT', 'BART', 'AutoJ', 'Critique', 'Tiger', 'ChatGPT', 'ChatGPT_CoT', 'ChatGPT_Eval', 'GPT-4o', 'GPT_4o_CoT', 'GPT-4o_Eval']
    with open(input_path, 'r') as f:
        data = json.load(f)

    for baseline in baseline_names:
        human_vs_human = []
        human_vs_model = []
        model_vs_model = []
        for item in data:
            if item['answer_a_type'] == 'human' and item['answer_b_type'] == 'human':
                human_vs_human.append(item[f"{baseline}_preference"] == item['Human_preference'])
            elif item['answer_a_type'] == 'human' and item['answer_b_type'] == 'model':
                human_vs_model.append(item[f"{baseline}_preference"] == item['Human_preference'])
            elif item['answer_a_type'] == 'model' and item['answer_b_type'] == 'model':
                model_vs_model.append(item[f"{baseline}_preference"] == item['Human_preference'])
        print(model_vs_model)
        print("Baseline:", baseline)
        print("Human vs Human:", sum(human_vs_human) / len(human_vs_human))
        print("Human vs Model:", sum(human_vs_model) / len(human_vs_model))
        print("Model vs Model:", sum(model_vs_model) / len(model_vs_model))
        
mode_acc('final_version/geo.json')
        
def length_preference(input_path):
    with open(input_path, 'r') as f:
        data = json.load(f)
    for item in data:
        item['Length_preference'] = 'a' if len(item['student_answer_a'].split()) > len(item['student_answer_b'].split()) else 'b'
    with open(input_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)
        
def baseline_correlation_length(input_path):
    baseline_names = ['ROUGE', 'BLEU', 'GPT2', 'BLEURT', 'UNIEVAL', 'BERT', 'BART', 'AutoJ', 'Critique', 'Tiger', 'ChatGPT', 'GPT-4o', 'Human']
    with open(input_path, 'r') as f:
        data = json.load(f)
    for baseline in baseline_names:
        acc = 0
        baseline_preference = []
        length_preference = []
        for item in data:
            if item[f"{baseline}_preference"] == 'a':
                baseline_preference.append(0)
            elif item[f"{baseline}_preference"] == 'b':
                baseline_preference.append(1)
            else:
                baseline_preference.append(2)
            if item['Length_preference'] == 'a':
                length_preference.append(0)
            elif item['Length_preference'] == 'b':
                length_preference.append(1)
            else:
                length_preference.append(2) 
        pearson_corr, _ = pearsonr(baseline_preference, length_preference)
        kendall_corr, _ = kendalltau(baseline_preference, length_preference)
        acc = sum([a == b for a, b in zip(baseline_preference, length_preference)]) / len(baseline_preference)
        print(f'{baseline}: Pearson: {pearson_corr}, Kendall: {kendall_corr}', f'Accuracy: {acc}')
        
# you should draw a confusion matrix for the baseline vs baseline comparison
def consistency_matrix_between_baselines(input_path):
    # Determine subject based on input path
    if 'geo' in input_path:
        subject = 'geo'
    elif 'psy' in input_path:
        subject = 'psy'
    elif 'pol' in input_path:
        subject = 'pol'
    elif 'med' in input_path:
        subject = 'med'
    elif 'law' in input_path:
        subject = 'law'
    else:
        subject = 'unknown'  # default value if none of the conditions match

    # Baseline names
    baselines = ['Length', 'ROUGE', 'BLEU', 'GPT2', 'BLEURT', 'UNIEVAL', 'BERTScore', 'BARTScore', 'Auto-J-13B', 'CritiqueLLM-6B', 'TIGERScore-13B', 'ChatGPT', 'ChatGPT-CoT', 'G-Eval(ChatGPT)', 'GPT-4o', 'GPT-4o-CoT', 'G-Eval(GPT-4o)']
    baseline_names = ['Length', 'ROUGE', 'BLEU', 'GPT2', 'BLEURT', 'UNIEVAL', 'BERT', 'BART', 'AutoJ', 'Critique', 'Tiger', 'ChatGPT', 'ChatGPT_CoT', 'ChatGPT_Eval', 'GPT-4o', 'GPT_4o_CoT', 'GPT-4o_Eval']
    
    assert len(baselines) == len(baseline_names), "Length of baselines and baseline_names should be the same"

    # Load data
    with open(input_path, 'r') as f:
        data = json.load(f)

    # Initialize matrix
    matrix = {}
    for baseline in baseline_names:
        matrix[baseline] = {}
        for baseline_ in baseline_names:
            matrix[baseline][baseline_] = 0

    # Fill matrix based on preferences
    for item in data:
        for baseline in baseline_names:
            for baseline_ in baseline_names:
                if item.get(f"{baseline}_preference") == 'a' and item.get(f"{baseline_}_preference") == 'a':
                    matrix[baseline][baseline_] += 1
                elif item.get(f"{baseline}_preference") == 'b' and item.get(f"{baseline_}_preference") == 'b':
                    matrix[baseline][baseline_] += 1
                elif item.get(f"{baseline}_preference") == 'tie' and item.get(f"{baseline_}_preference") == 'tie':
                    matrix[baseline][baseline_] += 1

    # Convert matrix to numpy array
    matrix_array = np.zeros((len(baseline_names), len(baseline_names)))
    for i, baseline in enumerate(baseline_names):
        for j, baseline_ in enumerate(baseline_names):
            matrix_array[i, j] = matrix[baseline][baseline_] / len(data)

    # Plotting the matrix
    plt.figure(figsize=(12, 9))
    ax = sns.heatmap(matrix_array, annot=True, fmt=".1f", xticklabels=baselines, yticklabels=baselines, cmap='Blues', cbar=False)
    
    # Set ticks only for baselines
    ax.set_xticklabels(baselines, rotation=45, ha='right')
    ax.set_yticklabels(baselines, rotation=0)
    
    # Remove axis labels
    ax.set_xlabel('')
    ax.set_ylabel('')
    
    plt.savefig(f'''{subject}_consistency_matrix.png''')
    plt.show()
    
BASE_URL = 'final_version/'

def consistency_matrix_between_baselines_modes(input_path, mode):
    if mode == 'human vs human':
        type_a = 'human'
        type_b = 'human'
    elif mode == 'human vs model':
        type_a = 'human'
        type_b = 'model'
    else :
        type_a = 'model'
        type_b = 'model'
    if 'geo' in input_path:
        subject = 'geo'
    elif 'psy' in input_path:
        subject = 'psy'
    elif 'pol' in input_path:
        subject = 'pol'
    elif 'med' in input_path:
        subject = 'med'
    elif 'law' in input_path:
        subject = 'law'
    elif 'his' in input_path:
        subject = 'his'
    else:
        subject = 'unknown'
    # Baseline names
    baselines = ['Length', 'ROUGE', 'BLEU', 'GPT2', 'BLEURT', 'UNIEVAL', 'BERTScore', 'BARTScore', 'Auto-J-13B', 'CritiqueLLM-6B', 'TIGERScore-13B', 'ChatGPT', 'ChatGPT-CoT', 'G-Eval(ChatGPT)', 'GPT-4o', 'GPT-4o-CoT', 'G-Eval(GPT-4o)']
    baseline_names = ['Length', 'ROUGE', 'BLEU', 'GPT2', 'BLEURT', 'UNIEVAL', 'BERT', 'BART', 'AutoJ', 'Critique', 'Tiger', 'ChatGPT', 'ChatGPT_CoT', 'ChatGPT_Eval', 'GPT-4o', 'GPT_4o_CoT', 'GPT-4o_Eval']

    # Load data
    with open(input_path, 'r') as f:
        data = json.load(f)
        
    # Initialize matrix
    matrix = {}
    for baseline in baseline_names:
        matrix[baseline] = {}
        for baseline_ in baseline_names:
            matrix[baseline][baseline_] = 0
    
    # Fill matrix based on preferences
    cnt = 0 
    for item in data:
        if item['answer_a_type'] == type_a and item['answer_b_type'] == type_b:
            cnt += 1
            for baseline in baseline_names:
                for baseline_ in baseline_names:
                    if item.get(f"{baseline}_preference") == 'a' and item.get(f"{baseline_}_preference") == 'a':
                        matrix[baseline][baseline_] += 1
                    elif item.get(f"{baseline}_preference") == 'b' and item.get(f"{baseline_}_preference") == 'b':
                        matrix[baseline][baseline_] += 1
                    elif item.get(f"{baseline}_preference") == 'tie' and item.get(f"{baseline_}_preference") == 'tie':
                        matrix[baseline][baseline_] += 1
    # Convert matrix to numpy array
    matrix_array = np.zeros((len(baseline_names), len(baseline_names)))
    for i, baseline in enumerate(baseline_names):
        for j, baseline_ in enumerate(baseline_names):
            matrix_array[i, j] = matrix[baseline][baseline_] / cnt
    # Plotting the matrix
    plt.figure(figsize=(11, 9))
    ax = sns.heatmap(matrix_array, annot=True, fmt=".1f", xticklabels=baselines, yticklabels=baselines, cmap='Blues', cbar=False)
    # Set ticks only for baselines
    ax.set_xticklabels(baselines, rotation=45, ha='right')
    ax.set_yticklabels(baselines, rotation=0)
    # Remove axis labels
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.savefig(f'''{subject}_{mode}_consistency_matrix.png''')
    plt.show()
    
consistency_matrix_between_baselines_modes(f'{BASE_URL}final.json', 'model vs model')
def test(input_path):
    with open(input_path, 'r') as f:
        data = json.load(f)
    acc_4 = 0
    acc_0 = 0
    acc_1 = 0
    acc_3 = 0
    for item in data:
        gpt_votes = item['ChatGPT_Eval_votes']
        gpt4_votes = item['GPT-4o_Eval_votes']
        gpt_preference = item['ChatGPT_Eval_preference']
        gpt4_preference = item['GPT-4o_Eval_preference']
        human_preference = item['Human_preference']
        if human_preference in gpt_votes:
            acc_3 += 1
        if human_preference in gpt4_votes:
            acc_4 += 1
        if gpt_preference == human_preference:
            acc_0 += 1
        if gpt4_preference == human_preference:
            acc_1 += 1
    print(acc_3 / len(data), acc_4 / len(data))
    print(acc_0 / len(data), acc_1 / len(data))
    
def check(input_path):
    with open('final_version/geo.json', 'r') as f:
        data = json.load(f)
    human_preference = [d['Human_preference'] for d in data]
    for i in range(0, len(data), 4):
        print(human_preference[i])
    
check('final_version/geo.json')