import argparse
import json
import random
from pprint import pprint

from tqdm import tqdm
from dotenv import load_dotenv
load_dotenv()

from utils import set_seeds, read_config, openai_chat_completion, vllm_chat_completion_batch

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="output/answer")
    parser.add_argument("--sample", type=int, default=None)

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    config = read_config(args.config)
    set_seeds(config["seed"])
    pprint(config)
    args.output_dir = f"{args.output_dir}/{args.config.split('/')[-2]}"

    with open(config["data_path"], "r") as f:
        test_dataset = [json.loads(line) for line in f]
    
    if args.sample is not None:
        test_dataset = random.sample(test_dataset, args.sample)
    print(f"Number of test instances: {len(test_dataset)}")

    # format prompts
    prompts = []
    for instance in test_dataset:
        messages = [
            {
                "role": "user",
                "content": config["prompt_template"].format(question=instance["question"])
            }
        ]
        if "gemma" in config["model_name"] or "o1-mini" in config["model_name"] or "DeepSeek-R1" in config["model_name"]:
            messages[0]["content"] = config["system_prompt"] + "\n" + messages[0]["content"]
        elif "o3" in config["model_name"]:
            messages.insert(0, {"role": "developer", "content": config["system_prompt"]})
        else:
            messages.insert(0, {"role": "system", "content": config["system_prompt"]})
        prompts.append(messages)
    
    # generate answers
    if config["engine"] == "openai":
        answers = [
            openai_chat_completion(prompt, config)
            for prompt in tqdm(prompts)
        ]
    elif config["engine"] == "vllm":
        from vllm import LLM
        llm = LLM(
            model=config["model_name"],
            tensor_parallel_size=config["tensor_parallel_size"],
            max_model_len=config["max_model_len"],
            enable_prefix_caching=config["enable_prefix_caching"],
            disable_sliding_window=config.get("disable_sliding_window", False)
        )
        answers = vllm_chat_completion_batch(llm, prompts, config)
    
    # save results
    for answer, prompt, instance in zip(answers, prompts, test_dataset):
        output = instance
        output["pred_answer"] = answer
        output["prompt"] = prompt
        with open(f"{args.output_dir}/{args.config.split('/')[-1].replace('.yaml', '.jsonl')}", "a") as f:
            f.write(json.dumps(output) + "\n")