
import json
import os

import fire
import pandas as pd
import torch
from termcolor import colored
from tqdm import tqdm

from embedding_utils import (
    load_cisr_model,
    load_sd_model,
    load_luar_model_and_tokenizer,
    load_sbert_model,
    get_instance_embeddings,
    get_author_embeddings,
)

def flatten(lst):
    return [item for sublist in lst for item in sublist]

def calculate_metrics(
    df: pd.DataFrame,
    evaluation_key,
    luar,
    luar_tok,
    sbert,
    cisr,
    sd,
    transfer_key = "transfer_pick",
):
    cossim = torch.nn.CosineSimilarity(dim=-1)

    # SBERT (Source Text, Transfer Text):
    sbert_content_embeddings = get_instance_embeddings(
        flatten(df["content_text"].tolist()),
        function_kwargs={"model": sbert},
        model_name="sbert",
    )
    sbert_transfer_embeddings = get_instance_embeddings(
        flatten(df[transfer_key].tolist()),
        function_kwargs={"model": sbert},
        model_name="sbert",
    )
    sbert_sim = cossim(sbert_content_embeddings, sbert_transfer_embeddings).mean().item()
    
    # LUAR (Reference, Transfer):
    luar_reference_embeddings = torch.cat([
        get_author_embeddings(
            text,
            {"luar": luar, "luar_tok": luar_tok}, 
            "mud"
        ) for text in tqdm(df[evaluation_key].tolist())
    ], dim=0)
    luar_transfer_embeddings = torch.cat([
        get_author_embeddings(
            text,
            {"luar": luar, "luar_tok": luar_tok}, 
            "mud"
        ) for text in tqdm(df[transfer_key].tolist())
    ], dim=0)
    luar_sim = cossim(luar_reference_embeddings, luar_transfer_embeddings).mean().item()

    # CISR (Reference, Transfer):
    cisr_reference_embeddings = torch.cat([
        get_author_embeddings(
            text,
            {"model": cisr}, 
            "cisr"
        ) for text in tqdm(df[evaluation_key].tolist())
    ], dim=0)
    cisr_transfer_embeddings = torch.cat([
        get_author_embeddings(
            text,
            {"model": cisr}, 
            "cisr"
        ) for text in tqdm(df[transfer_key].tolist())
    ], dim=0)
    cisr_sim = cossim(cisr_reference_embeddings, cisr_transfer_embeddings).mean().item()

    # SD(Reference, Transfer):
    sd_reference_embeddings = torch.cat([
        get_author_embeddings(
            text,
            {"model": sd}, 
            "sd"
        ) for text in tqdm(df[evaluation_key].tolist())
    ], dim=0)
    sd_transfer_embeddings = torch.cat([
        get_author_embeddings(
            text,
            {"model": sd}, 
            "sd"
        ) for text in tqdm(df[transfer_key].tolist())
    ], dim=0)
    sd_sim = cossim(sd_reference_embeddings, sd_transfer_embeddings).mean().item()

    metrics = {
        "luar_sim": luar_sim,
        "sbert_sim": sbert_sim,
        "cisr_sim": cisr_sim,
        "sd_sim": sd_sim,
    }
    for k, v in metrics.items():
        print(colored(f"{k}: ", "green"), f"{v:.3f}")
    return metrics

def main(
    data_path: str = "./outputs/transfer_pick/stylemc_new_alternate_paraphrase_Mistral-7B-Instruct-v0.3_N=5_transfer_N=100_temp=0.7_top-p=0.9_model=MUD_paraphrase_Mistral-7B-Instruct-v0.3_N=5_small=False_8-3average_style=1.0_content=1.0_clean=False_key=evaluation_text.jsonl",
    evaluation_key: str = "reference_text",
    transfer_key: str = "transfer_pick",
    get_base: bool = False,
    get_gold: bool = False,
    outdir: str = "./neurips/transfer_metrics",
):
    df = pd.read_json(data_path, lines=True)
    df.rename(columns={"reference_text": "reference"}, inplace=True)

    if transfer_key == "transfer_text":
        df[transfer_key] = df[transfer_key].apply(lambda x: [y[0] for y in x])

    luar, luar_tok = load_luar_model_and_tokenizer("rrivera1849/LUAR-MUD")
    luar.to("cuda")
    sbert = load_sbert_model()
    sbert.to("cuda")
    cisr = load_cisr_model()
    cisr.to("cuda")
    sd = load_sd_model()
    sd.to("cuda")

    if get_base:
        calculate_metrics(df, "reference", luar, luar_tok, sbert, cisr, sd, transfer_key="content_text")
        return 0
    if get_gold:
        calculate_metrics(df, "evaluation_text", luar, luar_tok, sbert, cisr, sd, transfer_key="reference")
        return 0

    # Calculate & Save Metrics:
    metrics = calculate_metrics(df, evaluation_key, luar, luar_tok, sbert, cisr, sd, transfer_key=transfer_key)
    filename = os.path.basename(data_path)
    # try:
    #     index = filename.index("_average")
    # except:
    #     index = filename.index("_best_style")
    # savedir = os.path.join(outdir, filename[:index])
    savepath = os.path.join(outdir, filename).replace(".jsonl", ".json")
    # os.makedirs(savedir, exist_ok=True)
    # savepath = os.path.join(savedir, filename[index+1:].replace(".jsonl", "") + ".json")
    with open(savepath, "w") as fout:
        json.dump(metrics, fout, indent=4)

    return 0

if __name__ == "__main__":
    fire.Fire(main)