
import importlib
import os

import fire
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from termcolor import colored
from transformers import set_seed
from transformers import set_seed
from tqdm import tqdm

tinystyler_module = importlib.util.module_from_spec(
    importlib.util.spec_from_file_location(
        "tinystyler",
        hf_hub_download(repo_id="tinystyler/tinystyler", filename="tinystyler.py"),
    )
)
tinystyler_module.__spec__.loader.exec_module(tinystyler_module)
get_tinystyler_model = tinystyler_module.get_tinystyler_model
get_target_style_embeddings = tinystyler_module.get_target_style_embeddings
run_tinystyler_batch = tinystyler_module.run_tinystyler_batch

def main(
    dataset_path: str = "./neurips/shards/MTD_reddit_12000_correct_Mistral-7B-Instruct-v0.3_N=5.jsonlshard-1-4",
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_new_tokens: int = 128 + 32,
    batch_size: int = 256,
    outdir: str = "./neurips/tinystyler",
    debug: bool = False,
):
    device = torch.device("cuda")
    tokenizer, model = get_tinystyler_model(device)

    nrows = 10 if debug else None
    df = pd.read_json(dataset_path, lines=True, nrows=nrows)
    # generation -> reference_text
    N = len(df)
    source = df["respond_reddit"].tolist()
    target = df["transfer_reference_text"].tolist()
    generations = []
    for i in tqdm(range(0, N, batch_size)):
        batch_source = source[i:i+batch_size]
        batch_target = target[i:i+batch_size]
        inputs = tokenizer(
            batch_source, 
            padding="longest", 
            truncation=True, 
            return_tensors="pt"
        ).to(device)
        style = get_target_style_embeddings(batch_target, device).to(device)

        output = model.generate(
            **inputs,
            style=style,
            do_sample=True,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
        )
        generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)
        generations.extend(generated_text)

    df["transfer_pick"] = generations
    
    os.makedirs(outdir, exist_ok=True)
    if "shard" in os.path.basename(dataset_path):
        savename = os.path.basename(dataset_path).replace(".jsonl", ".")
    else:
        savename = os.path.basename(dataset_path).replace(".jsonl", "")
    savename += "_top-p={}_temp={}".format(top_p, temperature) + ".jsonl"
    savename += ".debug" if debug else ""
    savename = os.path.join(outdir, savename)
    print(colored("Saving to: {}".format(savename), "yellow"))
    df.to_json(savename, lines=True, orient="records")

if __name__ == "__main__":
    set_seed(43)
    fire.Fire(main)