import os
import json
from vllm import LLM, SamplingParams
from tqdm import tqdm
from transformers import AutoTokenizer

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

DATA_DIR = "en_data"
OUTPUT_DIR = "results_gpt_oss_20b_3"
os.makedirs(OUTPUT_DIR, exist_ok=True)

MODEL_NAME = "openai/gpt-oss-20b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
RETURN_ID = tokenizer.convert_tokens_to_ids("<|return|>")
STOP_IDS = [RETURN_ID] if RETURN_ID is not None else None

sampling_params = SamplingParams(
    temperature=0.6,
    top_p=0.9,
    max_tokens=3072,
    stop_token_ids=STOP_IDS
)

def build_prompt_with_gold(dataset, current_idx):

    messages = []
    for i in range(current_idx):
        messages.append({"role": "user", "content": dataset[i]["input"]})
        if "output" in dataset[i]:
            messages.append({"role": "assistant", "content": dataset[i]["output"]})
    messages.append({"role": "user", "content": dataset[current_idx]["input"]})

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    return prompt

def run_inference(llm, model_name, dataset, file_name):
    print(f"Running inference with {model_name} on {file_name} ...")
    results = []
    for i in tqdm(range(len(dataset))):
        prompt = build_prompt_with_gold(dataset, i)
        outputs = llm.generate([prompt], sampling_params)
        model_output = outputs[0].outputs[0].text.strip()
        results.append({
            "turn": i + 1,
            "input": dataset[i]["input"],
            "reference_output": dataset[i].get("output", None),
            "model_output": model_output,
            "prompt_used": prompt
        })
    return results

def main():
    files = [f for f in os.listdir(DATA_DIR) if f.endswith(".json")]

    print(f"Loading model: {MODEL_NAME}")
    llm = LLM(
        model=MODEL_NAME,
        dtype="bfloat16",
        tensor_parallel_size=4
    )

    model_tag = MODEL_NAME.split("/")[-1]

    for file in files:

        save_path = os.path.join(
            OUTPUT_DIR,
            f"{file.replace('.json', '')}_{model_tag}.json"
        )

        if os.path.exists(save_path):
            print(f"Skip (already done): {save_path}")
            continue

        with open(os.path.join(DATA_DIR, file), "r", encoding="utf-8") as f:
            dataset = json.load(f)

        results = run_inference(llm, MODEL_NAME, dataset, file)

        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

        print(f"Results saved to {save_path}")

if __name__ == "__main__":
    main()
