import argparse
import json

import re
import numpy as np
from tqdm import tqdm
from dotenv import load_dotenv
load_dotenv()

from utils import set_seeds, read_config, openai_chat_completion

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="output/judge")
    parser.add_argument("--assistent_1_answer_path", type=str, required=True)
    parser.add_argument("--assistent_2_answer_path", type=str, required=True)
    parser.add_argument("--sample", type=int, default=None)

    return parser.parse_args()

def format_prompt(assitent_1_instance, assitent_2_instance, system_prompt, prompt_template):
    messages = [{"role": "system", "content": system_prompt}]
    messages.append({
        "role": "user",
        "content": prompt_template.format(
        question_1=assitent_1_instance["question"],
            answer_1=assitent_1_instance["pred_answer"],
            answer_2=assitent_2_instance["pred_answer"])
    })

    return messages

if __name__ == "__main__":
    args = parse_args()
    config = read_config(args.config)
    set_seeds(config["seed"])
    print(config)

    with open(args.assistent_1_answer_path, "r") as f:
        assistent_1_answers = [json.loads(line) for line in f]
    
    with open(args.assistent_2_answer_path, "r") as f:
        assistent_2_answers = [json.loads(line) for line in f]

    if args.sample:
        random_indices = np.random.choice(len(assistent_1_answers), args.sample, replace=False)
        assistent_1_answers = [assistent_1_answers[i] for i in random_indices]
        assistent_2_answers = [assistent_2_answers[i] for i in random_indices]

    print(f"Number of test instances: {len(assistent_1_answers)}")

    game_1_outputs = []
    game_2_outputs = []
    all_scores = []
    num_of_nan_scores = 0
    for idx, (assistent_1_instance, assistent_2_instance) in enumerate(tqdm(zip(assistent_1_answers, assistent_2_answers), total=len(assistent_1_answers))):
        
        assert assistent_1_instance["unique_id"] == assistent_2_instance["unique_id"]

        two_game_scores = []
        for game in range(2):
            if game == 0:
                messages = format_prompt(assistent_1_instance, assistent_2_instance, config["system_prompt"], config["prompt_template"])
            elif game == 1:
                messages = format_prompt(assistent_2_instance, assistent_1_instance, config["system_prompt"], config["prompt_template"])

            output = openai_chat_completion(messages, config)

            if game == 0:
                game_1_outputs.append(output)
            elif game == 1:
                game_2_outputs.append(output)

            matches = re.compile(config["regex_pattern"]).findall(output)
            matches = [m for m in matches if m != ""]
            judgement = matches[0].strip("\n") if len(set(matches)) == 1 else None
            
            if game == 0:
                if judgement == "A>B":
                    score = 3
                elif judgement == "A=B":
                    score = 2
                elif judgement == "B>A":
                    score = 1
                else:
                    score = 2
                    num_of_nan_scores += 1
            elif game == 1:
                if judgement == "A>B":
                    score = 1
                elif judgement == "A=B":
                    score = 2
                elif judgement == "B>A":
                    score = 3
                else:
                    score = 2
                    num_of_nan_scores += 1
            two_game_scores.append(score)

        two_game_average_score = np.mean(two_game_scores)
        all_scores.append(two_game_average_score)

    print(f"Number of NaN scores: {num_of_nan_scores}")
    print(f"Average score: {np.mean(all_scores)}")

    # write output to jsonl file
    with open(f"{args.output_dir}/{args.config.split('/')[-1].replace('.yaml', '')}_{args.assistent_2_answer_path.split('/')[-1].replace('.jsonl', '')}.jsonl", "a") as f:
        for idx, (assistent_1_instance, assistent_2_instance, score) in enumerate(zip(assistent_1_answers, assistent_2_answers, all_scores)):
            f.write(json.dumps(
                {
                    "unique_id": assistent_1_instance["unique_id"],
                    "problem": assistent_1_instance["question"],
                    "assistent_1_answer": assistent_1_instance["pred_answer"],
                    "assistent_2_answer": assistent_2_instance["pred_answer"],
                    "assistent_1_is_correct": assistent_1_instance["is_correct"],
                    "assistent_2_is_correct": assistent_2_instance["is_correct"],
                    "game_1_output": game_1_outputs[idx],
                    "game_2_output": game_2_outputs[idx],
                    "judgement_score": score
                }) + "\n")
