import argparse
import json
import pdb
import os

from evaluation.eval.eval_script import eval_math 
from evaluation.data_processing.answer_extraction import extract_math_answer

import sys
MAX_INT = sys.maxsize

invalid_outputs = []

def test_hendrycks_math(data_path, remainder=0, n_groups=MAX_INT, args=None):
    
    # load completion data
    with open(data_path, "r") as infile:
        item_lst = json.load(infile)

    hendrycks_math_ins = []
    hendrycks_math_answers = []
    res_completions = []
    for idx, item in enumerate(item_lst):
        hendrycks_math_ins.append(item['prompt'])
        hendrycks_math_answers.append(item['answer'])
        res_completions.append(item['completion'])
    
    hendrycks_math_answers = hendrycks_math_answers[remainder::n_groups]
    hendrycks_math_answers = hendrycks_math_answers * args.rep

    to_save_list = []
    results = []
    for idx, (prompt, completion, prompt_answer) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers)):

        if "The answer is:" in completion and (isinstance(prompt_answer, list) and len(prompt_answer) == 1 and "\\begin{pmatrix}" in prompt_answer[0]):
            prompt_answer[0] = prompt_answer[0].replace("\\\\", "\\")
            completion = completion.replace("\\\\", "\\")

        item = {
            'question': prompt,
            'model_output': completion,
            'prediction': extract_math_answer(prompt, completion, task='cot'),
            'answer': prompt_answer if isinstance(prompt_answer, list) else [prompt_answer],
        }

        if len(item['prediction']) == 0:
            invalid_outputs.append({'question': prompt, 'output': completion, 'answer': item['prediction']})
            res = False
            extract_ans = None
        else:
            extract_ans = item['prediction']
            res = eval_math(item)

        results.append(res)

        to_save_dict = {
            'prompt': prompt,
            'completion': completion,
            'extract_answer': extract_ans,
            'answer': prompt_answer,
            'result': res,
        }
        to_save_list.append(to_save_dict)

    with open(args.save_path, "w+") as f:
        json.dump(to_save_list, f, indent=4)

    acc = sum(results) / len(results)
    # print('valid_outputs===', invalid_outputs)
    print('len invalid outputs ====', len(invalid_outputs))
    print('n_groups===', n_groups, ', remainder====', remainder)
    print('length====', len(results), ', acc====', acc)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_file", type=str, default='')  # data path
    parser.add_argument("--save_path", type=str, default='')
    parser.add_argument("--remainder", type=int, default=0) # index
    parser.add_argument("--n_groups", type=int, default=1)  # group number
    parser.add_argument("--rep", type=int, default=1)
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    test_hendrycks_math(data_path=args.data_file, remainder=args.remainder, n_groups=args.n_groups, args=args)
