import os
import json
from vllm import LLM, SamplingParams
from tqdm import tqdm
from transformers import AutoTokenizer

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

DATA_DIR = "en_data"
OUTPUT_DIR = "results_gpt_oss_120b_1"
os.makedirs(OUTPUT_DIR, exist_ok=True)

MODEL_NAME = "openai/gpt-oss-120b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
RETURN_ID = tokenizer.convert_tokens_to_ids("<|return|>")
STOP_IDS = [RETURN_ID] if RETURN_ID is not None else None

sampling_params = SamplingParams(
    temperature=0.6,
    top_p=0.9,
    max_tokens=3072,
    stop_token_ids=STOP_IDS
)

BATCH_SIZE = 32

def build_prompt_with_gold(dataset, current_idx):

    messages = []
    for i in range(current_idx):
        messages.append({"role": "user", "content": dataset[i]["input"]})
        if "output" in dataset[i]:
            messages.append({"role": "assistant", "content": dataset[i]["output"]})
    messages.append({"role": "user", "content": dataset[current_idx]["input"]})

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    return prompt

def run_inference(llm, model_name, dataset, file_name):
    print(f"Running inference with {model_name} on {file_name} ...")
    results = [None] * len(dataset)

    for start in tqdm(range(0, len(dataset), BATCH_SIZE)):
        end = min(start + BATCH_SIZE, len(dataset))
        batch_prompts = [build_prompt_with_gold(dataset, i) for i in range(start, end)]

        outputs = llm.generate(batch_prompts, sampling_params)

        for k, out in enumerate(outputs):
            model_output = out.outputs[0].text.strip()
            idx = start + k
            results[idx] = {
                "turn": idx + 1,
                "input": dataset[idx]["input"],
                "reference_output": dataset[idx].get("output", None),
                "model_output": model_output,
                "prompt_used": batch_prompts[k]
            }
    return results

def main():
    files = [f for f in os.listdir(DATA_DIR) if f.endswith(".json")]

    print(f"Loading model: {MODEL_NAME}")
    llm = LLM(
        model=MODEL_NAME,
        dtype="bfloat16",
        tensor_parallel_size=4,
        enable_prefix_caching=True,
        enable_chunked_prefill=True
    )

    model_tag = MODEL_NAME.split("/")[-1]

    for file in files:

        save_path = os.path.join(
            OUTPUT_DIR,
            f"{file.replace('.json', '')}_{model_tag}.json"
        )

        if os.path.exists(save_path):
            print(f"Skip (already done): {save_path}")
            continue

        with open(os.path.join(DATA_DIR, file), "r", encoding="utf-8") as f:
            dataset = json.load(f)

        results = run_inference(llm, MODEL_NAME, dataset, file)

        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

        print(f"Results saved to {save_path}")

if __name__ == "__main__":
    main()
