
import editdistance
import fire
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from genpaths import *

def main(
    debug: bool = False,
):
    METHOD_ORDER = [
        # "Baseline",
        # "LLMOPT",
        "Prompting",
        "Paraphrasing",
        "DIPPER",
        "TinyStyler",
        "Ours",
    ]
    
    DATASET_ORDER = ["reddit", "amazon", "blogs"]
    DATASET_IDX = [2, 0, 1]
    # DATASET_ORDER = ["reddit"]
    # DATASET_IDX = [2]
    
    ORIG = [
        "/FOO/BAR/data/style_transfer/mtd/MTD_reddit_12000_correct.jsonl.ready",
        "/FOO/BAR/data/style_transfer/mtd/MTD_amazon_12000.jsonl.ready",
        "/FOO/BAR/data/style_transfer/mtd/MTD_blogs_7000.jsonl.ready",
    ]
    
    model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    model.cuda()
    model.eval()

    for dataset in DATASET_ORDER:
        
        semantic_similarities = []
        edit_distances = []
        
        for method in METHOD_ORDER:
            path = globals()[method.upper()][DATASET_IDX[DATASET_ORDER.index(dataset)]]
            nrows = 100 if debug else None
            df_method = pd.read_json(path, lines=True, nrows=nrows)
            column = globals()[method.upper()][-1]
            df_orig = pd.read_json(ORIG[DATASET_ORDER.index(dataset)], lines=True, nrows=nrows)

            orig_text = df_orig["respond_reddit"].tolist()
            method_text = df_method[column].tolist()
            if isinstance(method_text[0], list):
                method_text = [t[0] for t in method_text]
                
            # indices_to_remove = []
            # for i, (orig_str, method_str) in enumerate(zip(orig_text, method_text)):
            #     if len(orig_str) == 0 or len(method_str) == 0:
            #         indices_to_remove.append(i)
            #     if not isinstance(orig_str, str) or not isinstance(method_str, str):
            #         indices_to_remove.append(i)
            # orig_text = [orig_text[i] for i in range(len(orig_text)) if i not in indices_to_remove]
            # method_text = [method_text[i] for i in range(len(method_text)) if i not in indices_to_remove]

            orig_emb = model.encode(orig_text, normalize_embeddings=True, convert_to_tensor=True, batch_size=32, show_progress_bar=True)
            method_emb = model.encode(method_text, normalize_embeddings=True, convert_to_tensor=True, batch_size=32, show_progress_bar=True)
            try:
                semsim = F.cosine_similarity(orig_emb, method_emb).mean()
            except:
                breakpoint()
            editdist = np.mean([editdistance.eval(s1, s2) for s1, s2 in zip(orig_text, method_text)])
            semantic_similarities.append(semsim); edit_distances.append(editdist)

        print(dataset)
        latex_line = "Edit Distance & 0.0 & - & "
        for edist in edit_distances:
            latex_line += "{:.2f} & ".format(edist)
        print(latex_line[:-3] + " \\\\")
        latex_line = "Semantic Sim. & 1.0 & - & "
        for sims in semantic_similarities:
            latex_line += "{:.2f} & ".format(sims)
        print(latex_line[:-3] + " \\\\")
        print("*"*50)
        # breakpoint()
    
    return 0


if __name__ == "__main__":
    fire.Fire(main)