
import json
import os

import fire

from utils import DATA_PATH, clean_generation

from transformers import AutoProcessor
from termcolor import colored
from tqdm.auto import tqdm
from vllm import LLM, SamplingParams

def load_data(
    data_path: str
) -> list[str]:
    data = []
    with open(data_path, "r") as fin:
        for line in fin:
            data.append(json.loads(line))
    return data

def build_prompts(
    texts: list[str],
    dataset_name: str,
) -> list[str]:
    RESPOND_REDDIT_PROMPT = """Write a response to this Reddit comment: {}

Keep the response around {} words.

Do not include the original comment in your response.

Only output the comment, do not include any other details.

Response:
"""
    RESPOND_AMAZON_PROMPT = """Here's an Amazon review: {}

Please write another review, of about {} words, but about something different.

Do not include the original review in your response.

Only output the review, do not include any other details.

Response:
"""
    RESPOND_BLOG_PROMPT = """Here's a snippet of a Blog post: {}
    
Please write another snippet, of about {} words, but about something different.

Do not include the original snippet in your response.

Only output the snippet, do not include any other details.

Response:
"""

    d = {
        "reddit": RESPOND_REDDIT_PROMPT,
        "amazon": RESPOND_AMAZON_PROMPT,
        "blogs": RESPOND_BLOG_PROMPT,
    }

    prompts = []
    for t in texts:
        lenwords = len(t.split(" "))
        prompts.append(d[dataset_name].format(t, lenwords))

    return prompts

def main(
    data_name: str = "MTD_reddit.jsonl",
    model_name: str = "mistralai/Mistral-7B-Instruct-v0.3",
    dataset_name: str = "reddit",
    batch_size: int = 64,
    max_new_tokens: int = 128+32,
    temperature: float = 0.7,
    top_p: float = 0.9,
    num_generations: int = 1,
    debug: bool = False,
):
    assert model_name in ["mistralai/Mistral-7B-Instruct-v0.3", 
                          "meta-llama/Meta-Llama-3-8B-Instruct", 
                          "google/gemma-3-4b-pt"], "Invalid model name"

    arguments = locals()
    for key, value in arguments.items():
        print(colored(key, "green"), "=", colored(value, "yellow"))

    mtd_path = os.path.join(DATA_PATH, "mtd")
    data = load_data(os.path.join(mtd_path, data_name))
    if debug:
        data = data[:10]
    prompts = build_prompts([d["content_text"] for d in data], dataset_name)

    model = LLM(model_name)

    os.makedirs(os.path.join(mtd_path, "generations"), exist_ok=True)
    tmpstr = data_name.replace(".jsonl", "").replace("_", "-")
    savename = os.path.join(
        mtd_path,
        "generations",
        f"{tmpstr}_{os.path.basename(model_name)}_temperature={temperature}_top-p={top_p}_ng={num_generations}"
    )
    savename += "-vllm"
    savename += ".jsonl"
    if debug:
        savename += ".debug"
    print(colored(f"savename={savename}", "yellow"))

    sampling_params = SamplingParams(
        n=num_generations,
        max_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
    )

    generations = []
    for i in tqdm(range(0, len(prompts), batch_size)):
        batch = prompts[i:i+batch_size]
        batch_transfer = model.generate(
            batch,
            sampling_params,
        )
        breakpoint()
        batch_transfer = [[o.text.strip() for o in out.outputs] for out in batch_transfer]
        generations.extend(batch_transfer)
        
    assert len(data) == len(generations)
    with open(savename, "w+") as fout:
        for d, glst in zip(data, generations):
            d["respond_reddit"] = [clean_generation(g, is_reddit=True) for g in glst]
            fout.write(json.dumps(d) + "\n")
        
    return 0

if __name__ == "__main__":
    fire.Fire(main)