import argparse
import json

import re
import numpy as np
from tqdm import tqdm
from dotenv import load_dotenv
load_dotenv()

import torch
from vllm import LLM
from utils import set_seeds, read_config, vllm_chat_completion_batch


def format_prompt(assitent_1_instance, assitent_2_instance, system_prompt, prompt_template, model_name):
    messages = [{
        "role": "user",
        "content": prompt_template.format(
            question_1=assitent_1_instance["question"],
            answer_1=assitent_1_instance["pred_answer"],
            answer_2=assitent_2_instance["pred_answer"]
        )
    }]

    if "gemma" in model_name:
        messages[0]["content"] = system_prompt + "\n" + messages[0]["content"]
    else:
        messages.insert(0, {"role": "system", "content": system_prompt})

    return messages

def compute_spb_score(top_logprobs, tokenizer):
    candidate_logits = []
    for label in ["A", "T", "B"]:
        try:
            candidate_logits.append(top_logprobs[tokenizer.encode(label, add_special_tokens=False)[0]].logprob)
        except KeyError:
            print("Warning: {} not found. Setting log prob to -100.".format(label))
            # print(top_logprobs)
            candidate_logits.append(-100)

    candidate_logits = torch.tensor(candidate_logits).to(torch.float32)
    probs = (torch.nn.functional.softmax(candidate_logits, dim=0).detach().cpu().numpy().tolist())

    return probs

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="output/judge")
    parser.add_argument("--assistent_1_answer_path", type=str, required=True)
    parser.add_argument("--assistent_2_answer_path", type=str, required=True)
    parser.add_argument("--sample", type=int, default=None)

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    config = read_config(args.config)
    set_seeds(config["seed"])
    print(config)
    args.output_dir = f"{args.output_dir}/{args.config.split('/')[-2]}"
    
    llm = LLM(
        model=config["model_name"],
        tensor_parallel_size=config["tensor_parallel_size"],
        max_model_len=config["max_model_len"],
        enable_prefix_caching=config["enable_prefix_caching"]
    )
    
    with open(args.assistent_1_answer_path, "r") as f:
        assistent_1_answers = [json.loads(line) for line in f]
    
    with open(args.assistent_2_answer_path, "r") as f:
        assistent_2_answers = [json.loads(line) for line in f]

    if args.sample:
        random_indices = np.random.choice(len(assistent_1_answers), args.sample, replace=False)
        assistent_1_answers = [assistent_1_answers[i] for i in random_indices]
        assistent_2_answers = [assistent_2_answers[i] for i in random_indices]

    print(f"Number of test instances: {len(assistent_1_answers)}")
    
    game_1_prompts = []
    game_2_prompts = []
    for assistent_1_instance, assistent_2_instance in zip(assistent_1_answers, assistent_2_answers):
        if assistent_1_instance["unique_id"] == 493:
            continue
        game_1_prompts.append(format_prompt(assistent_1_instance, assistent_2_instance, config["system_prompt"], config["prompt_template"], config["model_name"]))
        game_2_prompts.append(format_prompt(assistent_2_instance, assistent_1_instance, config["system_prompt"], config["prompt_template"], config["model_name"]))

    game_1_outputs = vllm_chat_completion_batch(llm, game_1_prompts, config)
    game_2_outputs = vllm_chat_completion_batch(llm, game_2_prompts, config)

    tokenizer = llm.get_tokenizer()
    game_1_spb_scores = [compute_spb_score(top_logprobs, tokenizer) for _, top_logprobs in game_1_outputs]
    game_2_spb_scores = [compute_spb_score(top_logprobs, tokenizer) for _, top_logprobs in game_2_outputs]

    # game_1_spb_scores = []
    # for _, top_logprobs in game_1_outputs:
    #     probs = compute_spb_score(top_logprobs, tokenizer)
    #     game_1_spb_scores.append(probs[0]*3 + probs[1]*2 + probs[2]*1)

    # game_2_spb_scores = []
    # for _, top_logprobs in game_2_outputs:
    #     probs = compute_spb_score(top_logprobs, tokenizer)
    #     game_2_spb_scores.append(probs[0]*1 + probs[1]*2 + probs[2]*3)

    # avg_spb_score = (np.array(game_1_spb_scores) + np.array(game_2_spb_scores)) / 2
    # print(f"Average self-preference score: {np.mean(avg_spb_score)}")

    # write output to jsonl file
    with open(f"{args.output_dir}/{args.config.split('/')[-1].replace('.yaml', '')}_{args.assistent_2_answer_path.split('/')[-1].replace('.jsonl', '')}.jsonl", "a") as f:
        for idx, (assistent_1_instance, assistent_2_instance, game_1_spb_score, game_2_spb_score, game_1_output, game_2_output) in enumerate(zip(assistent_1_answers, assistent_2_answers, game_1_spb_scores, game_2_spb_scores, game_1_outputs, game_2_outputs)):
            f.write(json.dumps(
                {
                    "unique_id": assistent_1_instance["unique_id"],
                    "problem": assistent_1_instance["question"],
                    "assistent_1_answer": assistent_1_instance["pred_answer"],
                    "assistent_2_answer": assistent_2_instance["pred_answer"],
                    "assistent_1_is_correct": assistent_1_instance["is_correct"],
                    "assistent_2_is_correct": assistent_2_instance["is_correct"],
                    "game_1_spb_score": game_1_spb_score,
                    "game_2_spb_score": game_2_spb_score,
                    "game_1_output": game_1_output[0],
                    "game_2_output": game_2_output[0]
                }) + "\n")
