import os
from utils.utils import read_jsonl, write_jsonl
import random
import json

def retrieve_answer_math(model_name, 
                         hard_levels= ['Level 5'], 
                         methods=['cot', 'pal', 'codenl_single', 'nlcode_single', 'codenl', 'nlcode'],
                         is_train=False):
        """
        Retrieve the answer from the stored file
        Args:
            model_name: str, the name of the model
           
        Returns:
            answer: {'cot': {'algebra':[{reasoning_path: str, answer: str}]}}

        """
        dataset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
        # methods = ['cot', 'pal', 'codenl', 'nlcode', 'codenl_single', 'nlcode_single']
        answers = {}
        for method_name in methods:
            answers[method_name] = {}
            for data_type in dataset_list:
                answers[method_name][data_type] = {}
                data = []
                # if isinstance(hard_level, str): # for the case of multiple hard levels
                #     hard_levels = [hard_level]
                # else:
                #     hard_levels = hard_level
                for hard_level in hard_levels:  
                    if 'gpt' in model_name:
                        file_path = f'new_results/gpt/{model_name}_{data_type}_{hard_level}_{method_name}_greedy.jsonl'
                        # print(f"File path: {file_path}")
                    elif 'meta-llama' in model_name:
                        file_path = f'new_results/{model_name}_{data_type}_{hard_level}_{method_name}_greedy.jsonl'
                    else:
                        raise ValueError(f"Invalid model name: {model_name}")  
                    if is_train:
                        file_path = file_path.replace('.jsonl', '_train.jsonl')
                    file_path = file_path.replace('.jsonl', '_eval.jsonl')
                    
                    if not os.path.exists(file_path):
                        print(f"File {file_path} does not exist")
                        continue
                
                    data = data + read_jsonl(file_path)
                    
                for row in data:
                    if method_name.lower() == 'cot' or method_name.lower() == 'pal' or method_name.lower() == 'codenl_single' or method_name.lower() == 'nlcode_single':
                        answer, reasoning_path, correct = row['pred_answer'], row['reasoning_path'], row['correct']
                    elif method_name.lower() == 'codenl' or method_name.lower() == 'nlcode':
                        answer, reasoning_path_1, reasoning_path_2, correct= row['pred_answer'], row['reasoning_path_1'], row['reasoning_path_2'], row['correct']
                        reasoning_path = reasoning_path_1 + "\n" + reasoning_path_2
                    else:
                        raise ValueError(f"Invalid method name: {method_name}")
                    
                    answers[method_name][data_type][row['idx']] = {'question': row['question'], 
                                                                   'answer': answer, 
                                                                   'gold_reasoning_path': row['answer'],
                                                                   'level': row['level'],
                                                                   'reasoning_path': reasoning_path, 
                                                                   'ground_truth': row['ground_truth'],
                                                                   'correct': correct}
                # print('method name:', method_name, 'data type:', data_type)
                # print(answers[method_name][data_type].keys())
            print(f"Method {method_name} done")
        return answers
              
def obtain_finetuned_data_greedy(model_name, dataset, method_list, hard_levels=['Level 5'], is_train=False):
    """
    Obtain the label to finetune the classifier. If only one method can correctly solve the problem, the label is that exact method. 
        If multiple methods can correctly solve the problem or no method can correctly solve the problem, breaking tie based on the preference of the method_list.
        The preference list is ordered by the average accuracy of the method in the method_list.
    Args:
        model_name: The name of the model.
        dataset: The dataset to evaluate the model.
        method_list: The list of methods to evaluate the model.
    Returns:
        None
    """
    if dataset == 'math':
        answers = retrieve_answer_math(model_name, methods=method_list, hard_levels=hard_levels, is_train=is_train)
    else:
        raise ValueError("Dataset not supported")

    finetuned_raw_data = []
    dataset_list = list(answers[method_list[0]].keys())

    all_labels = []
    
    for dataset in dataset_list:
        
        for idx, row in answers[method_list[0]][dataset].items():
            correct_list = [answers[method][dataset][idx]['correct'] for method in method_list]
            # check if the answer is the same
            correct_method = [method for method, correct in zip(method_list, correct_list) if correct]   
            # if len(correct_method) > 0: # have at least one correct method
            labels = correct_method
            all_labels.extend(labels)
            finetuned_raw_data.append({
                'type': dataset,
                'idx': idx,
                'level':row['level'],
                'question': row['question'],
                'ground_truth': row['ground_truth'],
                'label': labels})
    
    print(f"Total number of examples: {sum([len(answers[method_list[0]][dataset]) for dataset in dataset_list])}")
    print(f"Total number of valid examples: {len(finetuned_raw_data)}")
    # get the distribution of the labels
    label_distribution = {label: all_labels.count(label) for label in method_list}
    print(label_distribution)
    

    if 'Llama-3.1-8B' in model_name:
        stored_dic = f"ft_data/llama3.1-8b/"
    elif 'Llama-3.1-70B' in model_name:
        stored_dic = f"ft_data/llama3.1-70b/"
    elif 'gpt-4o-mini' in model_name:
        stored_dic = f"ft_data/gpt4o-mini/"
    else:
        raise ValueError("Model not supported")
    
    # f"ft_data/llama3.1-8b/data_{mode}.jsonl"
    if is_train:
        stored_dic += 'train/'
    else:
        stored_dic += 'test/'

    if not os.path.exists(stored_dic):
        os.makedirs(stored_dic)

    string = 'lvl_'
    for hard_level in hard_levels:
        string += hard_level.replace('Level ', '') 
    stored_path = stored_dic + f"data_{string}_greedy.jsonl"
    answers_stored_path = stored_dic + f"answers_{string}_greedy.json"

    # shuffle the data
    random.shuffle(finetuned_raw_data)
    write_jsonl(finetuned_raw_data, stored_path)
    # store answers
    dir_name = os.path.dirname(answers_stored_path)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    json.dump(answers, open(answers_stored_path, 'w'))
    return stored_path

if __name__ == "__main__":
    # Example usage
    # finetune("gpt2", "data/gpt-4o-mini_decision_train_200.jsonl")
    method_list = ['cot', 'pal','nlcode_single', 'codenl', 'nlcode']
    is_train = False
    # model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    model_name = 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo'
    dataset = 'math'
    hard_levels = ['Level 5', 'Level 4', 'Level 3', 'Level 2', 'Level 1']

    stored_path = obtain_finetuned_data_greedy(model_name, dataset, method_list, hard_levels=hard_levels, is_train=is_train)