
import editdistance
import fire
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from genpaths import *

def main(
    debug: bool = False,
):
    METHOD_ORDER = [
        # "Baseline",
        # "LLMOPT",
        "Prompting",
        "Paraphrasing",
        "DIPPER",
        "TinyStyler",
        "Ours",
    ]
    
    DATASET_ORDER = ["reddit", "amazon", "blogs"]
    DATASET_IDX = [2, 0, 1]
    # DATASET_ORDER = ["reddit"]
    # DATASET_IDX = [2]
    
    ORIG = [
        "/data1/yubnub/data/style_transfer/mtd/MTD_reddit_12000_correct.jsonl.ready",
        "/data1/yubnub/data/style_transfer/mtd/MTD_amazon_12000.jsonl.ready",
        "/data1/yubnub/data/style_transfer/mtd/MTD_blogs_7000.jsonl.ready",
    ]
    
    model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    model.cuda()
    model.eval()

    all_results = {method: {'ed_mean': [], 'ed_std': [], 'sim_mean': [], 'sim_std': []} for method in METHOD_ORDER}

    for dataset in DATASET_ORDER:
        
        semantic_similarities = []
        edit_distances = []
        
        for method in METHOD_ORDER:
            path = globals()[method.upper()][DATASET_IDX[DATASET_ORDER.index(dataset)]]
            nrows = 100 if debug else None
            df_method = pd.read_json(path, lines=True, nrows=nrows)
            column = globals()[method.upper()][-1]
            df_orig = pd.read_json(ORIG[DATASET_ORDER.index(dataset)], lines=True, nrows=nrows)

            orig_text = df_orig["respond_reddit"].tolist()
            method_text = df_method[column].tolist()
            if isinstance(method_text[0], list):
                method_text = [t[0] for t in method_text]
                
            # indices_to_remove = []
            # for i, (orig_str, method_str) in enumerate(zip(orig_text, method_text)):
            #     if len(orig_str) == 0 or len(method_str) == 0:
            #         indices_to_remove.append(i)
            #     if not isinstance(orig_str, str) or not isinstance(method_str, str):
            #         indices_to_remove.append(i)
            # orig_text = [orig_text[i] for i in range(len(orig_text)) if i not in indices_to_remove]
            # method_text = [method_text[i] for i in range(len(method_text)) if i not in indices_to_remove]

            orig_emb = model.encode(orig_text, normalize_embeddings=True, convert_to_tensor=True, batch_size=32, show_progress_bar=True)
            method_emb = model.encode(method_text, normalize_embeddings=True, convert_to_tensor=True, batch_size=32, show_progress_bar=True)
            try:
                semsim_t = F.cosine_similarity(orig_emb, method_emb)
                semsim = semsim_t.mean().item()
                semsim_std = semsim_t.std().item()
            except:
                breakpoint()
            
            edit_dists = [editdistance.eval(s1, s2) for s1, s2 in zip(orig_text, method_text)]
            editdist = np.mean(edit_dists)
            editdist_std = np.std(edit_dists)
            
            semantic_similarities.append((semsim, semsim_std))
            edit_distances.append((editdist, editdist_std))
            
            all_results[method]['sim_mean'].append(semsim)
            all_results[method]['sim_std'].append(semsim_std)
            all_results[method]['ed_mean'].append(editdist)
            all_results[method]['ed_std'].append(editdist_std)

        print(dataset)
        latex_line = "Edit Distance & 0.0 & - & "
        for edist, estd in edit_distances:
            latex_line += "{:.2f} ({:.2f}) & ".format(edist, estd)
        print(latex_line[:-3] + " \\\\")
        latex_line = "Semantic Sim. & 1.0 & - & "
        for ssim, sstd in semantic_similarities:
            latex_line += "{:.2f} ({:.2f}) & ".format(ssim, sstd)
        print(latex_line[:-3] + " \\\\")
        print("*"*50)
        # breakpoint()

    print("Average")
    latex_line = "Edit Distance & 0.0 & - & "
    for method in METHOD_ORDER:
        avg_ed_mean = np.mean(all_results[method]['ed_mean'])
        avg_ed_std = np.mean(all_results[method]['ed_std'])
        latex_line += "{:.2f} ({:.2f}) & ".format(avg_ed_mean, avg_ed_std)
    print(latex_line[:-3] + " \\\\")

    latex_line = "Semantic Sim. & 1.0 & - & "
    for method in METHOD_ORDER:
        avg_sim_mean = np.mean(all_results[method]['sim_mean'])
        avg_sim_std = np.mean(all_results[method]['sim_std'])
        latex_line += "{:.2f} ({:.2f}) & ".format(avg_sim_mean, avg_sim_std)
    print(latex_line[:-3] + " \\\\")
    print("*"*50)
    
    return 0


if __name__ == "__main__":
    fire.Fire(main)
