"""
Refine a fully MGT into human-like text. 
"""

import os
import random; random.seed(43)

import fire
import pandas as pd
import torch
from vllm import LLM, SamplingParams

from embedding_utils import get_instance_embeddings, load_sbert_model
from utils import MODEL_PATH, build_style_transfer_prompts
from termcolor import colored
from tqdm import tqdm

def read_data(dataset_path, debug):
    N = 10 if debug else None
    df = pd.read_json(dataset_path, lines=True, nrows=N)
    df.rename(columns={"respond_reddit":"generation", "paraphrase_respond_reddit":"paraphrase_generation"}, inplace=True)
    return df

def choose_best_content(
    df: pd.DataFrame,
) -> pd.DataFrame:
    model = load_sbert_model()
    if torch.cuda.is_available():
        model.cuda()
    
    orig_generations = df["generation"].tolist()
    transfer_text = df["transfer_text"].tolist()
    transfer_text = [j for i in transfer_text for j in i]
    
    function_kwargs = {
        "model": model,
        "progress_bar": True,
    }
    
    orig_emb = get_instance_embeddings(orig_generations, function_kwargs=function_kwargs, model_name="sbert")
    transfer_emb = get_instance_embeddings(transfer_text, function_kwargs=function_kwargs, model_name="sbert")

    num_generations = [0] + df.transfer_text.apply(len).tolist()
    cossim = torch.nn.CosineSimilarity(dim=-1)
    best_transfer_text = []
    for i in range(len(orig_emb)):
        emb_1 = orig_emb[i:i+1]
        start = sum(num_generations[:i+1])
        end = start + num_generations[i+1]
        emb_2 = transfer_emb[start:end]
        sims = cossim(emb_1, emb_2)
        best_indices = sims.flatten().cpu().numpy().argsort()[-5:]
        best_indices = [index + start for index in best_indices]
        best_transfer_text.append([transfer_text[bi] for bi in best_indices])
        
    df["transfer_pick"] = best_transfer_text
    return df

def main(
    dataset_path: str = "./data/MTD_reddit_12000_correct_Mistral-7B-Instruct-v0.3_N=5.jsonl",
    model_name: str = "MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3",
    dirname: str = "",
    checkpoint_num: int = 1_629, # preference
    num_exemplars: int = None,
    temperature: float = 0.6,
    top_p: float = 1.0,
    num_generations: int = 10,
    batch_size: int = 50,
    debug: bool = False,
):
    if "iter" in dataset_path:
        df = pd.read_json(dataset_path, lines=True)
    else:
        df = read_data(dataset_path, debug)

    if "iter" in dataset_path:
        fname = os.path.basename(dataset_path)
        iter_num = int(fname[fname.index(".iter=")+len(".iter="):])
    else:
        iter_num = 0
    next_iter_num = iter_num + 1

    if "transfer_text" in df.columns and df["transfer_text"].apply(lambda x: len(x)).max() > 1:
        print(colored("Found `transfer_text` column with more than one generation, choosing best content preserving generation", "yellow"))
        df.rename(columns={
            "transfer_pick": f"transfer_pick_iter={iter_num}",
            "transfer_text": f"transfer_text_iter={iter_num}",
        }, inplace=True)
        picked_best = True
        column = f"transfer_pick_iter={iter_num}"
    else:
        picked_best = False

    if picked_best:
        # source "paraphrases" come from the best transfers from previous iteration
        print(colored(f"Using the `{column}` column as our source paraphrases", "yellow"))
        source_paraphrases = df[column].tolist()
    else:
        # otherwise they're just base Mistral paraphrases
        print(colored("Using the `paraphrase_generation` column as our source paraphrases", "yellow"))
        source_paraphrases = df["paraphrase_generation"].tolist()
    target_texts = df["transfer_reference_text"].tolist()
    target_paraphrases = df["paraphrase_transfer_reference_text"].tolist()

    print(colored("Creating prompts...", "yellow"))
    if num_exemplars is not None and isinstance(num_exemplars, int):
        target_texts = [t[:num_exemplars] for t in target_texts]
        target_paraphrases = [t[:num_exemplars] for t in target_paraphrases]

    prompts = build_style_transfer_prompts(
        source_paraphrases=source_paraphrases,
        target_texts=target_texts,
        target_paraphrases=target_paraphrases,
        progress_bar=True,
    )

    load_name = os.path.join(MODEL_PATH, model_name, "r=32_alpha=64_dropout=0.1_perc=1.0", dirname, f"checkpoint-{checkpoint_num}_merged")
    model = LLM(
        load_name,
        gpu_memory_utilization=0.75,
        max_model_len=10_000,
    )
    sampling_params = SamplingParams(
        n=num_generations,
        max_tokens=128+64,
        temperature=temperature,
        top_p=top_p,
        stop="\n#####\n",
    )

    transfer_text = []
    for i in tqdm(range(0, len(prompts), batch_size)):
        batch_prompts = prompts[i:i+batch_size]
        batch_transfer = model.generate(
            batch_prompts,
            sampling_params,
        )
        batch_transfer = [list(set([o.text.strip() for o in out.outputs])) for out in batch_transfer]
        transfer_text.extend(batch_transfer)

    df["transfer_text"] = transfer_text
    df = choose_best_content(df)

    savename = dataset_path
    if "iter" in savename:
        savename = savename[:savename.index(".iter=")]
    if num_exemplars is not None:
        savename += ".ne={}".format(num_exemplars)
    savename += ".debug" if debug and "debug" not in savename else ""
    savename += ".iter={}".format(next_iter_num)
    print(colored("Saving to: {}".format(savename), "yellow"))
    df.to_json(savename, lines=True, orient="records")
    
    return 0

if __name__ == "__main__":
    fire.Fire(main)