import numpy as np
from numba import jit
import torch

import matplotlib
matplotlib.rcParams['image.cmap'] = 'jet'
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from nnet import Sheaf_NNet
from data_loader import generate_graph_weights_and_features
from sheaf_calculator import compute_neural_network_parameters

def learn_sheaf_parameters(dimy=6, dimx=2, dimh=3, numv=1000, nepoch=2048, lr=1.0e-3):
    print('learning_the_sheaf')
    w, x, f = generate_graph_weights_and_features(numv, dimy, dimx, dimh, gamma_ref=10.0)
    ### what are you going to learn?
    nnet = Sheaf_NNet(dimy, dimh).double()
    loss_data = compute_neural_network_parameters(nnet, nepoch, numv, lr, w, x, f)

    nnet_folder = './nnet_folder'
    fname = nnet_folder + '/nnet_sheaf.ptr'
    torch.save(nnet, fname)        
    np.save('./results/loss_data.npy', loss_data)
    return 0


def performance_test():
    loss_data = np.load('./results/loss_data.npy')
    loss_orth = loss_data[:, 0]
    loss_cons = loss_data[:, 1]
    loss_diff = loss_data[:, 2]
    loss_sres = loss_data[:, 3]
    loss_comb = loss_data[:, 4]
    t = np.linspace(1, loss_data.shape[0], loss_data.shape[0])
    plt.figure(figsize=(8, 6))

    plt.plot(np.log(t) / np.log(2.0), np.log(loss_orth) / np.log(2.0), c='r', label='orth', linewidth=4.0)
    plt.plot(np.log(t) / np.log(2.0), np.log(loss_cons) / np.log(2.0), c='g', label='cons', linewidth=4.0)
    plt.plot(np.log(t) / np.log(2.0), np.log(loss_diff) / np.log(2.0), c='b', label='diff', linewidth=4.0)
    plt.plot(np.log(t) / np.log(2.0), np.log(loss_sres) / np.log(2.0), c='C0', label='mse', linewidth=4.0)
    plt.plot(np.log(t) / np.log(2.0), np.log(loss_comb) / np.log(2.0), c='C1', label='comb', linewidth=4.0)

    # plt.scatter(np.log(t) / np.log(2.0), np.log(loss_orth) / np.log(2.0), c='r', label='orth')
    # plt.scatter(np.log(t) / np.log(2.0), np.log(loss_cons) / np.log(2.0), c='g', label='cons')
    # plt.scatter(np.log(t) / np.log(2.0), np.log(loss_diff) / np.log(2.0), c='b', label='diff')
    # plt.scatter(np.log(t) / np.log(2.0), np.log(loss_comb) / np.log(2.0), c='C0', label='comb')
    # plt.xlabel(r'$iteration$')
    plt.xlabel(r'$\log_2(iteration)$', fontsize=20)
    plt.ylabel(r'$\log_2(loss)$', fontsize=20)
    plt.xlim(0, 12)
    plt.ylim(-25, 10)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.legend(fontsize=20)
    plt.grid()
    plt.tight_layout()
    plt.savefig('loss_data_cube.pdf')
    plt.show()
    return 0

def data_expolration():
    numv = 100
    dimy = 3
    w, x = generate_graph_weights_and_features(numv, dimy)
    w = w.detach().cpu().numpy()
    x = x.detach().cpu().numpy()

    print('w.shape = ' + str(w.shape))
    print('x.shape = ' + str(x.shape))
    ### 
    return 0

def main():
    print('inside the main function')
    learn_sheaf_parameters(dimy=6, dimx=3, dimh=3, numv=300, nepoch=2048, lr=1.0e-3)
    performance_test()
    # data_expolration()
    return 0

main()




















