import argparse
import os
from pathlib import Path

import jsonl
from transformers import AutoTokenizer, PreTrainedTokenizer
from vllm import LLM, SamplingParams


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="open-unlearning/tofu_Llama-3.1-8B-Instruct_full")
    parser.add_argument("--data", type=str, default="data/src/tofu_forget10.jsonl")
    parser.add_argument("--output", type=str)
    args = parser.parse_args()

    model_name = args.model
    data_path = Path(args.data)
    output_path = Path(args.output)

    tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_name)

    prompts = []
    data = list(jsonl.load(data_path))
    for item in data:
        prompt = tokenizer.apply_chat_template(
            [{"role": "user", "content": item["query"]}],
            tokenize=False,
            add_generation_prompt=True,
        )
        prompts.append(prompt)

    model = LLM(model=model_name)
    generations = model.generate(
        prompts,
        sampling_params=SamplingParams(n=1, temperature=0, max_tokens=128),
    )
    outputs = []
    for generation, item in zip(generations, data):
        response = generation.outputs[0].text
        item["response"] = response
        outputs.append(item)

    if output_path.exists():
        print(f"Overwrite {output_path}, move the original to .bak")
        os.rename(output_path, output_path.with_suffix(".bak"))
    output_path.parent.mkdir(exist_ok=True, parents=True)

    print(f"Output path: {output_path}")
    jsonl.dump(outputs, output_path)
    print(f"Saved {len(outputs)} generated responses to {output_path}")


if __name__ == "__main__":
    main()
