import datasets
import mobypy
from itertools import chain
import numpy as np
import pickle

def main(dataset):
    if dataset == "emotion":
        label_to_name_mapping = {
                0: "sadness",
                1: "joy",
                3: "anger",
                4: "fear"
        }
    elif dataset == "trec10":
        label_to_name_mapping = {
                0: "description",
                1: "entity",
                3: "human",
                4: "number",
                5: "location"
        }
    elif dataset == "agnews":
        label_to_name_mapping = {
                0: "world",
                1: "sports",
                2: "business",
                3: "sci/tech"
        }
    elif dataset == "tacred":
        labels_path = "data/tacred/0/labels_map.pkl"
        with open(labels_path, "rb") as f:
            name_to_label_mapping = pickle.load(f)
        label_to_name_mapping = {0: []}
        for name, label in name_to_label_mapping.items():
            if label == 0:
                label_to_name_mapping[0].append(name)
            else:
                label_to_name_mapping[label] = name
        # Process label_to_name_mapping
        def normalize(name):
            return name.replace("org:", "").replace("per:", "").replace("_", " ")
        for label, name in label_to_name_mapping.items():
            if type(name) == list:
                label_to_name_mapping[label] = [normalize(n) for n in name]
            else:
                label_to_name_mapping[label] = normalize(name)
    else:
        raise ValueError

    fraction_deleted = []

    for ood_class in label_to_name_mapping.keys():
        input_dataset = f"{dataset}/{ood_class}/generations_100k/"
        output_dataset = f"{dataset}/{ood_class}/generations_100k_filtered/"
        id_labels = [label for i, label in label_to_name_mapping.items() if i != ood_class]
        d = datasets.load_from_disk(input_dataset)
        labels = list(set(d['label']))
        synonyms = list(set(chain.from_iterable([mobypy.synonyms(label) for label in id_labels])))
        keep_labels = [label for label in labels if label not in synonyms]
        removed_labels = [label for label in synonyms if label in labels]
        print(f"Went from {len(labels)} labels to {len(keep_labels)}")
        fraction_deleted.append(len(removed_labels)/len(labels))
        print("Deleted labels", removed_labels)
        new_d = d.filter(lambda ex: ex["label"] in keep_labels)
        print(f"Went from {len(d)} to {len(new_d)} examples for split {ood_class}") 
        new_d.save_to_disk(output_dataset)

    print("Average deleted", np.mean(fraction_deleted))

if __name__ == "__main__":
    main("trec10")
