"""TODO: 
    - Extend beyond Reddit
    - Watermarking (maybe not needed)
"""
import json
import os

import fire

from utils import DATA_PATH, clean_generation
from nicks_dpo.prompts import *

from termcolor import colored
from tqdm.auto import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

def load_data(
    data_path: str
) -> list[str]:
    data = []
    with open(data_path, "r") as fin:
        for line in fin:
            data.append(json.loads(line))
    return data

def build_respond_prompts(
    texts: list[str],
    dataset_name: str,
) -> list[str]:
    d = {
        "reddit": RESPOND_REDDIT_PROMPT,
        "amazon": RESPOND_AMAZON_PROMPT,
        "blogs": RESPOND_BLOG_PROMPT,
    }

    prompts = []
    for text in texts:
        lenwords = len(text.split(" "))
        prompts.append(d[dataset_name].format(text, lenwords))
    return prompts

def main(
    data_name: str = "MTD_reddit_preference_10000_correct.jsonl",
    model_name: str = "mistralai/Mistral-7B-Instruct-v0.3",
    batch_size: int = 16,
    max_new_tokens: int = 128+32,
    temperature: float = 0.7,
    top_p: float = 0.9,
    num_generations: int = 2,
    debug: bool = False,
):
    # if not (model_name == "mistralai/Mistral-7B-Instruct-v0.3" or os.path.isdir(model_name)):
    #     raise ValueError("Invalid model name or path")

    arguments = locals()
    for key, value in arguments.items():
        print(colored(key, "green"), "=", colored(value, "yellow"))

    mtd_path = os.path.join(DATA_PATH, "mtd")
    data = load_data(os.path.join(mtd_path, data_name))
    if debug:
        data = data[:10]

    if "reddit" in data_name:
        dataset_name = "reddit"
    elif "amazon" in data_name:
        dataset_name = "amazon"
    elif "blogs" in data_name:
        dataset_name = "blogs"
    else:
        assert False
    prompts = build_respond_prompts([d["content_text"] for d in data], dataset_name)
    if "qwen" in model_name.lower():
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        prompts = [tokenizer.apply_chat_template([{"role":"user", "content":p}], tokenize=False, add_generation_prompt=False) for p in prompts]
    print(prompts[0])

    savedir = "./nicks_dpo/generations_neurips"
    os.makedirs(savedir, exist_ok=True)
    tmpstr = data_name.replace(".jsonl", "").replace("_", "-")

    if not os.path.isdir(model_name):
        modelstr = os.path.basename(model_name)
    else:
        modelstr = os.path.basename(model_name) + "-" + "-".join(os.path.basename(os.path.dirname(model_name)).split("-")[1:])
        
    savename = os.path.join(
        savedir,
        f"{tmpstr}_{modelstr}_temperature={temperature}_top-p={top_p}_ng={num_generations}"
    )
    if os.path.isdir(model_name):
        savename += "-preference"
    savename += ".jsonl.rebuttal"
    if debug:
        savename += ".debug"
    print(colored(f"savename={savename}", "yellow"))
    
    # model = LLM(model_name, max_seq_len_to_capture=4096,tensor_parallel_size=4)
    model = LLM(
        model_name, 
        max_model_len=4096,
        gpu_memory_utilization=0.75,
    )
    sampling_params = SamplingParams(
        n=num_generations,
        max_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
    )

    generations = []
    for i in tqdm(range(0, len(prompts), batch_size)):
        batch = prompts[i:i+batch_size]
        batch_transfer = model.generate(
            batch,
            sampling_params,
        )
        batch_transfer = [[o.text.strip() for o in out.outputs] for out in batch_transfer]
        generations.extend(batch_transfer)

    assert len(data) == len(generations)
    is_reddit = dataset_name == "reddit"
    with open(savename, "w+") as fout:
        for d, glst in zip(data, generations):
            d["respond_reddit"] = [clean_generation(g, is_reddit=is_reddit) for g in glst]
            fout.write(json.dumps(d) + "\n")
        
    return 0

if __name__ == "__main__":
    fire.Fire(main)
