import os
import json
import copy
import argparse
from tqdm import trange
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer", type=str, default="meta-llama/Llama-3.1-8B-Instruct")
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--expname", type=str, default="")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_p", type=float, default=0.99)
parser.add_argument("--tensor_parallel_size", type=int, default=1)
parser.add_argument("--dataset", type=str, default="hotpotqa")
parser.add_argument("--round", type=int, default=2)
parser.add_argument("--start", type=int, default=0)
parser.add_argument("--end", type=int, default=10000)
args = parser.parse_args()

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
sampling_params = SamplingParams(
    temperature=args.temperature,
    top_p=args.top_p,
    repetition_penalty=1.05,
    max_tokens=1024
)
llm = LLM(model=args.model_path, tensor_parallel_size=args.tensor_parallel_size, gpu_memory_utilization=0.86, trust_remote_code=True)

# Prompt templates
PROMPT_QA = """Please break down the question '{question}' into multiple specific sub-questions that address individual components. Use ### at the beginning of each sub-question."""
PROMPT_CLAIM = """Please break down the claim "{question}" into smaller sub-claims for easier verification. Use ### at the beginning of each sub-claim."""

# Load dataset
prompts, contexts = [], []
data_path = f"data/{args.dataset}/train.jsonl"
if args.dataset == "hover":
    data_path = "data/hover/train.jsonl"

with open(data_path, "r") as f:
    for line in f:
        example = json.loads(line)
        if args.dataset == "hover":
            question = example["claim"]
            answer = ["Yes"] if example["label"] in ["SUPPORT", "SUPPORTED"] else ["No"]
            prompt_text = PROMPT_CLAIM.replace("{question}", question)
        else:
            question = example["question_text"]
            answer = [span for ans in example["answers_objects"] for span in ans["spans"]]
            prompt_text = PROMPT_QA.replace("{question}", question)

        prompts.append([{ "role": "user", "content": prompt_text.strip() }])
        contexts.append({"question": question, "answer": answer})

# Output path
output_dir = f"processed_data/{args.dataset}/prompts_decompose_train_{args.expname}_r{args.round}"
os.makedirs(output_dir, exist_ok=True)

examples = []
for i in trange(args.start, min(args.end, len(prompts))):
    text = tokenizer.apply_chat_template(prompts[i], tokenize=False, add_generation_prompt=True)
    generated_texts = {}

    for _ in range(8):
        output = llm.generate([text], sampling_params)[0].outputs[0].text.strip()
        generated_texts[output] = 1
        if len(generated_texts) > 3:
            break

    if len(generated_texts) <= 2:
        continue

    for j, gen_text in enumerate(generated_texts):
        ctx = copy.deepcopy(contexts[i])
        decomposed = []
        for line in gen_text.split("\n"):
            line = line.strip()
            if line.startswith("### Q"):
                try:
                    label = "Q" + line.split(":")[0].split("Q")[-1].strip()
                    text_part = line.split(":", 1)[1].strip()
                    decomposed.append({"label": label, "text": text_part, "needs_context": True})
                except Exception:
                    print(f"Error parsing: {gen_text}")
                    decomposed = "Error"
                    break
        ctx.update({"question_id": i, "decompose_id": j, "decomposed": decomposed})
        examples.append(ctx)

    if i % 1000 == 0 and i > 0:
        with open(f"{output_dir}/corpus_{i}.txt", "w") as f:
            for example in examples:
                f.write(json.dumps(example) + "\n")
        examples = []

if examples:
    final_file = f"{output_dir}/corpus_{1000 * ((args.end - args.start) // 1000)}.txt"
    with open(final_file, "w") as f:
        for example in examples:
            f.write(json.dumps(example) + "\n")
