import os
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from utils import load_json, save_json
from prompts import (
    build_prompt_med,
    build_reflection_prompt_med,
    build_prompt_cs,
    build_reflection_prompt_cs,
)

os.environ["CUDA_VISIBLE_DEVICES"] = "1"


def main():
    # === Argument parser ===
    parser = argparse.ArgumentParser(description="Run QA with reflection.")
    parser.add_argument("--input", type=str, required=True, help="Path to input JSON file.")
    parser.add_argument("--output", type=str, default="reflection_output.json", help="Path to save results.")
    parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="Model name or path.")
    parser.add_argument("--task", type=str, choices=["med", "csqa"], required=True, help="Task type: 'med' or 'csqa'")
    args = parser.parse_args()

    # === Load model and tokenizer ===
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = AutoModelForCausalLM.from_pretrained(
        args.model, device_map="auto", torch_dtype=torch.float16
    )

    # === Select task ===
    if args.task == "med":
        prompt_fn = build_prompt_med
        reflection_fn = build_reflection_prompt_med
    else:  # csqa
        prompt_fn = build_prompt_cs
        reflection_fn = build_reflection_prompt_cs

    # === Load dataset ===
    questions = load_json(args.input)
    results = []

    for item in tqdm(questions, desc="Running reflection inference..."):
        question = item["question"]
        qid = item["id"]
        correct_answer = item["answerKey"]

        # First generation
        prompt_1 = prompt_fn(question)
        inputs_1 = tokenizer(prompt_1, return_tensors="pt").to(model.device)
        output_1 = model.generate(**inputs_1, max_new_tokens=512, do_sample=False)
        response_1 = tokenizer.decode(output_1[0], skip_special_tokens=True)
        assistant_reply_1 = response_1.split("assistant")[-1].strip()

        # Second generation (reflection)
        prompt_2 = reflection_fn(question, response_1)
        inputs_2 = tokenizer(prompt_2, return_tensors="pt").to(model.device)
        output_2 = model.generate(**inputs_2, max_new_tokens=1024, do_sample=True)
        response_2 = tokenizer.decode(output_2[0], skip_special_tokens=True)
        assistant_reply_2 = response_2.split("assistant")[-1].strip()

        results.append({
            "id": qid,
            "question": question,
            "correct_answer": correct_answer,
            "raw_answer_response": assistant_reply_1,
            "reflection": assistant_reply_2
        })

    # === Save results ===
    save_json(results, args.output)
    print(f"Output saved to {args.output}")


if __name__ == "__main__":
    main()