import os
import json
from vllm import LLM, SamplingParams
from tqdm import tqdm
from transformers import AutoTokenizer

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

DATA_DIR = "en_data1"
OUTPUT_DIR = "results_gpt_oss_20b"
os.makedirs(OUTPUT_DIR, exist_ok=True)

MODEL_NAME = "openai/gpt-oss-20b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
RETURN_ID = tokenizer.convert_tokens_to_ids("<|return|>")
STOP_IDS = [RETURN_ID] if RETURN_ID is not None else None

sampling_params = SamplingParams(
    temperature=0.6,
    top_p=0.9,
    max_tokens=3072,
    stop_token_ids=STOP_IDS
)

BATCH_SIZE = 8

def get_max_ctx(llm):
    return llm.llm_engine.model_config.max_model_len

def toks_len(messages):
    ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
    return len(ids)

def trim_to_fit(messages, max_ctx, max_new):
    budget = max_ctx - max_new

    while len(messages) > 1 and toks_len(messages) > budget:
        if len(messages) >= 2 and messages[0]["role"] == "user" and messages[1]["role"] == "assistant":
            del messages[:2]

    return messages

def build_prompt_with_gold(dataset, current_idx, max_ctx, max_new):
    msgs = []
    for i in range(current_idx):
        msgs.append({"role": "user", "content": dataset[i]["input"]})
        if "output" in dataset[i]:
            msgs.append({"role": "assistant", "content": dataset[i]["output"]})
    msgs.append({"role": "user", "content": dataset[current_idx]["input"]})

    msgs = trim_to_fit(msgs, max_ctx, max_new)
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def run_inference(llm, model_name, dataset, file_name, max_ctx):
    print(f"Running inference with {model_name} on {file_name} ...")
    results = [None] * len(dataset)

    for start in tqdm(range(0, len(dataset), BATCH_SIZE)):
        end = min(start + BATCH_SIZE, len(dataset))
        batch_prompts = [build_prompt_with_gold(dataset, i, max_ctx, sampling_params.max_tokens) for i in range(start, end)]

        outputs = llm.generate(batch_prompts, sampling_params)

        for k, out in enumerate(outputs):
            model_output = out.outputs[0].text.strip()
            idx = start + k
            results[idx] = {
                "turn": idx + 1,
                "input": dataset[idx]["input"],
                "reference_output": dataset[idx].get("output", None),
                "model_output": model_output,
                "prompt_used": batch_prompts[k]
            }
    return results

def main():
    files = [f for f in os.listdir(DATA_DIR) if f.endswith(".json")]

    print(f"Loading model: {MODEL_NAME}")
    llm = LLM(
        model=MODEL_NAME,
        dtype="bfloat16",
        tensor_parallel_size=2,
        enable_prefix_caching=True,
        enable_chunked_prefill=True
    )

    max_ctx = get_max_ctx(llm)
    print(f"[budget] prompt_budget = {max_ctx - sampling_params.max_tokens} (= {max_ctx} - {sampling_params.max_tokens})")

    model_tag = MODEL_NAME.split("/")[-1]

    for file in files:

        save_path = os.path.join(
            OUTPUT_DIR,
            f"{file.replace('.json', '')}_{model_tag}.json"
        )

        if os.path.exists(save_path):
            print(f"Skip (already done): {save_path}")
            continue

        with open(os.path.join(DATA_DIR, file), "r", encoding="utf-8") as f:
            dataset = json.load(f)

        results = run_inference(llm, MODEL_NAME, dataset, file, max_ctx)

        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)

        print(f"Results saved to {save_path}")

if __name__ == "__main__":
    main()
