from os import listdir
from os.path import isfile, join

import numpy as np
import pickle

import matplotlib.pyplot as plt

import seaborn as sns
#sns.set_theme()


dir_path = "data_nd/"
onlyfiles = [(join(dir_path, f), f) for f in listdir(dir_path) if isfile(join(dir_path, f)) and "Store" not in f]

linear_regions_dict = {}
linear_dists_dict = {}
epochs = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 105, 110, 115, 120, 125, 130, 135, 140, 145]
for f_path, f_name in onlyfiles:
    with open(f_path, "rb") as f_in:
        data_arr = pickle.load(f_in)
        f_arr =f_name.split("_")
        dim = int(f_arr[-1].split(".")[0])
        if dim not in linear_regions_dict:
            linear_regions_dict[dim] = []
        linear_regions_dict[dim].append(data_arr[1])
        if dim not in linear_dists_dict:
            linear_dists_dict[dim] = []
        linear_dists_dict[dim].append(data_arr[2])

for i in range(1, 6):
    key = i*5
    value = linear_dists_dict[key]
    linear_boundaries_mean = np.mean(np.array(value), axis=0)
    plt.plot(np.array(epochs) + 1, linear_boundaries_mean, label=str(key))
    #plt.plot(epochs, linear_boundaries_std)

plt.xlabel("Epochs")
#plt.ylim((0, 60))
plt.ylim((0.0, 0.0065))
plt.grid()
plt.legend()
plt.tight_layout(pad=1.5)
#plt.ylabel("Number of Linear Regions")
plt.ylabel("Average Distance to Linear Boundary")
plt.savefig("nd_dists.png", dpi=600)
