import random
import json
import vllm
import evaluate
import argparse
from tqdm import tqdm
import torch
import os
import re
from transformers import AutoTokenizer

exact_match = evaluate.load("exact_match")

instr = "Please think step by step to solve the following question, and put your final answer within \\boxed{}."
chat_prompt = lambda x: [{"role": "system", "content": instr}, {"role": "user", "content": x}]

def extract_answer(text):
    
    pattern = re.compile(r'\\boxed{\s*((?:[^{}]|(?<=\\)[{}]|{(?:[^{}]*|{[^{}]*})*})*)\s*}')

    match = pattern.search(text)
    if match:
        boxed_content = match.group(1).strip()
        return boxed_content
    else:
        return None

def main(args):
    random.seed(42)
    
    prompt = chat_prompt
    is_chat = True
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
    add_generation_prompt = True

    print("Loading data...")
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir, exist_ok=True)

    with open('FILEPATH', 'r', encoding='utf-8') as f:
        lines = f.readlines()
        full_data = [json.loads(line) for line in lines]
    
    if args.num_instances is not None and len(full_data) >= args.num_instances:
        full_data = random.sample(full_data, args.num_instances)

    prompts = []
    targets = []
    rationales = []
    if args.eval_lang == 'chn':
        for d in tqdm(full_data):
            prp = prompt(d['chinese_question'])
            if is_chat:
                prp = tokenizer.apply_chat_template(
                    prp, 
                    tokenize=False, 
                    add_generation_prompt=add_generation_prompt
                )
            prompts.append(prp)
            targets.append(str(d['final_answer']))
            rationales.append(d['solution'])
    else:
        for d in tqdm(full_data):
            prp = prompt(d['question'])
            if is_chat:
                prp = tokenizer.apply_chat_template(
                    prp, 
                    tokenize=False, 
                    add_generation_prompt=add_generation_prompt
                )
            prompts.append(prp)
            # print(d)
            targets.append(str(d['final_answer']))
            rationales.append(d['solution'])

    # load model and tokenizer
    model = vllm.LLM(
        model=args.model_name,
        tensor_parallel_size=torch.cuda.device_count(),
        trust_remote_code=True
    )

    stop_strings = args.additional_stop_sequence
    if args.newline_stop:
        if args.stop_at_double_newline:
            stop_strings += ["\n\n"] 
        elif args.stop_at_triple_newline:
            stop_strings += ["\n\n\n"]
        else:
            stop_strings += ["\n"]
    sampling_params = vllm.SamplingParams(
        temperature=0,
        max_tokens=args.clm_max_length,
        stop=stop_strings,
        skip_special_tokens=True,
    )

    # We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long)
    generations = model.generate(prompts, sampling_params)
    prompt_to_output = {
        g.prompt: g.outputs[0].text for g in generations
    }
    outputs = [prompt_to_output[prompt] if prompt in prompt_to_output else "" for prompt in prompts]

    print("Calculating accuracy...")
    predictions = []
    for output in outputs:
        answer = extract_answer(output)
        if answer is not None:
            predictions.append(str(answer))  # Convert numeric value back to string for comparison
        else:
            predictions.append("")
        
    em_score = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
    print(f"Exact match : {em_score}")

    predictions = [{
        "prompt": prompt,
        "answer": tgt,
        "prediction": pred,
        "model_output": output,
        "solution_ref": d['solution']
    } for prompt, tgt, output, pred, d in zip(prompts, targets, outputs, predictions, full_data)]

    with open(os.path.join(args.save_dir, f"predictions.jsonl"), "w") as fout:
        for prediction in predictions:
            fout.write(json.dumps(prediction) + "\n") 
    
    with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
        json.dump({
            "exact_match": em_score,
        }, fout, indent=4)

parser = argparse.ArgumentParser()
parser.add_argument(
    '--model_name', 
    type=str, 
    help="The HuggingFace model to be evaluated."
    )
parser.add_argument(
    '--num_instances', 
    type=int, 
    default=None,
    help="Num of sampled instances for evaluation"
    )
parser.add_argument(
    "--newline_stop",
    action="store_true",
    help="If given, we will use stop token (usually newline or double newline) to stop generation."
    )
parser.add_argument(
    "--stop_at_double_newline",
    action="store_true",
    help="If given, will stop generation at double newline instead of single."
    )
parser.add_argument(
    "--stop_at_triple_newline",
    action="store_true",
    help="If given, will stop generation at triple newline instead of single."
    )
parser.add_argument(
    '--additional_stop_sequence',
    type=str,
    nargs="+",
    default=[],
    help="Additional stop sequences to use when generating completions. Useful for e.g. llama-3-instruct."
    )
parser.add_argument(
    "--clm_max_length",
    type=int,
    default=256
    )
parser.add_argument(
    "--eval_lang",
    type=str,
    choices=['chn', 'eng'],
    default='eng'
    )
parser.add_argument(
        "--save_dir", 
        type=str
    )


args = parser.parse_args()
main(args)