import argparse
import re
import json
import os
os.environ["HF_ALLOW_CODE_EVAL"] = "1"

def eval_code(test_cases, candidates):
    from evaluate import load
    code_eval = load("code_eval")
    pass_at_k, results = code_eval.compute(references=test_cases, predictions=candidates, k=[1])

    return pass_at_k, results

def preproc(pred_answer):

    def split_at_last_return(text):
        match = list(re.finditer(r"(return .*\n)", text))
        if match:
            last_match = match[-1]  # Get the last occurrence
            split_index = last_match.end()  # Get the end position of the match
            return text[:split_index]
        return text

    if "[PYTHON]" in pred_answer and "[/PYTHON]" in pred_answer:
        pred_answer = pred_answer.replace("```python", "").replace("```", "")
        pattern = r"\[PYTHON\](.*?)\[/PYTHON\]"
        match = re.search(pattern, pred_answer, re.DOTALL)
        pred_answer = match.group(1).strip()
    elif "```python" in pred_answer and "```" in pred_answer:
        pattern = r"```python(.*?)```"
        match = re.search(pattern, pred_answer, re.DOTALL)
        pred_answer = match.group(1).strip()
    else:
        pred_answer = pred_answer.replace("```python", "").replace("```", "")
    
    return split_at_last_return(pred_answer).strip()

def parse_args():
    parser = argparse.ArgumentParser() 
    parser.add_argument("--answer_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="output/answer")

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    args.output_dir = f"{args.output_dir}/{args.answer_path.split('/')[-2]}"

    # load answers
    with open(args.answer_path, "r") as f:
        answers = [json.loads(line) for line in f]
    
    # evaluate answers
    test_cases = [instance["gt_answer"] for instance in answers]
    candidates = [[preproc(instance["pred_answer"])] for instance in answers]
    pass_at_k, results = eval_code(test_cases, candidates)
    print(f"Pass@1: {pass_at_k['pass@1']}")

    is_corrects = []
    for idx in range(len(answers)):
        is_corrects.append(results[idx][0][1]["passed"])
        # if results[idx][0][1]["passed"] == False:
        #     print(results[idx])

    # save results
    print("Saving results...")
    for instance, is_correct in zip(answers, is_corrects):
        instance["is_correct"] = is_correct
        instance["pred_answer"] = preproc(instance["pred_answer"])
        with open(f"{args.output_dir}/{args.answer_path.split('/')[-1].split('.jsonl')[0]}_eval.jsonl", "a") as f:
            f.write(json.dumps(instance) + "\n")
