import argparse 
import Levenshtein
import numpy as np 
import openai
import pandas as pd 
import pickle
import scipy
import time
import torch 

from pathlib import Path 
from torch.utils.data import DataLoader 
from typing import List 

from mi_estimation_utils import load_data, MyDataset


def _count_overlap(x: List, y: List) -> float:
    overlapped = 0
    for word in x:
        if word in y:
            overlapped += 1
    return overlapped  / len(x)

def xe_evaluation_score(dataset_name: str, ds: MyDataset, rds: MyDataset) -> pd.DataFrame:
    """
    Compute some scores for each row in the esnli data. Save to the dataframe as standalone columns.
    """
    scores = {"simXE": [], "edit_distance_ratio": [], "type_overlap_ratio": []}

    for i in range(len(ds)):
        X, Y, E = ds[i]["X"], ds[i]["Y"], ds[i]["E"]
        Xs, Ys, Es = rds[i]["X"], rds[i]["Y"], rds[i]["E"]
        Xs_L = Xs.split()
        Es_L = Es.split()

        # Cosine similarity: How much are X and E similar in the embedding space?
        sim_X_E = 1 - scipy.spatial.distance.cosine(X, E)
        scores["simXE"].append(sim_X_E)

        # Levenshtein distance (normalized by len(X)): how much do I need to edit, to get Es from Xs?
        dist = Levenshtein.distance(Xs_L, Es_L)
        scores["edit_distance_ratio"].append(dist / len(Xs_L))

        # Type (unique words) overlap
        X_types = list(set(Xs_L))
        E_types = list(set(Es_L))
        scores["type_overlap_ratio"].append(_count_overlap(X_types, E_types))

    return pd.DataFrame(scores)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, choices=["rationale", "nle"], default="nle")
    parser.add_argument("--dataset", type=str, default="esnli")
    parser.add_argument("--embedding", type=str, default="openai")
    parser.add_argument("--downsample", type=int, default=1200)
    parser.add_argument("--report_dir", type=str, default="../data/scored/")
    args = parser.parse_args()

    print(args)

    Path(args.report_dir).mkdir(parents=True, exist_ok=True)
    train_ds, val_ds, test_ds = load_data(args, rawtext=False)
    rtrain_ds, rval_ds, rtest_ds = load_data(args, rawtext=True)
    for split, ds, rds in zip(["test"], [test_ds], [rtest_ds]):
        start_time = time.time()
        report_df = xe_evaluation_score(args.dataset, ds, rds)
        report_df.to_csv(Path(args.report_dir, f"{args.dataset}_{args.method}_{args.downsample}_{args.embedding}_{split}_XE_similarity.csv"), index=False)
        print("{} split done. Time slapsed: {:.4f} seconds".format(
            split, time.time() - start_time))