from utils.utils import *
from pebble import ProcessPool
from utils.grader import *
from tqdm import tqdm
from concurrent.futures import TimeoutError as FutureTimeoutError
import signal
import threading
import concurrent.futures

def eval_file(result_file, rewrite=False):
    """
    Evaluate the result file.
    Args:
        result_file: The result file.
    """
    if os.path.exists(result_file):
        results = read_jsonl(result_file)
    else:
        print("{} does not exist.".format(result_file))
        exit()
    new_result_file = result_file.replace(".jsonl", "_eval.jsonl")
    if os.path.exists(new_result_file) and not rewrite:
        results = read_jsonl(new_result_file)
        scores = [sample['correct'] for sample in results]
        return scores

    scores = []
    timeout_cnt  = 0
    samples = [(sample['pred_answer'], sample['ground_truth'], sample['idx']) for sample in results]

    with ProcessPool() as pool:
        future = pool.map(math_equal_process, samples, timeout=3)
        iterator = future.result()
        with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
            while True:
                try:
                    result = next(iterator)
                    scores.append(result)
                    # if not result:
                    #     print(f'idx: {samples[i][2]}')  # Correctly accessing the sample's idx
                    #     print(f'prediction: {samples[i][0]}')  # Correctly accessing the prediction
                    #     print(f'ground truth: {samples[i][1]}')  # Correctly accessing the ground truth
                except StopIteration:
                    break
                except FutureTimeoutError as error:
                    # print(f"TimeoutError for idx: {samples[i][2]}")
                    scores.append(False)
                    timeout_cnt += 1
                    continue
                except Exception as error:
                    print(error)
                    scores.append(False)
                    exit()
                progress_bar.update(1)
    
    # add the key 'correct' to the results
    for i, sample in enumerate(results):
        sample['correct'] = scores[i]
    # write the results back to the file
    write_jsonl(results, new_result_file)
    # compute the accuracy
    accuracy = sum(scores) / len(scores)
    print(f"Accuracy for dataset {result_file} is {accuracy}")
    return scores

def eval_file_folio(result_file, rewrite=False):
    """
    Evaluate the result file.
    Args:
        result_file: The result file.
    """
    if os.path.exists(result_file):
        results = read_jsonl(result_file)
    else:
        print("{} does not exist.".format(result_file))
        exit()
    new_result_file = result_file.replace(".jsonl", "_eval.jsonl")
    if os.path.exists(new_result_file) and not rewrite:
        results = read_jsonl(new_result_file)
        scores = [sample['correct'] for sample in results]
        return scores

    scores = []
    timeout_cnt  = 0
    samples = [(sample['pred_answer'], sample['label'], sample['example_id']) for sample in results if 'pred_answer' in sample.keys()]

    with ProcessPool() as pool:
        future = pool.map(math_equal_process, samples, timeout=3)
        iterator = future.result()
        with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
            while True:
                try:
                    result = next(iterator)
                    scores.append(result)
                    # if not result:
                    #     print(f'idx: {samples[i][2]}')  # Correctly accessing the sample's idx
                    #     print(f'prediction: {samples[i][0]}')  # Correctly accessing the prediction
                    #     print(f'ground truth: {samples[i][1]}')  # Correctly accessing the ground truth
                except StopIteration:
                    break
                except FutureTimeoutError as error:
                    # print(f"TimeoutError for idx: {samples[i][2]}")
                    scores.append(False)
                    timeout_cnt += 1
                    continue
                except Exception as error:
                    print(error)
                    scores.append(False)
                    exit()
                progress_bar.update(1)
    
    # add the key 'correct' to the results
    for i, sample in enumerate(results):
        sample['correct'] = scores[i]
    # write the results back to the file
    write_jsonl(results, new_result_file)
    # compute the accuracy
    accuracy = sum(scores) / len(scores)
    print(f"Accuracy for dataset {result_file} is {accuracy}")
    return scores


def eval_model(model, method, hard_level="", rewrite=False, is_train=False):
    """
    Evaluate the model.
    Args:
        model: The name 
        method: The method is used to evaluate the model.
    """
    scores = []
    for data_type in dataset_list:
        if hard_level:
            stored_path = f"results/{model}_{data_type}_{hard_level}_{method}_new.jsonl"
        else:
            stored_path = f"results/{model}_{data_type}_{method}_new.jsonl"
        if is_train:
            stored_path = stored_path.replace('new', 'train')
    
        scores += eval_file(stored_path, rewrite)

    accuracy = sum(scores) / len(scores)
    print(f"Average accuracy for model {model} using method {method} is {accuracy}")

def eval_model_greedy(model, method, hard_level="", rewrite=False, is_train=False):
    """
    Evaluate the model.
    Args:
        model: The name 
        method: The method is used to evaluate the model.
    """
    scores = []
    for data_type in dataset_list:
        if hard_level:
            if 'gpt' in model:
                stored_path = f"new_results/gpt/{model}_{data_type}_{hard_level}_{method}_greedy.jsonl"
            else:
                stored_path = f"new_results/{model}_{data_type}_{hard_level}_{method}_greedy.jsonl"
        else:
            stored_path = f"mew_results/{model}_{data_type}_{method}_greedy.jsonl"
        if is_train:
            stored_path = stored_path.replace('.jsonl', '_train.jsonl')
    
        scores += eval_file(stored_path, rewrite)

    accuracy = sum(scores) / len(scores)
    print(f"Average accuracy for model {model} using method {method} on {hard_level} problems is {accuracy}")



def eval_file_multi_shots_best(result_file, rewrite=False, num_shots=3):
    """
    Evaluate the result file.
    Args:
        result_file: The result file that contains multiple shots.
    """
    if os.path.exists(result_file):
        results = read_jsonl(result_file)
    else:
        print("{} does not exist.".format(result_file))
        exit()
    new_result_file = result_file.replace(".jsonl", "_eval.jsonl")
    if os.path.exists(new_result_file) and not rewrite:
        results = read_jsonl(new_result_file)
        # scores = [sample['accuracy'] for sample in results]
        total_scores = []
        for result in results:
            if any(result['correct']):
                total_scores.append(True)
            else:
                total_scores.append(False)
        return total_scores
    
    scores = []
    timeout_cnt  = 0
    # now each sample is a list of shots
    samples = []

    for i, result in enumerate(results):
        idx = result['idx']
        pred_answers = result['pred_answers']
        ground_truth = result['ground_truth']
        for j in range(len(pred_answers)):
            # samples[i].append((pred_answers[j], ground_truth, idx))
            samples.append((pred_answers[j], ground_truth, idx))

    # samples = [(sample['pred_answer'], sample['ground_truth'], sample['idx']) for sample in results]
    with ProcessPool() as pool:
        future = pool.map(math_equal_process, samples, timeout=3)
        iterator = future.result()
        with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
            while True:
                try:
                    result = next(iterator)
                    scores.append(result)
                except StopIteration:
                    break
                except FutureTimeoutError as error:
                    scores.append(False)
                    timeout_cnt += 1
                    continue
                except Exception as error:
                    print(error)
                    scores.append(False)
                    exit()
                progress_bar.update(1)
    acc_scores = []
    for i, result in enumerate(results):
        result['correct'] = scores[i*num_shots:(i+1)*num_shots]
        if any(result['correct']):
            acc_scores.append(True)
        else:
            acc_scores.append(False)
        # result['accuracy'] = sum(scores[i*num_shots:(i+1)*num_shots]) / num_shots
    write_jsonl(results, new_result_file)
    # accuracy = sum(scores) / len(scores)
    accuracy = sum(acc_scores) / len(acc_scores)
    print(f"Accuracy for dataset {result_file} is {accuracy}")
    return acc_scores

def eval_model_multi_shots_best(model, method, hard_level="",rewrite=False, num_shots=3, is_train=False):
    """
    Evaluate the model.
    Args:
        model: The name 
        method: The method is used to evaluate the model.
    """
    scores = []
    for data_type in dataset_list:
        if 'gpt' in model:
            stored_path = f"new_results/gpt/{model}_{data_type}_{hard_level}_{method}_nshots_{num_shots}.jsonl"
        else:
            stored_path = f"new_results/{model}_{data_type}_{hard_level}_{method}_nshots_{num_shots}.jsonl"
        if is_train:
            stored_path = stored_path.replace('.jsonl', '_train.jsonl')
        scores += eval_file_multi_shots_best(stored_path, rewrite, num_shots)

    accuracy = sum(scores) / len(scores)
    print(f"Average accuracy for model {model} using method {method} is {accuracy}")


if __name__ == "__main__":
    # eval("results/gsm8k_cot_test.jsonl")
    # eval("results/gsm8k_pal_test.jsonl")
    # eval("results/gsm8k_tworounds_codeNL_test.jsonl")
    # eval("results/gsm8k_tworounds_NLCode_test.jsonl")
    # eval("results/aime_pal_zero_shot_test.jsonl")
    dataset_list = ['algebra', 'counting & probability', 'geometry', 'number theory', 'intermediate algebra', 'precalculus', 'prealgebra']
    # model = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"

    # for data in dataset_list:
    #     model = 'gpt/gpt-4o-mini'
    #     file = 'results/' + model + '_' + data + '_cot.jsonl'
    #     extract_hard_level_file(file, 'Level 5')

    # model= 'gpt/gpt-4o-mini'
    # model = 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo'
    model = 'gpt-4o'
    # model = 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo'

    
    # for method in ['cot', 'pal', 'codenl', 'nlcode']:
    for method in ['cot']:
        for level in ['Level 5', 'Level 4', 'Level 3', 'Level 2', 'Level 1']:
            # eval_model(model, method, hard_level=level, rewrite=True, is_train=False
        # eval_model(model, method, hard_level="Level 5", rewrite=False, is_train=True)
            eval_model_multi_shots_best(model, method, hard_level=level, rewrite=True, num_shots=4, is_train=False)
            # eval_model_greedy(model, method, hard_level=level, rewrite=False, is_train=True)
            print('====================')
    # eval_file('results/gpt-4o-mini/algebra_Level 5_nlcode.jsonl')

