import numpy as np
import matplotlib.pyplot as plt
import pickle



def get_the_directory(i, j, k):
    return "fig2" + "/" + "res" + "-" + "se_" + str(k+11) + "-" + "ntr_" + str(round(i/10, 1))\
        + "-" + "nte_" + str(round(j/10, 1)) + ".pkl"

def get_the_mean(i, j):

    summy = 0
    for k in range(10):
        the_directory = get_the_directory(i, j, k)
        file = open(the_directory, 'rb')
        data = pickle.load(file)
        file.close()

        # key, value = data.items()
        value = float(list(data.values())[0])
        #print(float(list(data.values())[0]))
        #exit()
        print(value)
        summy += value

    return summy/10



def restore():

    result = np.zeros((20, 20))

    for i in range(20):
        for j in range(20):
            result[i, j] = get_the_mean(i, j)

    return result

def run():
    # Load the data from the folder
    the_folder = "fig2"
    #the_file = "result-fig2"

    #the_directory = the_folder + "/" + the_file + ".npy"
    #result = np.load(the_directory)
    #means = np.mean(result, axis=0)

    means = restore()

    x = ["0.0", "0.1", "0.2", "0.3", "0.4", "0.5","0.6","0.7","0.8","0.9", \
        "1.0", "1.1", "1.2", "1.3", "1.4", "1.5","1.6","1.7","1.8","1.9"]

    y = ["0.0", "0.1", "0.2", "0.3", "0.4", "0.5","0.6","0.7","0.8","0.9", \
        "1.0", "1.1", "1.2", "1.3", "1.4", "1.5","1.6","1.7","1.8","1.9"]

    means = means[::,::].T
    x = x[::]
    y = y[::]

    fig, ax = plt.subplots(figsize=(14, 14))
    ax.imshow(means, interpolation='nearest')
    # Show all ticks and label them with the respective list entries
    ax.set_xticks(np.arange(len(x)), labels=x)
    ax.set_yticks(np.arange(len(y)), labels=y)

    for i in range(20):
        for j in range(20):
            text = ax.text(j, i, round(means[i, j], 2), ha="center", va="center", color="w")

    ax.set_xlabel(r'Training Noise Scalar, $\alpha$')
    ax.set_ylabel(r'Test Noise Scalar, $\alpha$')
    # plt.title( "" )
    plt.savefig('fig2Table1/view3.pdf', format='pdf', bbox_inches='tight')


def the_test():
    milad = np.random.rand(4, 5, 6)
    print(milad.shape)
    milad_mean = np.mean(milad, axis=0)
    print(milad_mean.shape)
    return


if __name__ == '__main__':
    run()
    #the_test()