import pickle
import os
import numpy as np
import matplotlib.pyplot as plt

dump_path = "experiment_data/shape_correspondence/"
test_set = ["test-set1", "test-set2", "test-set3", "test-set4"]
# competitors = ["wot", "gw", "pgw", "ugw"]
competitors = ["gw", "ugw"]

ours = ["wot-b", "wot-u"]
ours_meyer = ["wot-b-meyer", "wot-u-meyer"]
ours_simple = ["wot-b-simple", "wot-u-simple"]

competitor_naming = {"gw": "Gromov-Wasserstein (GW)", "ugw": "Unbalanced GW"}
# union_naming = {"wot-b": "WOT (Heat Kernel)", "wot-p": "Partial WOT (Heat Kernel)", "wot-u": "Unbalanced WOT (Heat Kernel)"}
our_naming = {"wot-b": "E-WOT (Heat Kernel)", "wot-u": "Unbalanced E-WOT (Heat Kernel)"}
our_meyer_naming = {"wot-b-meyer": "E-WOT (Meyer)", "wot-u-meyer": "Unbalanced E-WOT (Meyer)"}
our_simple_naming = {"wot-b-simple": "E-WOT (Simple Tight)", "wot-u-simple": "Unbalanced E-WOT (Simple Tight)"}
out_learned_naming = {"wot-b-learned": "L-WOT (Heat Kernel)", "wot-u-learned": "Unbalanced L-WOT (Heat Kernel)"}

out_path = "figures/shape_correspondence/"

path = "data/shape_data/test-sets"
for i in range(len(competitors)):
    curr_competitor = competitors[i]
    curr_our = ours[i]
    curr_our_meyer = ours_meyer[i]
    curr_our_simple = ours_simple[i]

    if i == 0:
        curr_union = "unioncom"
        curr_learned = "wot-b-learned"
        curr_pamona = "pamona"

    if i == 1:
        curr_learned = "wot-u-learned"
        curr_pamona = "pamona"

    for test_file in test_set:
        curr_competitor_errors = []
        curr_our_errors = []
        curr_our_meyer_errors = []
        curr_our_simple_errors = []

        if i == 0:
            curr_union_errors = []
            curr_learned_errors = []
            curr_pamona_errors = []
        if i == 1:
            curr_learned_errors = []
            curr_pamona_errors = []

        with open(os.path.join(path, test_file + ".txt"), "r") as f:
            lines = f.readlines()
            animal_pairs = [line.strip().split(",") for line in lines]

            for animal_pair in animal_pairs:
                animal1, animal2 = animal_pair
                with open(os.path.join(dump_path, curr_competitor, test_file, f"{animal1}_{animal2}.txt"), "rb") as f:
                    curr_competitor_errors += pickle.load(f)

                with open(os.path.join(dump_path, curr_our, test_file, f"{animal1}_{animal2}.txt"), "rb") as f:
                    curr_our_errors += pickle.load(f)

                with open(os.path.join(dump_path, curr_our_meyer, test_file, f"{animal1}_{animal2}.txt"), "rb") as f:
                    curr_our_meyer_errors += pickle.load(f)

                with open(os.path.join(dump_path, curr_our_simple, test_file, f"{animal1}_{animal2}.txt"), "rb") as f:
                    curr_our_simple_errors += pickle.load(f)

                if i == 0:
                    with open(os.path.join(dump_path, curr_union, test_file, f"{animal1}_{animal2}.txt"), "rb") as f:
                        curr_union_errors += pickle.load(f)

                    with open(os.path.join(dump_path, curr_learned, test_file, f"{animal1}_{animal2}.txt"), "rb") as f:
                        curr_learned_errors += pickle.load(f)

                    with open(os.path.join(dump_path, curr_pamona, test_file, f"{animal1}_{animal2}.txt"), "rb") as f:
                        curr_pamona_errors += pickle.load(f)

                if i == 1:
                    with open(os.path.join(dump_path, curr_learned, test_file, f"{animal1}_{animal2}.txt"), "rb") as f:
                        curr_learned_errors += pickle.load(f)

                    with open(os.path.join(dump_path, curr_pamona, test_file, f"{animal1}_{animal2}.txt"), "rb") as f:
                        curr_pamona_errors += pickle.load(f)

        # print(f"Test File: {test_file}")
        # print(f"Type: {curr_competitor}")
        # print(f"Competitor Mean: {np.mean(curr_competitor_errors)}")
        # if i == 0: print(f"UnionCom Mean: {np.mean(curr_union_errors)}") 
        # if i == 0 or i == 1: print(f"Learned Mean: {np.mean(curr_learned_errors)}") 
        # if i == 1: print(f"Pamona Mean: {np.mean(curr_pamona_errors)}")
        # print(f"Our Mean: {np.mean(curr_our_errors)}")
        # print(f"Our Meyer Mean: {np.mean(curr_our_meyer_errors)}")
        # print(f"Our Simple Tight Mean: {np.mean(curr_our_simple_errors)}")

        # curr_competitor_errors = np.array(curr_competitor_errors)
        # curr_our_errors = np.array(curr_our_errors)
        # curr_our_meyer_errors = np.array(curr_our_meyer_errors)
        # curr_our_simple_errors = np.array(curr_our_simple_errors)
        # curr_learned_errors = np.array(curr_learned_errors)
        # if i == 1: curr_pamona_errors = np.array(curr_pamona_errors)
        # if i == 0: curr_union_errors = np.array(curr_union_errors)
        # print(f"Type: {curr_competitor}")
        # print(f"Test File: {test_file}")
        # print(f"Competitor < 0.25: {len(curr_competitor_errors[curr_competitor_errors < 0.25]) / len(curr_competitor_errors)}")
        # if i == 0: print(f"UnionCom < 0.25: {len(curr_union_errors[curr_union_errors < 0.25]) / len(curr_union_errors)}")
        # if i == 0 or i == 1: print(f"Learned < 0.25: {len(curr_learned_errors[curr_learned_errors < 0.25]) / len(curr_learned_errors)}")
        # if i == 1: print(f"Pamona < 0.25: {len(curr_pamona_errors[curr_pamona_errors < 0.25]) / len(curr_pamona_errors)}")
        # print(f"Our < 0.25: {len(curr_our_errors[curr_our_errors < 0.25]) / len(curr_our_errors)}")
        # print(f"Our < 0.25 Meyer Mean: {len(curr_our_meyer_errors[curr_our_meyer_errors < 0.25]) / len(curr_our_meyer_errors)}")
        # print(f"Our < 0.25 Tight Mean: {len(curr_our_simple_errors[curr_our_simple_errors < 0.25]) / len(curr_our_simple_errors)}")

        competitor_values, competitor_base = np.histogram(curr_competitor_errors, bins=200)
        competitor_cumulative = (np.cumsum(competitor_values) / np.cumsum(competitor_values).max()) * 100

        if i == 0:
            union_values, union_base = np.histogram(curr_union_errors, bins=200)
            union_cumulative = (np.cumsum(union_values) / np.cumsum(union_values).max()) * 100

            learned_values, learned_base = np.histogram(curr_learned_errors, bins=200)
            learned_cumulative = (np.cumsum(learned_values) / np.cumsum(learned_values).max()) * 100
        
            pamona_values, pamona_base = np.histogram(curr_pamona_errors, bins=200)
            pamona_cumulative = (np.cumsum(pamona_values) / np.cumsum(pamona_values).max()) * 100

        if i == 1:
            learned_values, learned_base = np.histogram(curr_learned_errors, bins=200)
            learned_cumulative = (np.cumsum(learned_values) / np.cumsum(learned_values).max()) * 100

            pamona_values, pamona_base = np.histogram(curr_pamona_errors, bins=200)
            pamona_cumulative = (np.cumsum(pamona_values) / np.cumsum(pamona_values).max()) * 100

        our_values, our_base = np.histogram(curr_our_errors, bins=200)
        our_cumulative = (np.cumsum(our_values) / np.cumsum(our_values).max()) * 100

        our_meyer_values, our_meyer_base = np.histogram(curr_our_meyer_errors, bins=200)
        our_meyer_cumulative = (np.cumsum(our_meyer_values) / np.cumsum(our_meyer_values).max()) * 100

        our_simple_values, our_simple_base = np.histogram(curr_our_simple_errors, bins=200)
        our_simple_cumulative = (np.cumsum(our_simple_values) / np.cumsum(our_simple_values).max()) * 100

        plt.title(f"Cumulative Relative Geodesic Error (SHREC20)")
        if i == 0: plt.plot(learned_base[:-1], learned_cumulative, c='fuchsia', label=out_learned_naming[curr_learned])
        if i == 1: plt.plot(learned_base[:-1], learned_cumulative, c='fuchsia', label=out_learned_naming[curr_learned])
        plt.plot(our_base[:-1], our_cumulative, c='cyan', label=our_naming[curr_our])
        plt.plot(our_meyer_base[:-1], our_meyer_cumulative, c='purple', label=our_meyer_naming[curr_our_meyer])
        plt.plot(our_simple_base[:-1], our_simple_cumulative, c='blue', label=our_simple_naming[curr_our_simple])
        plt.plot(competitor_base[:-1], competitor_cumulative, c='orange', label=competitor_naming[curr_competitor])
        if i == 0: plt.plot(union_base[:-1], union_cumulative, c='olive', label="UnionCom")
        if i == 0: plt.plot(pamona_base[:-1], pamona_cumulative, c='darkred', label="Pamona")
        plt.xlabel("Relative Geodesic Error")
        plt.ylabel("% Matches")

        plt.xlim(0, 1)
        plt.ylim(0, 100)
        plt.xticks(np.arange(0, 1.1, 0.1))
        plt.yticks(np.arange(0, 101, 20))
        plt.legend()
        plt.grid(True)
        plt.show()
        plt.savefig(os.path.join(out_path, curr_competitor, test_file, "cum_plot.png"))
        plt.close()