from utils.utils import *
from reasoner import retrieve_answer_math
import random

def random_choose(answer_data, total_num=25, wrong_method='pal', subset_list=['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra','precalculus', 'prealgebra']):
    sample =  [] # (method, idx)
    while len(sample) < total_num:
        subset = random.choice(subset_list)
        # total_len = len(answer_data[wrong_method][subset])
        idx = random.choice(list(answer_data[wrong_method][subset].keys()))
        # print(idx)
        # print(answer_data[wrong_method][subset].keys())
        # print(answer_data[wrong_method][subset])
        if answer_data[wrong_method][subset][idx]['correct'] == False:
            sample.append((subset, idx))
        else:
            continue
    return sample

def sample_result(model_name, wrong_method='cot'):

    answer_data = retrieve_answer_math(model_name=model_name)
    samples = random_choose(answer_data, wrong_method=wrong_method)
    results = []
    for sample in samples:
        _type = f'{wrong_method}_wrong'
        subset, idx = sample
        result = {
            'type': _type,
            'subset': subset,
            'idx': idx,
            'question': answer_data['pal'][subset][idx]['question'],
            'ground_truth': answer_data['pal'][subset][idx]['ground_truth'],

            'pal_answer': answer_data['pal'][subset][idx]['answer'],
            'pal_rp': answer_data['pal'][subset][idx]['reasoning_path'],
            'pal_correct': answer_data['pal'][subset][idx]['correct'],

            
            'nlcode_answer': answer_data['nlcode'][subset][idx]['answer'],
            'nlcode_rp': answer_data['nlcode'][subset][idx]['reasoning_path'],
            'nlcode_correct': answer_data['nlcode'][subset][idx]['correct'],

            'cot_answer': answer_data['cot'][subset][idx]['answer'],
            'cot_rp': answer_data['cot'][subset][idx]['reasoning_path'],
            'cot_correct': answer_data['cot'][subset][idx]['correct'],
            
            'codenl_answer': answer_data['codenl'][subset][idx]['answer'],
            'codenl_rp': answer_data['codenl'][subset][idx]['reasoning_path'],
            'codenl_correct': answer_data['codenl'][subset][idx]['correct'],
        }
        results.append(result)
    
    model_name = model_name.split('/')[-1]
    print(model_name)
    stored_path = f'analysis/{model_name}_{wrong_method}_wrong.json'

    import json

    with open(stored_path, 'w') as f:
        json.dump(results, f, indent=4)


def one_correct_one_wrong(correct_method, wrong_method, model_name= 'gpt-4o-mini'):
    

    answer_data = retrieve_answer_math(model_name)
    subset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra','precalculus', 'prealgebra']
    total_num = 25
    # def random_choose(answer_data, total_num=25, wrong_method='pal', subset_list=['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra','precalculus', 'prealgebra']):
    samples =  [] # (method, idx)
    while len(samples) < total_num:
        subset = random.choice(subset_list)
        # total_len = len(answer_data[wrong_method][subset])
        idx = random.choice(list(answer_data[wrong_method][subset].keys()))

        if answer_data[wrong_method][subset][idx]['correct'] == False and answer_data[correct_method][subset][idx]['correct'] == True:
            samples.append((subset, idx))
        else:
            continue
    
    results = []
    for sample in samples:
        _type = f'{correct_method}_correct_{wrong_method}_wrong'
        subset, idx = sample
        result = {
            'type': _type,
            'subset': subset,
            'idx': idx,
            'question': answer_data[correct_method][subset][idx]['question'],
            'answer': answer_data[correct_method][subset][idx]['gold_reasoning_path'],
            'ground_truth': answer_data[correct_method][subset][idx]['ground_truth'],

            f'{correct_method}_answer': answer_data[correct_method][subset][idx]['answer'],
            f'{correct_method}_rp': answer_data[correct_method][subset][idx]['reasoning_path'],
            f'{correct_method}_correct': answer_data[correct_method][subset][idx]['correct'],

            
            f'{wrong_method}_answer': answer_data[wrong_method][subset][idx]['answer'],
            f'{wrong_method}_rp': answer_data[wrong_method][subset][idx]['reasoning_path'],
            f'{wrong_method}_correct': answer_data[wrong_method][subset][idx]['correct'],
        }
        results.append(result)
    
    model_name = model_name.split('/')[-1]
    print(model_name)
    stored_path = f'analysis/{model_name}_{correct_method}_correct_{wrong_method}_wrong.json'
    
    import json
    
    with open(stored_path, 'w') as f:
        json.dump(results, f, indent=4)



def plot_relation(model_name= 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'):
    dataset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra','precalculus', 'prealgebra']

    method_1 = 'pal'
    method_2 = 'nlcode'

    all_answers = retrieve_answer_math(model_name)
    records = {
        f'{method_1}_correct_{method_2}_correct': 0,
        f'{method_1}_correct_{method_2}_wrong': 0,
        f'{method_1}_wrong_{method_2}_correct': 0,
        f'{method_1}_wrong_{method_2}_wrong': 0
    }

    answers_1 = all_answers[method_1]
    answers_2 = all_answers[method_2]

    for dataset in dataset_list:
        keys = list(answers_1[dataset].keys())
        for i in keys:
            # print(f'{answers_1[dataset]}')
            if answers_1[dataset][i]['correct'] == answers_2[dataset][i]['correct']:
                if answers_1[dataset][i]['correct']:
                    records[f'{method_1}_correct_{method_2}_correct'] += 1
                else:
                    records[f'{method_1}_wrong_{method_2}_wrong'] += 1
            else:
                if answers_1[dataset][i]['correct']:
                    records[f'{method_1}_correct_{method_2}_wrong'] += 1
                else:
                    records[f'{method_1}_wrong_{method_2}_correct'] += 1
        
    for key, value in records.items():
        print(f'{key}: {value}')

    # plot a metrix
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np

    data = np.array([[records[f'{method_1}_correct_{method_2}_correct'], records[f'{method_1}_correct_{method_2}_wrong']],
                    [records[f'{method_1}_wrong_{method_2}_correct'], records[f'{method_1}_wrong_{method_2}_wrong']]])
    df_cm = pd.DataFrame(data, columns=['Correct', 'Wrong'], index=['Correct', 'Wrong'])
    plt.figure(figsize=(10,7))
    sns.heatmap(df_cm, annot=True, fmt='d', cmap='coolwarm')
    plt.xlabel(f'{method_2}')
    plt.ylabel(f'{method_1}')
    plt.title(f'{method_1} vs {method_2}, {model_name}')

    plt.savefig(f'analysis/{model_name}_{method_1}_vs_{method_2}.png')



if __name__ == "__main__":
    # model_name = 'gpt-4o-mini'
    # model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    # sample_result(model_name, 'cot')
    # model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    # plot_relation(model_name)
    first_list = ['cot', 'pal']
    second_list = ['codenl', 'nlcode']

    for first in first_list:
        for second in second_list:
            one_correct_one_wrong(first, second)
            one_correct_one_wrong(second, first)
    
