import os
import copy
import json
import argparse
from tqdm import tqdm, trange
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

parser = argparse.ArgumentParser("")
parser.add_argument("--tokenizer", type=str, default="meta-llama/Llama-3.2-3B-Instruct")
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--expname", type=str, default="")
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_p", type=float, default=0.99)
parser.add_argument("--tensor_parallel_size", type=int, default=1)
parser.add_argument("--datasets", type=str, default="hotpotqa")
args = parser.parse_args()

prompt_plan = {
    "gsm8k": "Please break down the question '{question}' into multiple specific sub-questions. Use ### for each sub-question.",
    "TabMWP": "You have the following table:\n{table}\nBreak down the question '{question}' into sub-questions. Use ### for each sub-question.",
    "convfinqa": "You have the following passages and table:\nPassages:\n{passage}\nTable:\n{table}\nBreak down the question '{question}' into sub-questions. Use ### for each sub-question.",
    "claim": "Please break down the claim \"{claim}\" into multiple smaller sub-claims using ###.",
    "qa": "Please break down the question '{question}' into multiple specific sub-questions using ###.",
    "claim_cold": "Break down the claim \"{claim}\" into smaller sub-claims using ###. Refer to earlier answers using #1, #2, etc. if needed.",
    "qa_cold": "Break down the question '{question}' into sub-questions using ###. Refer to earlier answers using #1, #2, etc. if needed."
}

task_map = {
    "bamboogle": "qa", "2wikimultihopqa": "qa", "hotpotqa": "qa", "musique": "qa",
    "hover": "claim", "exfever": "claim"
}

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
sampling_params = SamplingParams(
    temperature=args.temperature,
    top_p=args.top_p,
    repetition_penalty=1.05,
    max_tokens=2048
)
llm = LLM(model=args.model_path, tensor_parallel_size=args.tensor_parallel_size, gpu_memory_utilization=0.88, trust_remote_code=True)

datasets = args.datasets.split(",") if "," in args.datasets else [args.datasets]

for dataset in datasets:
    prompts, contexts = [], []
    with open(f"data/{dataset}/test.jsonl", "r") as f:
        for line in f:
            example = json.loads(line)
            if task_map[dataset] == "claim":
                question = example["claim"]
                answer = ["Yes"] if example["label"] in ["SUPPORT", "SUPPORTED"] else ["No"]
            else:
                question = example["question_text"]
                answer = [span for ans in example["answers_objects"] for span in ans["spans"]]
            item = {"question": question, "answer": answer}
            prompt_template = prompt_plan[task_map[dataset] + ("_cold" if args.expname in ["qwen3_8b", "qwen3_4b"] else "")]
            prompt_text = prompt_template.replace("{question}", question).replace("{claim}", question)
            prompts.append([{"role": "user", "content": prompt_text.strip()}])
            contexts.append(item)

    os.makedirs(f"outputs/{dataset}/prompts_{args.expname}", exist_ok=True)
    examples = []
    for i in trange(len(prompts)):
        text = tokenizer.apply_chat_template(prompts[i], tokenize=False, add_generation_prompt=True)
        for j in range(1 if args.temperature == 0 else 3):
            ctx = contexts[i]
            outputs = llm.generate([text], sampling_params)
            generated_text = outputs[0].outputs[0].text
            decomposed_questions = []
            for line in generated_text.strip().split("\n"):
                line = line.strip()
                if line.startswith("### Q"):
                    try:
                        question_part = line
                        question_text = question_part.split(":", 1)[1].strip()
                        q_label = "Q" + question_part.split(":")[0].split("Q")[-1].strip()
                        decomposed_questions.append({"label": q_label, "text": question_text, "needs_context": True})
                    except Exception:
                        print(f"Error in parsing: {generated_text}")
                        decomposed_questions = "Error"
                        break
            ctx.update({"question_id": i, "decompose_id": j, "decomposed": decomposed_questions})
            examples.append(copy.deepcopy(ctx))

    if examples:
        with open(f"outputs/{dataset}/prompts_{args.expname}/generate.jsonl", "w") as f:
            for example in examples:
                f.write(json.dumps(example) + "\n")
