from re import M
import _pickle as cPickle
import matplotlib.pyplot as plt
import numpy as np
import math

def plot_compare(m_norm, m_norm1, m_norm2, v_norm, v_norm1, v_norm2, m_em_loss, v_em_loss, m_true_loss, v_true_loss, m_test_loss, v_test_loss, value, color, alpha = 0.5, sum_weight = True, fontsize = 15):
    plt.figure(figsize=(19,5))
    plt.title("Training of non-closed fixed support neural network", fontsize = fontsize)
    x = np.arange(1, len(m_norm1[0]) + 1)
    ax1 = plt.subplot(1,3,1)
    
    for (index, lamda) in enumerate(value):
        ax1.plot(x, m_em_loss[index], linewidth = 0.5, label = r'$\lambda = $' + str(lamda), color = color[index])
        ax1.fill_between(x, m_em_loss[index] - v_em_loss[index], m_em_loss[index] + v_em_loss[index], color = color[index], alpha = alpha)
    
    ax1.set_title('a)', fontsize = fontsize)

    ax1.set_xlabel('epochs', fontsize = fontsize)
    ax1.set_ylabel('Emperical loss', fontsize = fontsize)
    ax1.legend(prop={'size': fontsize - 2})

    ax2 = plt.subplot(1,3,2)
    for (index, lamda) in enumerate(value):
        ax2.plot(x, m_true_loss[index], label = r'$\lambda = $' + str(lamda), color = color[index])
        ax2.fill_between(x, m_true_loss[index] - v_true_loss[index], m_true_loss[index] + v_true_loss[index], color = color[index], alpha = alpha)
  
    ax2.set_xlabel('epochs', fontsize = fontsize)
    ax2.set_ylabel('Jacobian loss', fontsize = fontsize)
    ax2.set_title('b)', fontsize = fontsize)
    ax2.legend(prop={'size': fontsize - 2})

    ax3 = plt.subplot(1,3,3)
    for (index, lamda) in enumerate(value):
        if sum_weight:
            ax3.plot(x, m_norm[index], label = r'$\lambda = $' + str(lamda), color = color[index])
            ax3.fill_between(x, m_norm[index] - v_norm[index], m_norm[index] + v_norm[index], color = color[index], alpha = alpha)
        else: 
            ax3.plot(x, m_norm1[index], label = r'$\lambda = $' + str(lamda), color = color[index])
            ax3.fill_between(x, m_norm1[index] - v_norm1[index], m_norm1[index] + v_norm1[index], color = color[index], alpha = alpha)    
            ax3.plot(x, m_norm2[index], color = color[index])
            ax3.fill_between(x, m_norm2[index] - v_norm2[index], m_norm2[index] + v_norm2[index], color = color[index], alpha = alpha)     
    
    ax3.set_xlabel('epochs', fontsize = fontsize)
    ax3.set_ylabel('Norm of weight matrices', fontsize = fontsize)
    ax3.set_title('c)', fontsize = fontsize)
    ax3.legend(loc = 'upper left', prop={'size': fontsize - 2})

    #plt.suptitle("Evolution during training of losses and weight matrices norm", fontsize = fontsize)
    plt.savefig(fname='training_behavior_compare.png', dpi = 200)
    plt.show()

def plot_norm_vs_loss(m_norm, m_em_loss, value, color, fontsize = 15):
    plt.figure(figsize = (10,5))
    plt.title("Total norm of weight matrices vs training loss")
    for (index, lamda) in enumerate(value):
        plt.plot(m_norm[index], m_em_loss[index], label = r'$\lambda = $' + str(lamda), color = color[index])
    plt.xlabel("Norms of weight matrices")
    plt.ylabel("Empirical loss")
    plt.legend(loc = "upper right", prop={'size':fontsize - 2})
    plt.savefig(fname="weightvsloss.png", dpi = 200)

def plot_test_loss(m_test_loss, m_true_loss, v_test_loss, v_true_loss, color, alpha = 0.5, fontsize = 15):
    x = np.arange(1, len(m_norm1[0]) + 1)
    plt.figure(figsize = (11, 5))
    plt.title("Validation loss and Jacobian loss")
    ax1 = plt.subplot(1,2,1)    
    for (index, lamda) in enumerate(value):
        ax1.plot(x, m_test_loss[index], label = r'$\lambda = $' + str(lamda), color = color[index])
        ax1.fill_between(x, m_test_loss[index] - v_test_loss[index], m_test_loss[index] + v_test_loss[index], color = color[index], alpha = alpha)
    ax1.set_title("Validation loss")
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Validation loss")
    ax1.legend(loc = "upper right", prop={'size':fontsize - 2})
    
    ax2 = plt.subplot(1,2,2)
    for (index, lamda) in enumerate(value):
        ax2.plot(x, m_true_loss[index], label = r'$\lambda = $' + str(lamda), color = color[index])
        ax2.fill_between(x, m_true_loss[index] - v_true_loss[index], m_true_loss[index] + v_true_loss[index], color = color[index], alpha = alpha)
    ax2.set_title("Jacobian loss")
    ax2.set_xlabel("Epochs")
    ax2.set_ylabel("Jacobian loss")
    ax2.legend(loc = "upper right", prop={'size':fontsize - 2})

    plt.savefig(fname="valloss_jacobloss.png", dpi = 200)

def extract_info_epoch(data, data_size = 100000, batch_size = 3000, n_epoch = 1000, n_run = 10):
    result = []
    nb_iter_per_epoch = math.ceil(data_size  / batch_size)
    for i in range(n_run):
        line = []
        for j in range(n_epoch):
            cumulative_loss = 0
            for k in range(nb_iter_per_epoch - 1):
                cumulative_loss += batch_size * data[i][j * nb_iter_per_epoch + k]
            if data_size % batch_size != 0:
                remaining_size = data_size - (nb_iter_per_epoch - 1) * batch_size
                cumulative_loss += remaining_size * data[i][(j + 1) * nb_iter_per_epoch - 1]
            line.append(cumulative_loss / data_size)
        result.append(line)
    return np.array(result)
# result = np.load("training_evo_full.npz")
# normfc1 = result['arr_0']
# normfc2 = result['arr_1']
# emp_loss = result['arr_2']
# true_loss = result['arr_3']

# plot(normfc1, normfc2, emp_loss, true_loss)
if __name__ == "__main__":
    d = 100
    data_size = 100000
    batch_size = 3000
    n_epoch = 1000
    n_run = 10

    value = [0.0, 0.0001, 0.0005, 0.001]
    all_dicts = {}
    for lamda in value:        
        with open('training_evo_regularisation_' + str(lamda) + '_' + 'LU.pickle', 'rb') as handle:
            all_dicts[lamda] = cPickle.load(handle)
        handle.close()

    m_norm = []
    m_norm1 = []
    m_norm2 = []
    v_norm = []
    v_norm1 = []
    v_norm2 = []

    m_em_loss = []
    m_true_loss = []
    m_test_loss = []
    v_em_loss = []
    v_true_loss = []
    v_test_loss = []
    color = ['red', 'blue', 'green', 'orange']

    for lamda in value:
        all_dicts[lamda]['normfc'] = np.sqrt(np.square(np.array(all_dicts[lamda]['normfc1'])) + np.square(np.array(all_dicts[lamda]['normfc2']))) 

        # Calculate mean of norm
        m_norm1.append(np.mean(np.array(all_dicts[lamda]['normfc1']), axis = 0)) 
        m_norm2.append(np.mean(np.array(all_dicts[lamda]['normfc2']), axis = 0))
        m_norm.append(np.mean(np.array(all_dicts[lamda]['normfc']), axis = 0))

        # Calculate std of norm
        v_norm1.append(np.std(np.array(all_dicts[lamda]['normfc1']), axis = 0)) 
        v_norm2.append(np.std(np.array(all_dicts[lamda]['normfc2']), axis = 0)) 
        v_norm.append(np.std(np.array(all_dicts[lamda]['normfc']), axis = 0))
  

        # Extract the mean of empirical loss and true loss
        em_loss = extract_info_epoch(all_dicts[lamda]['emp_loss'], data_size = data_size, batch_size = batch_size, n_epoch = n_epoch, n_run = n_run)
        true_loss = extract_info_epoch(all_dicts[lamda]['true_loss'], data_size = data_size, batch_size = batch_size, n_epoch = n_epoch, n_run = n_run)
        test_loss = extract_info_epoch(all_dicts[lamda]['test_loss'], data_size = data_size, batch_size = batch_size, n_epoch = n_epoch, n_run = n_run)

        # Calculate mean and std empirical loss
        m_em_loss.append(np.mean(em_loss, axis = 0))
        v_em_loss.append(np.std(em_loss, axis = 0))

        # Calculate mean and std of true loss
        m_true_loss.append(np.mean(true_loss, axis = 0))
        v_true_loss.append(np.std(true_loss, axis = 0))

        # Calculate mean and std of test loss
        m_test_loss.append(np.mean(test_loss, axis = 0))
        v_test_loss.append(np.std(test_loss, axis = 0))

    plot_compare(m_norm, m_norm1, m_norm2, v_norm, v_norm1, v_norm2, m_em_loss, v_em_loss, m_true_loss, v_true_loss, m_test_loss, v_test_loss, value, color)
    #plot_norm_vs_loss(m_norm, m_em_loss, value, color)
    #plot_test_loss(m_test_loss, m_true_loss, v_test_loss, v_true_loss, color)