import argparse
import json

import re
import numpy as np
from tqdm import tqdm
from dotenv import load_dotenv
load_dotenv()

import torch
from vllm import LLM
from utils import set_seeds, read_config, vllm_chat_completion_batch


def format_prompt(assitent_1_instance, assitent_2_instance, system_prompt, prompt_template, model_name):
    messages = [{
        "role": "user",
        "content": prompt_template.format(
            question_1=assitent_1_instance["question"],
            answer_1=assitent_1_instance["pred_answer"].split("</think>")[-1].strip(),
            answer_2=assitent_2_instance["pred_answer"].split("</think>")[-1].strip()
        )
    }]

    if "gemma" in model_name or "DeepSeek-R1" in model_name:
        messages[0]["content"] = system_prompt + "\n" + messages[0]["content"]
    else:
        messages.insert(0, {"role": "system", "content": system_prompt})

    return messages

def compute_spb_score(output):

    choices = ["1", "T", "2"]
    for choice in choices:
        escaped_symbol = re.escape(choice)
        patterns = [
            rf'\$\${escaped_symbol}\$\$',
            rf'\$\$\({escaped_symbol}\)\$\$',
            rf'\$\$\({escaped_symbol}\$\$',
            rf'\$\${escaped_symbol}\)\$\$',
            rf'\$\$\\text{{\({escaped_symbol}\)}}\$\$',
            rf'\$\$\\text{{{escaped_symbol}}}\$\$',
            rf'My final verdict is {escaped_symbol}',
            rf'My final verdict is \({escaped_symbol}\)',
            rf'My final answer is \\boxed{{{escaped_symbol}}}',
            rf'The final answer is \\boxed{{{escaped_symbol}}}',
        ]
        for pattern in patterns:
            if re.search(pattern, output):
                if choice == "1":
                    probs = [1, 0, 0]
                elif choice == "T":
                    probs = [0, 1, 0]
                elif choice == "2":
                    probs = [0, 0, 1]
                
                return probs
            
    print(f"Invalid output")
    probs = [0, 1, 0]

    return probs

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="output/judge")
    parser.add_argument("--assistent_1_answer_path", type=str, required=True)
    parser.add_argument("--assistent_2_answer_path", type=str, required=True)
    parser.add_argument("--sample", type=int, default=None)

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    config = read_config(args.config)
    set_seeds(config["seed"])
    print(config)
    args.output_dir = f"{args.output_dir}/{args.config.split('/')[-2]}"
    
    llm = LLM(
        model=config["model_name"],
        tensor_parallel_size=config["tensor_parallel_size"],
        max_model_len=config["max_model_len"],
        enable_prefix_caching=config["enable_prefix_caching"]
    )
    
    with open(args.assistent_1_answer_path, "r") as f:
        assistent_1_answers = [json.loads(line) for line in f]
    
    with open(args.assistent_2_answer_path, "r") as f:
        assistent_2_answers = [json.loads(line) for line in f]

    if args.sample:
        random_indices = np.random.choice(len(assistent_1_answers), args.sample, replace=False)
        assistent_1_answers = [assistent_1_answers[i] for i in random_indices]
        assistent_2_answers = [assistent_2_answers[i] for i in random_indices]

    print(f"Number of test instances: {len(assistent_1_answers)}")
    
    game_1_prompts = []
    game_2_prompts = []
    for assistent_1_instance, assistent_2_instance in zip(assistent_1_answers, assistent_2_answers):
        if assistent_1_instance["unique_id"] == 493:
            continue
        game_1_prompts.append(format_prompt(assistent_1_instance, assistent_2_instance, config["system_prompt"], config["prompt_template"], config["model_name"]))
        game_2_prompts.append(format_prompt(assistent_2_instance, assistent_1_instance, config["system_prompt"], config["prompt_template"], config["model_name"]))

    game_1_outputs = vllm_chat_completion_batch(llm, game_1_prompts, config)
    game_2_outputs = vllm_chat_completion_batch(llm, game_2_prompts, config)

    game_1_spb_scores = [compute_spb_score(output) for output in game_1_outputs]
    game_2_spb_scores = [compute_spb_score(output) for output in game_2_outputs]

    # write output to jsonl file
    with open(f"{args.output_dir}/{args.config.split('/')[-1].replace('.yaml', '')}_{args.assistent_2_answer_path.split('/')[-1].replace('.jsonl', '')}.jsonl", "a") as f:
        for idx, (assistent_1_instance, assistent_2_instance, game_1_spb_score, game_2_spb_score, game_1_output, game_2_output) in enumerate(zip(assistent_1_answers, assistent_2_answers, game_1_spb_scores, game_2_spb_scores, game_1_outputs, game_2_outputs)):
            f.write(json.dumps(
                {
                    "unique_id": assistent_1_instance["unique_id"],
                    "problem": assistent_1_instance["question"],
                    "assistent_1_answer": assistent_1_instance["pred_answer"],
                    "assistent_2_answer": assistent_2_instance["pred_answer"],
                    "assistent_1_is_correct": assistent_1_instance["is_correct"],
                    "assistent_2_is_correct": assistent_2_instance["is_correct"],
                    "game_1_spb_score": game_1_spb_score,
                    "game_2_spb_score": game_2_spb_score,
                    "game_1_output": game_1_output,
                    "game_2_output": game_2_output
                }) + "\n")
