from reasoner import retrieve_answer_math
from utils.utils import read_jsonl, write_jsonl
from utils.grader import math_equal
import tqdm
import itertools
import random
import json
import os
import pandas as pd

def best_of_four(model_name, dataset, method_list):
    if dataset == 'math':
        math_answers = retrieve_answer_math(model_name)
    else:
        raise ValueError("Dataset not supported")

    total_scores = []
    dataset_list = list(math_answers[method_list[0]].keys())
    for dataset in dataset_list:
        scores = []
        for idx, row in math_answers[method_list[0]][dataset].items():
            correct_list = [math_answers[method][dataset][idx]['correct'] for method in method_list]
            # check if the answer is the same
            if any(correct_list):
                scores.append(True)
                
            else:
                scores.append(False)
        
        accuracy = sum(scores) / len(scores)
        print(f"Accuracy for model {model_name} on dataset {dataset} is {accuracy}")
                
        total_scores += scores
    
    total_accuracy = sum(total_scores) / len(total_scores)
    print(f"Total accuracy for model {model_name} is {total_accuracy}")

def best_of_two(model_name, dataset, method_1, method_2):
    if dataset == 'math':
        math_answers = retrieve_answer_math(model_name)
    else:
        raise ValueError("Dataset not supported")

    total_scores = []
    dataset_list = list(math_answers[method_list[0]].keys())

    for dataset in dataset_list:
        for idx, row in math_answers[method_1][dataset].items():

            correct_1 = math_answers[method_1][dataset][idx]['correct']
            correct_2 = math_answers[method_2][dataset][idx]['correct']
            # check if the answer is the same
            if correct_1 or correct_2:
                total_scores.append(True)
            else:
                total_scores.append(False)
    print(f'total accuracy for {method_1} and {method_2} is {sum(total_scores) / len(total_scores)}')



def chosen_accuracy(model_name, dataset, hard_level, method_list):

    if dataset == 'math':
        math_answers = retrieve_answer_math(model_name)
    else:
        raise ValueError("Dataset not supported")

    total_scores = []
    dataset_list = list(math_answers[method_list[0]].keys())

    for dataset in dataset_list:
        
        # data_path = f"data/{dataset}.jsonl"
        if 'gpt' in model_name:
            data_path = f'results/gpt/{model_name}_{dataset}_{hard_level}_metamath.jsonl'
        else:
            data_path = f'results/{model_name}_{dataset}_{hard_level}_metamath.jsonl'
        data = read_jsonl(data_path)
        # data = [row for row in data if row['level'] == hard_level]
        scores = []

        for row in tqdm.tqdm(data):
            idx = row['idx'] 
            decision = row['decision']
            flag_dic = {}
            for method in method_list:
                # print(math_answers[method][dataset])
                answer = math_answers[method][dataset][idx]['answer']
                try:
                    flag = math_equal(answer, row['ground_truth'])
                    flag_dic[method] = flag
                except:
                    flag_dic[method] = False
            # print(flag_list)

            if not any(flag_dic.values()) or flag_dic[decision]: # if no method is correct or the chosen method is correct
                scores.append(True)
            
            else:
                scores.append(False)
            
        accuracy = sum(scores) / len(scores)
        print(f"Accuracy for model {model_name} on dataset {dataset} is {accuracy}")
    

        total_scores += scores

    total_accuracy = sum(total_scores) / len(total_scores)
    print(f"Total accuracy for model {model_name} is {total_accuracy}")

def get_preferred_list(answers):
    # get the average accuracy of the method in the method_list
    method_list = list(answers.keys())
    accuracy_list = []
    for method in method_list:
        correct_list = [answers[method][dataset][idx]['correct'] for dataset in answers[method] for idx in answers[method][dataset]]
        accuracy = sum(correct_list) / len(correct_list)
        accuracy_list.append(accuracy)
    
    sorted_list = sorted(zip(method_list, accuracy_list), key=lambda x: x[1], reverse=True)
    return [x[0] for x in sorted_list]

import math

def get_softmax_weights(answers, method_list, tau=0.05):
    """
    Computes the softmax weights for each method based on its average accuracy.
    The weight of each method is calculated using the softmax function over the accuracies,
    scaled by a temperature parameter tau.
    
    Args:
        answers (dict): The answer data, including correctness information for each method.
        method_list (list): A list of methods to evaluate.
        tau (float): The temperature parameter that controls the sharpness of the distribution.

    Returns:
        dict: A dictionary of methods and their corresponding softmax weights.
    """
    # Compute average accuracy for each method
    accuracy_dict = {}
    for method in method_list:
        correct_count = sum([answers[method][dataset][idx]['correct'] for dataset in answers[method] for idx in answers[method][dataset]])
        total_count = sum([len(answers[method][dataset]) for dataset in answers[method]])
        accuracy = correct_count / total_count if total_count > 0 else 0.0
        accuracy_dict[method] = accuracy

    # Compute softmax weights
    max_accuracy = max(accuracy_dict.values())  # For numerical stability, scale by max accuracy
    exp_scores = {method: math.exp((accuracy_dict[method] - max_accuracy) / tau) for method in accuracy_dict}
    total_exp = sum(exp_scores.values())

    # Normalize to get probabilities
    softmax_weights = {method: exp_scores[method] / total_exp for method in exp_scores}

    return softmax_weights


def obtain_finetuned_data(model_name, dataset, method_list, hard_levels=['Level 5'], mode='multi-label', is_train=False):
    """
    Obtain the label to finetune the classifier. If only one method can correctly solve the problem, the label is that exact method. 
        If multiple methods can correctly solve the problem or no method can correctly solve the problem, breaking tie based on the preference of the method_list.
        The preference list is ordered by the average accuracy of the method in the method_list.
    Args:
        model_name: The name of the model.
        dataset: The dataset to evaluate the model.
        method_list: The list of methods to evaluate the model.
    Returns:
        None
    """
    if dataset == 'math':
        answers = retrieve_answer_math(model_name,  hard_levels=hard_levels, is_train=is_train)
    else:
        raise ValueError("Dataset not supported")

    finetuned_raw_data = []
    dataset_list = list(answers[method_list[0]].keys())

    preferred_list = get_preferred_list(answers)

    if mode == 'softmax':
        softmax_weights = get_softmax_weights(answers, method_list)

    all_labels = []
    

    for dataset in dataset_list:
        
        for idx, row in answers[method_list[0]][dataset].items():
            correct_list = [answers[method][dataset][idx]['correct'] for method in method_list]
            # check if the answer is the same
            correct_method = [method for method, correct in zip(method_list, correct_list) if correct]
            
            if mode == 'multi-label':
                # if len(correct_method) > 0: # have at least one correct method
                labels = correct_method
                all_labels.extend(labels)
                finetuned_raw_data.append({
                    'type': dataset,
                    'idx': idx,
                    'level': 'Level 5',
                    'question': row['question'],
                    'ground_truth': row['ground_truth'],
                    'label': labels})
                continue
            
            
            if len(correct_method) == 1:
                label = correct_method[0]
            
            elif len(correct_method) > 1:

                if mode == 'random':
                    label = random.choice(correct_method)
                elif mode == 'hardmax':
                    # if multiple methods can correctly solve the problem, breaking tie based on the preference of the method_list
                    for method in preferred_list:
                        if method in correct_method:
                            label = method
                            break  
                elif mode == 'softmax':
                    # Select label based on softmax weights
                    method_probs = [softmax_weights[method] for method in correct_method]
                    label = random.choices(correct_method, method_probs)[0]  # Randomly select based on probabilities


            else:
                if mode == 'random':
                    label = random.choice(method_list)
                elif mode == 'hardmax':
                    # incorrect_method = [method for method, correct in zip(method_list, correct_list) if not correct]
                    label = preferred_list[0]
                elif mode == 'softmax':
                    # Select label based on softmax weights
                    method_probs = [softmax_weights[method] for method in method_list]
                    label = random.choices(method_list, method_probs)[0]
                

            all_labels.append(label)

            finetuned_raw_data.append({
                'type': dataset,
                'idx': idx,
                'level': 'Level 5',
                'question': row['question'],
                'ground_truth': row['ground_truth'],
                'label': label})
    
    print(f"Total number of examples: {sum([len(answers[method_list[0]][dataset]) for dataset in dataset_list])}")
    print(f"Total number of valid examples: {len(finetuned_raw_data)}")
    # get the distribution of the labels
    label_distribution = {label: all_labels.count(label) for label in method_list}
    print(label_distribution)
    

    if 'Llama-3.1-8B' in model_name:
        stored_dic = f"ft_data/llama3.1-8b/"
    elif 'Llama-3.1-70B' in model_name:
        stored_dic = f"ft_data/llama3.1-70b/"
    elif 'gpt-4o-mini' in model_name:
        stored_dic = f"ft_data/gpt4o-mini/"
    else:
        raise ValueError("Model not supported")
    
    # f"ft_data/llama3.1-8b/data_{mode}.jsonl"
    if is_train:
        stored_dic += 'train/'
    else:
        stored_dic += 'test/'

    if not os.path.exists(stored_dic):
        os.makedirs(stored_dic)

    string = 'lvl_'
    for hard_level in hard_levels:
        string += hard_level.replace('Level ', '') 
    stored_path = stored_dic + f"data_{mode}_{string}.jsonl"
    answers_stored_path = stored_dic + f"answers_{string}.json"

    # shuffle the data
    random.shuffle(finetuned_raw_data)
    write_jsonl(finetuned_raw_data, stored_path)
    # store answers
    dir_name = os.path.dirname(answers_stored_path)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    json.dump(answers, open(answers_stored_path, 'w'))

    return stored_path


def obtain_fintuned_data_multi_shots(model_name, dataset, method_list, hard_levels=['Level 5'], num_shots=3, is_train=False):
    # if dataset == 'math':
    #     answers = retrieve_answer_math(model_name,  hard_levels=hard_levels, is_train=is_train)
    # else:
    #     raise ValueError("Dataset not supported")
    if dataset == 'math':
        dataset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
    else:
        raise ValueError("Dataset not supported")
    
    # finetuned_data = []
    row_data = {}

    for dataset in dataset_list:
        row_data[dataset] = {}
        for hard_level in hard_levels:
            row_data[dataset][hard_level] = {}

            for method in method_list:
                data_path = f'results/{model_name}_{dataset}_{hard_level}_{method}_nshots_{num_shots}'
                if is_train:
                    data_path += '_train'
                data_path += '_eval.jsonl'

                data = read_jsonl(data_path)
                
                row_data[dataset][hard_level][method] = {}

            
                for row in data:
                    idx = row['idx']
                    question = row['question']
                    ground_truth = row['ground_truth']
                    accuracy = row['accuracy']
                    row_data[dataset][hard_level][method][idx] = {
                        'type': dataset,
                        'idx': idx,
                        'level': hard_level,
                        'question': question,
                        'ground_truth': ground_truth,
                        'accuracy': accuracy,
                    }
    finetune_data = []
    for dataset in dataset_list:
        for hard_level in hard_levels:
            for idx in row_data[dataset][hard_level][method_list[0]]:
                accuracy_list = [row_data[dataset][hard_level][method][idx]['accuracy'] for method in method_list]
                max_accuracy_method = method_list[accuracy_list.index(max(accuracy_list))]

                finetune_data.append({
                    'type': dataset,
                    'idx': idx,
                    'level': hard_level,
                    'question': row_data[dataset][hard_level][max_accuracy_method][idx]['question'],
                    'ground_truth': row_data[dataset][hard_level][max_accuracy_method][idx]['ground_truth'],
                    'label': max_accuracy_method
                })
    if 'Llama-3.1-8B' in model_name:
        stored_dic = f"ft_data/llama3.1-8b/"
    elif 'Llama-3.1-70B' in model_name:
        stored_dic = f"ft_data/llama3.1-70b/"
    elif 'gpt-4o-mini' in model_name:
        stored_dic = f"ft_data/gpt4o-mini/"
    else:
        raise ValueError("Model not supported")
    
    # f"ft_data/llama3.1-8b/data_{mode}.jsonl"
    if is_train:
        stored_dic += 'train/'
    else:
        stored_dic += 'test/'

    if not os.path.exists(stored_dic):
        os.makedirs(stored_dic)

    string = 'lvl_'
    for hard_level in hard_levels:
        string += hard_level.replace('Level ', '') 
    stored_path = stored_dic + f"data_nshots_{num_shots}_{string}.jsonl"
    answers_stored_path = stored_dic + f"answers_{string}.json"

    # shuffle the data
    # random.shuffle(finetuned_raw_data)
    write_jsonl(finetune_data, stored_path)
    # store answers
    dir_name = os.path.dirname(answers_stored_path)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    # json.dump(answers, open(answers_stored_path, 'w'))

    return stored_path



def split_data(data_path, train_size=None):
    """
    Split the data into train and test set and store the indices in a csv file
    Args:
        model_name: str
            Name of the model
        train_size: int
            Number of examples in the train set
    """
    # folder = f'./ft_data/{model_name}'
    # for file in os.listdir(folder):
    #     if file =='data.jsonl':
    #         data = read_jsonl(os.path.join(folder, file))
    data = read_jsonl(data_path)
    print(f"Total number of examples: {len(data)}")
    # test the label distribution of the data
    # label_count = {}
    for d in data:
        labels = d['label']
        # if d['label'] not in label_count:
        #     label_count[d['label']] = 0
        # label_count[d['label']] += 1
    # print(label_count)
    random.shuffle(data)
    if train_size is None:
        train_size = len(data)
    train_data = data[:train_size]
    train_idx = []
    for d in train_data:
        train_idx.append({
            'type': d['type'],
            'idx': d['idx']
        })

    df = pd.DataFrame(train_idx)

    df.to_csv(data_path.replace('.jsonl', f'_train_{train_size}_idx.csv'), index=False)
    write_jsonl(train_data, data_path.replace('.jsonl', f'_train_{train_size}.jsonl'))


if __name__ == '__main__':
    # model_name = 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo'
    model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    # model_name = 'gpt-4o-mini'
    # model_name = 'gpt-4o-mini'
    # dataset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra','precalculus', 'prealgebra']
    hard_level = "Level 5"
    method_list = ['cot', 'pal', 'codenl_single', 'nlcode_single']
    
    is_train = True

    dataset = 'math'
    # for model_name in ['meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo', 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo', 'gpt-4o-mini']:
        
    # hard_levels = ['Level 3', 'Level 4', 'Level 5']
    # stored_path = obtain_finetuned_data(model_name, dataset, method_list, mode='multi-label', hard_levels=hard_levels, is_train=is_train)

    hard_levels = ['Level 5']
    stored_path = obtain_fintuned_data_multi_shots(model_name, dataset, method_list, hard_levels=hard_levels, num_shots=3, is_train=is_train)


    # split_data(stored_path)
    # model_name = 'gpt-4o-mini'
    # best_of_four(model_name, dataset, method_list)

    # pairs = itertools.combinations(method_list, 2)

    # # Run best_of_two for each pair
    # for method_1, method_2 in pairs:
    #     best_of_two(model_name, dataset, method_1, method_2)

    # chosen_accuracy(model_name, dataset_list, hard_level, method_list)
