import os
import json
from vllm import LLM, SamplingParams
from tqdm import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

DATA_DIR = "en_data"
OUTPUT_DIR = "results_qwen3_32b"
os.makedirs(OUTPUT_DIR, exist_ok=True)

MODEL_NAME = "Qwen/Qwen3-32B"

sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=512
)

def build_prompt_with_gold(dataset, current_idx):
    conversation = ""
    for i in range(current_idx):
        user_inp = dataset[i]["input"]
        gold_out = dataset[i].get("output", "")
        conversation += f"User: {user_inp}\nAssistant: {gold_out}\n"
    conversation += f"User: {dataset[current_idx]['input']}\nAssistant:"
    return conversation

def run_inference(llm, model_name, dataset, file_name):
    print(f"Running inference with {model_name} on {file_name} ...")
    results = []
    for i in tqdm(range(len(dataset))):
        prompt = build_prompt_with_gold(dataset, i)
        outputs = llm.generate([prompt], sampling_params)
        model_output = outputs[0].outputs[0].text.strip()
        results.append({
            "turn": i + 1,
            "input": dataset[i]["input"],
            "reference_output": dataset[i].get("output", None),
            "model_output": model_output,
            "prompt_used": prompt
        })
    return results

def main():
    files = [f for f in os.listdir(DATA_DIR) if f.endswith(".json")]

    print(f"Loading model: {MODEL_NAME}")
    llm = LLM(
        model=MODEL_NAME,
        dtype="bfloat16",
        tensor_parallel_size=4
    )

    model_tag = MODEL_NAME.split("/")[-1]

    for file in files:

        save_path = os.path.join(
            OUTPUT_DIR,
            f"{file.replace('.json', '')}_{model_tag}.json"
        )

        if os.path.exists(save_path):
            print(f"Skip (already done): {save_path}")
            continue

        with open(os.path.join(DATA_DIR, file), "r", encoding="utf-8") as f:
            dataset = json.load(f)

        results = run_inference(llm, MODEL_NAME, dataset, file)


        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

        print(f"Results saved to {save_path}")

if __name__ == "__main__":
    main()
