import ast
import json
import matplotlib.gridspec as gridspec
import copy
from scipy.stats import ortho_group
from scipy.sparse import csr_matrix
from sklearn.metrics import f1_score
from torch_geometric.datasets import GNNBenchmarkDataset
import utils_gnn_VI_layer_NeurIPS as utils_layer
from matplotlib.ticker import MaxNLocator
from scipy.linalg import block_diag
from scipy.linalg import sqrtm
from scipy.stats import norm
import networkx as nx
import matplotlib.pyplot as plt
import sys
import importlib as ipb
from torch.autograd import Variable
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import torch
import scipy.io
import os
import numpy as np
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rcParams['font.size'] = 24
plt.rcParams['axes.titlesize'] = 24
plt.rcParams['figure.titlesize'] = 30

ipb.reload(sys.modules['utils_gnn_VI_layer_NeurIPS'])

# Generate graph data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class GCN_SGD(torch.nn.Module):
    def __init__(self, C, F_out=1, H=4, splus=False, beta=1):
        super().__init__()
        print(f'{H} hidden nodes')
        self.conv1 = GCNConv(C, H)
        self.conv2 = GCNConv(H, F_out)
        self.splus = splus
        self.beta = beta

    def forward(self, data):
        x, edge_index = data.x.to(device), data.edge_index.to(device)
        x = self.conv1(x, edge_index)
        func = torch.nn.Softplus(beta=self.beta) if self.splus else F.relu
        x = func(x)
        x = self.conv2(x, edge_index)
        # probit = torch.distributions.normal.Normal(0, 1)
        # return probit.cdf(x)
        return torch.sigmoid(x)  # This is CONVEX


class GCN_VI(torch.nn.Module):
    def __init__(self, C, F_out=1, H=4, splus=False, beta=1):
        super().__init__()
        print(f'{H} hidden nodes')
        self.conv1 = GCNConv(C, H)
        self.conv2 = GCNConv(H, F_out)
        self.splus = splus
        self.beta = beta

    def forward(self, data):
        x, edge_index = data.x.to(device), data.edge_index.to(device)
        layer1_x = self.conv1(x, edge_index)
        # NOTE: take grad w.r.t. x above for 1ST layer response
        func = torch.nn.Softplus(beta=self.beta) if self.splus else F.relu
        x = func(x)
        self.layer1_x = Variable(layer1_x, requires_grad=True)
        layer2_x = self.conv2(self.layer1_x, edge_index)
        # probit = torch.distributions.normal.Normal(0, 1)
        # return probit.cdf(layer2_x)
        return torch.sigmoid(layer2_x)  # This is CONVEX
# NOTE: if change layer1_x to self.layer1_x, then it has gradient but not the parameters in the first layer. That is weird.


class GCN_feature(torch.nn.Module):
    def __init__(self, C, H=4, more_layers=False):
        super().__init__()
        print(f'{H} hidden nodes')
        self.conv1 = GCNConv(C, H)
        self.conv2 = GCNConv(H, H)

    def forward(self, data):
        x, edge_index = data.x.to(device), data.edge_index.to(device)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


class GCN_feature1(torch.nn.Module):
    def __init__(self, C):
        super().__init__()
        print(f'{H} hidden nodes')
        self.conv1 = GCNConv(C, C)

    def forward(self, data):
        x, edge_index = data.x.to(device), data.edge_index.to(device)
        x = self.conv1(x, edge_index)
        return x


def change_mod_param(model):
    old_dict = model.state_dict()
    old_dict['conv1.bias'] = torch.from_numpy(b1)
    old_dict['conv1.lin.weight'] = torch.from_numpy(W1)
    model.load_state_dict(old_dict)
    layer = 0
    for child in model.children():
        if layer == 0:
            for param in child.parameters():
                param.requires_grad = False
        layer += 1
    return model


graph_type = 'small'
for pp in [1]:
    n = 40 if graph_type == 'large' else 15
    C = 2
    H_true = 2
    F_out = 1  # Multiple layer
    mu = 1
    # # For task type==2 or 3 & est. neuron = 4
    sigma = 1
    np.random.seed(2)
    W1 = np.random.normal(mu, sigma, H_true*C).reshape((H_true, C)
                                                       ).astype(np.float32)  # H_true-by-C
    b1 = np.random.normal(mu, sigma, H_true).astype(np.float32)
    # F-by-H_true, CRUCIAL to reset shape
    # NOTE: these parameters are used so we can have more balanced one and zero to make the problem harder
    W2 = np.random.normal(
        mu, sigma, F_out*H_true).reshape((F_out, H_true)).astype(np.float32)
    b2 = np.random.normal(mu, sigma, F_out).astype(np.float32)  # F-by-1
    G = nx.fast_gnp_random_graph(n=n, p=0.15, seed=1103)
    edge_index = torch.tensor(list(G.edges)).T.type(torch.long)
    pertub = 0.2 if graph_type == 'small' else 0.05
    G_est = utils.G_reformat(G, percent_perturb=pertub, return_G=True)
    edge_index_est = torch.tensor(list(G_est.edges)).T.type(torch.long)
    N = 2000  # Num training data
    N1 = 200  # Num test data
    batch_size = int(N/20)
    # utils_layer.draw_graph(edge_index, edge_index_est, graph_type)
    model_get_data = GCN_SGD(C, F_out, H_true).to(device)
    # NOTE: another way to change parameters, which FORCES me to make sure parameters match the shape I want
    old_dict = model_get_data.state_dict()
    old_dict['conv1.bias'] = torch.from_numpy(b1)
    old_dict['conv1.lin.weight'] = torch.from_numpy(W1)
    old_dict['conv2.bias'] = torch.from_numpy(b2)
    old_dict['conv2.lin.weight'] = torch.from_numpy(W2)
    model_get_data.load_state_dict(old_dict)

# Update all layers
num_epochs = 100
seeds = [1103, 1111, 1214]
H_ls = [H_true, 4, 8, 16, 32]
loss_type = 'MSE'
Adam = False
opt_type = '_Adam' if Adam else ''
for H in H_ls:
    tasks = [2, 3]
    for task_type in tasks:
        compute_para_err = True if task_type == 1 else False
        plot_para_recovery = True if task_type == 1 else False
        result_SGD1_dict = {}
        result_VI1_dict = {}
        for seed in seeds:
            ipb.reload(sys.modules['utils_gnn_VI_layer_NeurIPS'])
            # Generate Data
            model_get_data.eval()
            print(f'True model: {list(model_get_data.parameters())}')
            X_train, Y_train = utils_layer.get_simulation_data(
                model_get_data, N, edge_index, n, C, torch_seed=seed)
            len(X_train)
            X_test, Y_test = utils_layer.get_simulation_data(
                model_get_data, N1, edge_index, n, C, train=False, torch_seed=seed)
            len(X_test)
            train_loader, test_loader = utils_layer.get_train_test_loader(
                X_train, X_test, Y_train, Y_test, edge_index, batch_size)
            test_loader_true = None
            if task_type == 3:
                train_loader, test_loader = utils_layer.get_train_test_loader(
                    X_train, X_test, Y_train, Y_test, edge_index_est, batch_size)
                _, test_loader_true = utils_layer.get_train_test_loader(
                    X_train, X_test, Y_train, Y_test, edge_index, batch_size)
            # Estimation
            torch.manual_seed(seed)  # For reproducibility
            # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            model_SGD1 = GCN_SGD(C, H=H).to(device)
            if task_type == 1:
                model_SGD1 = change_mod_param(model_SGD1)
            # print(f'SGD: {list(model_SGD1.parameters())}')
            # SGD
            mod_SGD1 = utils_layer.GCN_train(model_SGD1, train_loader,
                                             test_loader, model_get_data, test_loader_true)
            para_error_vanilla1, pred_l2error_vanilla1, pred_linferror_vanilla1, pred_loss_vanilla1 = mod_SGD1.training_and_eval(
                num_epochs, compute_para_err=compute_para_err, loss_type=loss_type, Adam=Adam)
            result_SGD1_dict[f'Seed {seed}'] = [para_error_vanilla1,
                                                pred_l2error_vanilla1, pred_linferror_vanilla1, pred_loss_vanilla1]
            # NOW VI
            torch.manual_seed(seed)  # For reproducibility
            model_VI1 = GCN_VI(C, H=H).to(device)
            if task_type == 1:
                model_VI1 = change_mod_param(model_VI1)
            mod_feature1 = GCN_feature1(C).to(device)
            old_dict = mod_feature1.state_dict()
            old_dict['conv1.bias'] = torch.zeros(C)
            old_dict['conv1.lin.weight'] = torch.diag(torch.ones(C))
            mod_feature1.load_state_dict(old_dict)
            mod_feature2 = GCN_feature(C, H).to(device)
            old_dict = mod_feature2.state_dict()
            old_dict['conv2.bias'] = torch.zeros(H)
            old_dict['conv2.lin.weight'] = torch.diag(torch.ones(H))
            mod_feature2.load_state_dict(old_dict)
            model_to_feature_ls = [mod_feature1, mod_feature2]
            mod_VI1 = utils_layer.GCN_train(model_VI1, train_loader,
                                            test_loader, model_get_data, test_loader_true)
            model_to_feature = None
            if task_type == 1:
                model_to_feature = mod_feature2
            para_error_VI1, pred_l2error_VI1, pred_linferror_VI1, pred_loss_VI1 = mod_VI1.training_and_eval(
                num_epochs, compute_para_err=compute_para_err, model_to_feature_ls=model_to_feature_ls, model_to_feature=model_to_feature, loss_type=loss_type, Adam=Adam)
            result_VI1_dict[f'Seed {seed}'] = [para_error_VI1,
                                               pred_l2error_VI1, pred_linferror_VI1, pred_loss_VI1]
            fig, ax = plt.subplots(figsize=(10, 5))
            ax.plot(pred_linferror_vanilla1, label='SGD')
            ax.plot(pred_linferror_VI1, label='VI')
            ax.set_title('SGD vs. VI on $l_{\infty}$ prediction error')
            ax.legend(loc='best')
            ax.set_yscale('log')
            ax.set_xscale('log')
            plt.show()
        json_SGD = json.dumps(str(result_SGD1_dict))
        json_VI = json.dumps(str(result_VI1_dict))
        if task_type == 2:
            name = f'SGD_Simulation_b_{graph_type}_first_layer_not_known_H={H}_new_new_all_layer_{loss_type}{opt_type}'
            name1 = f'VI_Simulation_b_{graph_type}_first_layer_not_known_H={H}_new_new_all_layer_{loss_type}{opt_type}'
        if task_type == 3:
            name = f'SGD_Simulation_c_{graph_type}_est_graph_H={H}_new_new_all_layer_{loss_type}{opt_type}'
            name1 = f'VI_Simulation_c_{graph_type}_est_graph_H={H}_new_new_all_layer_{loss_type}{opt_type}'
        # open file for writing, "w"
        f = open(f"{name}.json", "w")
        # write json object to file
        f.write(json_SGD)
        # close file
        f.close()
        # open file for writing, "w"
        f = open(f"{name1}.json", "w")
        # write json object to file
        f.write(json_VI)
        # close file
        f.close()
        para_error_vanilla, para_error_vanillaSE, pred_l2error_vanilla, pred_l2error_vanillaSE, pred_linferror_vanilla, pred_linferror_vanillaSE, pred_loss_vanilla, pred_loss_vanillaSE = utils_layer.get_all(
            result_SGD1_dict)
        para_error_VI, para_error_VISE, pred_l2error_VI, pred_l2error_VISE, pred_linferror_VI, pred_linferror_VISE, pred_loss_VI, pred_loss_VISE = utils_layer.get_all(
            result_VI1_dict)
        fig = utils_layer.simulation_plot(para_error_vanilla, pred_l2error_vanilla, pred_loss_vanilla, pred_linferror_vanilla, para_error_VI, pred_l2error_VI, pred_loss_VI, pred_linferror_VI,
                                          para_error_vanillaSE, pred_l2error_vanillaSE, pred_loss_vanillaSE, pred_linferror_vanillaSE, para_error_VISE, pred_l2error_VISE, pred_loss_VISE, pred_linferror_VISE, plot_para_recovery=plot_para_recovery, loss_type=loss_type, Adam=Adam)
        if task_type == 1:
            fig.savefig(f'Simulation_a_{graph_type}_first_layer_fully_known_new_new_all_layer_{loss_type}{opt_type}.pdf',
                        dpi=300, bbox_inches='tight', pad_inches=0)
        elif task_type == 2:
            fig.savefig(f'Simulation_b_{graph_type}_first_layer_not_known_H={H}_new_new_all_layer_{loss_type}{opt_type}.pdf',
                        dpi=300, bbox_inches='tight', pad_inches=0)
        else:
            fig.savefig(f'Simulation_c_{graph_type}_est_graph_H={H}_new_new_all_layer_{loss_type}{opt_type}.pdf',
                        dpi=300, bbox_inches='tight', pad_inches=0)
        # # Sanity check, regarding how they differ
        # print(list(model_SGD1.parameters())[-2:])
        # for para in model_SGD1.parameters():
        #     print(para.grad)
        # print(list(model_VI1.parameters())[-2:])
        # for para in model_VI1.parameters():
        #     print(para.grad)


# Use saved JSON data:
# Plot prediction with and without knowing graph together, with label
# For each f, load the case in task_type b and c
# Also, save results in this Table below
# Table:
# row are hidden nueron number
# columns are three metrics, each having 6 entries: SGD vs. VI-SGD under true graph & under est. graph accuracy, and the relative error in between. In particular, present the three numbers for one method
# So 18 entries per row.
H_true = 2
H_ls = [H_true, 4, 8, 16, 32]
loss_type = 'MSE'
Adam = False
opt_type = '_Adam' if Adam else ''
Table_dict = {}
SGD_label = 'Adam (Known)' if Adam else 'SGD (Known)'
SGD_label1 = 'Adam (Est)' if Adam else 'SGD (Est)'
opt_type_sub = 'Adam' if Adam else 'SGD'
make_fig = True  # If just get table, do not set to True
gtypes = ['small', 'large']
for graph_type in gtypes:
    Table = np.zeros((len(H_ls), 2*12))
    columns = np.tile([SGD_label+' mean', SGD_label+' SE', SGD_label1+' mean', SGD_label1+' SE',
                       f'VI-{SGD_label}'+' mean', f'VI-{SGD_label}'+' SE', f'VI-{SGD_label1}'+' mean', f'VI-{SGD_label1}'+' SE'], 3)
    type = np.repeat(['Posterior prediction--$l_2$ norm',
                      f'{loss_type} loss', 'Posterior prediction--$l_{\infty}$ norm'], 8)
    tuples = list(zip(*[type, columns]))
    index = pd.MultiIndex.from_tuples(tuples)
    ipb.reload(sys.modules['utils_gnn_VI_layer_NeurIPS'])
    for k, H in enumerate(H_ls):
        # Case b, knowing graph completely
        name = f'SGD_Simulation_b_{graph_type}_first_layer_not_known_H={H}_new_new_all_layer_{loss_type}{opt_type}.json'
        name1 = f'VI_Simulation_b_{graph_type}_first_layer_not_known_H={H}_new_new_all_layer_{loss_type}{opt_type}.json'
        with open(name, 'r') as j:
            result_SGD1_dict = json.loads(j.read())
            result_SGD1_dict = ast.literal_eval(result_SGD1_dict)
        with open(name1, 'r') as j:
            result_VI1_dict = json.loads(j.read())
            result_VI1_dict = ast.literal_eval(result_VI1_dict)
        SGD_know_graph = utils_layer.get_all(
            result_SGD1_dict)
        VI_know_graph = utils_layer.get_all(
            result_VI1_dict)
        # Case c, estimate graph
        name = f'SGD_Simulation_c_{graph_type}_est_graph_H={H}_new_new_all_layer_{loss_type}{opt_type}.json'
        name1 = f'VI_Simulation_c_{graph_type}_est_graph_H={H}_new_new_all_layer_{loss_type}{opt_type}.json'
        with open(name, 'r') as j:
            result_SGD1_dict = json.loads(j.read())
            result_SGD1_dict = ast.literal_eval(result_SGD1_dict)
        with open(name1, 'r') as j:
            result_VI1_dict = json.loads(j.read())
            result_VI1_dict = ast.literal_eval(result_VI1_dict)
        SGD_est_graph = utils_layer.get_all(
            result_SGD1_dict)
        VI_est_graph = utils_layer.get_all(
            result_VI1_dict)
        single_plot = False
        if H == 32:
            single_plot = True
        fig, long_ls = utils_layer.simulation_plot_know_est_graph(
            SGD_know_graph, VI_know_graph, SGD_est_graph, VI_est_graph, loss_type=loss_type, Adam=Adam, make_fig=make_fig, single_plot=single_plot)
        Table[k] = long_ls
        if fig != 0:
            sing_idx = '_single' if single_plot else ''
            fig.savefig(f'{graph_type}_graph_know_est_graph_H={H}_{loss_type}{opt_type}{sing_idx}.pdf',
                        dpi=300, bbox_inches='tight', pad_inches=0)
    Table = pd.DataFrame(Table, index=H_ls, columns=index)
    Table.index.name = '# Hidden nodes'
    Table_dict[graph_type] = Table
list(Table_dict.keys())
columns = np.tile([SGD_label, SGD_label1,
                   f'VI-{SGD_label}', f'VI-{SGD_label1}'], 3)
type = np.repeat(['Posterior prediction--$l_2$ norm',
                  f'{loss_type} loss', 'Posterior prediction--$l_{\infty}$ norm'], 4)
tuples = list(zip(*[type, columns]))
new_colindex = pd.MultiIndex.from_tuples(tuples)
round_more = False
Table_new_small = utils_layer.concatenat_to_one(
    Table_dict['small'], new_colindex, round_more)
Table_new_large = utils_layer.concatenat_to_one(
    Table_dict['large'], new_colindex, round_more)

print(Table_new_small.to_latex(escape=False))
print(Table_new_large.to_latex(escape=False))

# Visualize dynamics of weight update
# Setup: C=2 (so easy to see), H large (like several hundred), so each column in the matrix of shape (C-by-H)
# denotes the weight on node i
# For now, still include bias, but we just visualize weights
# 1. Start with the same initialization of parameter and store them
# 2. After estimation, retrieve the parameters and obtain the inner plot with the initial one

# Then plot both on the plot


def plot_dynamics(a_SGD, w_SGD, a_VI, w_VI, pred_linferror_vanilla, pred_linferror_VI):
    # Plot parameters distribution (Fig. 2a)
    # Parameters
    fig = plt.figure(tight_layout=True, figsize=(14, 5))
    gs = gridspec.GridSpec(2, 5)
    cutoff = 2
    ax = fig.add_subplot(gs[:, :cutoff])
    ax.set_xlabel("Epoch", fontsize=16)
    ax.set_ylabel("Error", fontsize=16)
    SGD_label = 'Adam' if Adam else 'SGD'
    ax.plot(pred_linferror_vanilla, label=f'{SGD_label}')
    ax.plot(pred_linferror_VI, label=f'SVI')
    ax.legend(loc='upper right', fontsize=14)
    # ax.set_title(f'{H} Hidden nodes: '+r'Relative error in ' + f'{loss_type}'+f' loss---{SGD_label}'+'\n' +
    #              r'$|L(\hat{\Theta})-L(\Theta)|/|L(\Theta)|$')
    # ax.set_yscale('log')
    ax.tick_params(labelsize=14)
    if H == 50:
        YLIM_u = 0.25
        YLIM_l = -0.05
        XLIM = 0.4
    else:
        YLIM_u = 0.14
        YLIM_l = -0.02
        XLIM = 0.3
    for i in range(2):
        prefix = SGD_label if i == 0 else f'VI-{SGD_label}'
        if i == 0:
            a = a_SGD
            w = w_SGD
        else:
            a = a_VI
            w = w_VI
        ax2 = fig.add_subplot(gs[i, cutoff:])
        ax2.set_ylim(YLIM_l, YLIM_u)
        ax2.set_xlim(-XLIM, XLIM)
        ax2.set_xlabel(r"$a_i$", fontsize=16)
        ax2.set_ylabel(r"$w^\parallel_i$", fontsize=16)
        ax2.tick_params(labelsize=14)
        slist = range(len(a[0]))
        for j in slist:
            ax2.plot([a[0][j], a[1][j]], [w[0][j], w[1][j]],
                     '-', color='black', linewidth=0.5)
        ax2.plot(a[0], w[0], 'o', color='grey',
                 mfc='none', mew=1.5, label="Epoch = 0")
        ax2.plot(a[-1], w[-1], 'o', color='purple', mfc='none', mew=1.5,
                 label=f"Epoch = {num_epochs}")
        mult = 10
        ax2.quiver(a[0], w[0], mult*w[0], mult*a[0], color='black',
                   scale=60.0, width=0.003, headwidth=3)
        # ax2.set_title(f'{prefix} Dynamics')
        if i == 0:
            ax2.legend(ncol=2, fontsize=14, loc='upper center',
                       bbox_to_anchor=(0.5, 1.4))
    plt.tight_layout()
    return fig


graph_type = 'small'
for pp in [1]:
    n = 40 if graph_type == 'large' else 15
    C = 2
    H_true = 2
    F_out = 1  # Multiple layer
    mu = 1
    # # For task type==2 or 3 & est. neuron = 4
    sigma = 1
    np.random.seed(2)
    W1 = np.random.normal(mu, sigma, H_true*C).reshape((H_true, C)
                                                       ).astype(np.float32)  # H_true-by-C
    b1 = np.random.normal(mu, sigma, H_true).astype(np.float32)
    # F-by-H_true, CRUCIAL to reset shape
    # NOTE: these parameters are used so we can have more balanced one and zero to make the problem harder
    W2 = np.random.normal(
        mu, sigma, F_out*H_true).reshape((F_out, H_true)).astype(np.float32)
    b2 = np.random.normal(mu, sigma, F_out).astype(np.float32)  # F-by-1
    G = nx.fast_gnp_random_graph(n=n, p=0.15, seed=1103)
    edge_index = torch.tensor(list(G.edges)).T.type(torch.long)
    pertub = 0.2 if graph_type == 'small' else 0.05
    G_est = utils.G_reformat(G, percent_perturb=pertub, return_G=True)
    edge_index_est = torch.tensor(list(G_est.edges)).T.type(torch.long)
    N = 1000  # Num training data
    N1 = 1000  # Num test data
    batch_size = int(N/20)
    # utils_layer.draw_graph(edge_index, edge_index_est, graph_type)
    model_get_data = GCN_SGD(C, F_out, H_true).to(device)
    # NOTE: another way to change parameters, which FORCES me to make sure parameters match the shape I want
    old_dict = model_get_data.state_dict()
    old_dict['conv1.bias'] = torch.from_numpy(b1)
    old_dict['conv1.lin.weight'] = torch.from_numpy(W1)
    old_dict['conv2.bias'] = torch.from_numpy(b2)
    old_dict['conv2.lin.weight'] = torch.from_numpy(W2)
    model_get_data.load_state_dict(old_dict)
seed = 1103
Adam = False
opt_type = '_Adam' if Adam else ''
num_epochs = 10
for splus in [True, False]:
    beta = 2.5  # This is the best value for SGD
    for loss_type in ['Cross-Entropy', 'MSE']:
        for H in [50, 100]:
            # Train model
            result_dict = {'SGD': [], 'VI': []}
            for a in [1]:
                ipb.reload(sys.modules['utils_gnn_VI_layer_NeurIPS'])
                compute_para_err = False
                plot_para_recovery = False
                result_SGD1_dict = {}
                result_VI1_dict = {}
                # Generate Data
                model_get_data.eval()
                print(f'True model: {list(model_get_data.parameters())}')
                X_train, Y_train = utils_layer.get_simulation_data(
                    model_get_data, N, edge_index, n, C, torch_seed=seed)
                len(X_train)
                X_test, Y_test = utils_layer.get_simulation_data(
                    model_get_data, N1, edge_index, n, C, train=False, torch_seed=seed)
                len(X_test)
                train_loader, test_loader = utils_layer.get_train_test_loader(
                    X_train, X_test, Y_train, Y_test, edge_index, batch_size)
                # Estimation
                # SGD first
                torch.manual_seed(seed)  # For reproducibility
                model_SGD1 = GCN_SGD(C, H=H, splus=splus, beta=beta).to(device)
                SGD_dict_ref = copy.deepcopy(model_SGD1.state_dict())
                para_error_vanilla = []
                pred_l2error_vanilla = []
                pred_linferror_vanilla = []
                pred_loss_vanilla = []
                for epoch in range(num_epochs):
                    print(f'SGD epoch {epoch}')
                    train_loss = utils_layer.train_revised_all_layer(train_loader,
                                                                     model_to_train=model_SGD1, output_dim=1, loss_type=loss_type, Adam=Adam)
                    para_err, l2_err, linf_err, loss_true = utils_layer.evaluation_simulation(
                        test_loader, model_get_data, model_SGD1, W2=[], b2=[], data_loader_true=test_loader, loss_type=loss_type)
                    para_error_vanilla.append(para_err)
                    pred_l2error_vanilla.append(l2_err)
                    pred_linferror_vanilla.append(linf_err)
                    pred_loss_vanilla.append(loss_true)
                    if np.mod(epoch, int(num_epochs//10)) == 0:
                        print(
                            f'[rel Para err, rel l2 err, rel linf err, rel entropy loss] at {epoch} is \n {[para_err, l2_err, linf_err, loss_true]}')
                result_dict['SGD'] = [pred_l2error_vanilla,
                                      pred_linferror_vanilla, pred_loss_vanilla]
                SGD_dict_final = copy.deepcopy(model_SGD1.state_dict())
                # NOW VI
                torch.manual_seed(seed)  # For reproducibility
                model_VI1 = GCN_VI(C, H=H, splus=splus, beta=beta).to(device)
                VI_dict_ref = copy.deepcopy(model_VI1.state_dict())
                mod_feature1 = GCN_feature1(C).to(device)
                old_dict = mod_feature1.state_dict()
                old_dict['conv1.bias'] = torch.zeros(C)
                old_dict['conv1.lin.weight'] = torch.diag(torch.ones(C))
                mod_feature1.load_state_dict(old_dict)
                mod_feature2 = GCN_feature(C, H).to(device)
                old_dict = mod_feature2.state_dict()
                old_dict['conv2.bias'] = torch.zeros(H)
                old_dict['conv2.lin.weight'] = torch.diag(torch.ones(H))
                mod_feature2.load_state_dict(old_dict)
                model_to_feature_ls = [mod_feature1, mod_feature2]
                para_error_VI = []
                pred_l2error_VI = []
                pred_linferror_VI = []
                pred_loss_VI = []
                for epoch in range(num_epochs):
                    print(f'VI epoch {epoch}')
                    train_loss = utils_layer.train_revised_all_layer(train_loader,
                                                                     model_to_train=model_VI1, output_dim=1, model_to_feature_ls=model_to_feature_ls, loss_type=loss_type, Adam=Adam)
                    para_err, l2_err, linf_err, loss_true = utils_layer.evaluation_simulation(
                        test_loader, model_get_data, model_VI1, W2=[], b2=[], data_loader_true=test_loader, loss_type=loss_type)
                    para_error_VI.append(para_err)
                    pred_l2error_VI.append(l2_err)
                    pred_linferror_VI.append(linf_err)
                    pred_loss_VI.append(loss_true)
                result_dict['VI'] = [pred_l2error_VI,
                                     pred_linferror_VI, pred_loss_VI]
                VI_dict_final = copy.deepcopy(model_VI1.state_dict())
                splus_suff = '_splus' if splus else ''
                name1 = f'SGD_VI_dynamics_H={H}_{loss_type}{opt_type}_ref{splus_suff}.pth'
                name2 = f'SGD_VI_dynamics_H={H}_{loss_type}{opt_type}_SGD{splus_suff}.pth'
                name3 = f'SGD_VI_dynamics_H={H}_{loss_type}{opt_type}_VI{splus_suff}.pth'
                torch.save(SGD_dict_ref, name1)
                torch.save(SGD_dict_final, name2)
                torch.save(VI_dict_final, name3)
                json_linf = json.dumps(str(result_dict))
                name = f'SGD_VI_dynamics_H={H}_{loss_type}{opt_type}{splus_suff}'
                f = open(f"{name}.json", "w")
                # write json object to file
                f.write(json_linf)
                # close file
                f.close()
            # Get parameters
            for j in [1]:
                key = 'conv1.lin.weight'
                w_ref = SGD_dict_ref[key].cpu().detach().numpy()
                w_SGD = SGD_dict_final[key].cpu().detach().numpy()
                w_VI = VI_dict_final[key].cpu().detach().numpy()
                w_ref_inner = np.sum(w_ref*w_ref, axis=1)
                w_SGD_final = np.sum(w_ref*w_SGD, axis=1)
                w_VI_final = np.sum(w_ref*w_VI, axis=1)
                key2 = 'conv2.lin.weight'
                a_SGD_ref = SGD_dict_ref[key2].cpu().detach().numpy().flatten()
                a_VI_ref = VI_dict_ref[key2].cpu().detach().numpy().flatten()
                a_SGD_final = SGD_dict_final[key2].cpu(
                ).detach().numpy().flatten()
                a_VI_final = VI_dict_final[key2].cpu(
                ).detach().numpy().flatten()
                a_SGD = [a_SGD_ref, a_SGD_final]
                a_VI = [a_VI_ref, a_VI_final]
                w_SGD = [w_ref_inner, w_SGD_final]
                w_VI = [w_ref_inner, w_VI_final]
            fig = plot_dynamics(a_SGD, w_SGD, a_VI, w_VI,
                                pred_linferror_vanilla, pred_linferror_VI)
            fig.savefig(f'SGD_VI_dynamics_H={H}_{loss_type}{opt_type}{splus_suff}.pdf',
                        dpi=300, bbox_inches='tight', pad_inches=0)


# Load json filed to remake plots
Adam = False
opt_type = '_Adam' if Adam else ''
num_epochs = 200
for splus in [True, False]:
    splus_suff = '_splus' if splus else ''
    for loss_type in ['Cross-Entropy', 'MSE']:
        for H in [50, 100]:
            name1 = f'SGD_VI_dynamics_H={H}_{loss_type}{opt_type}_ref{splus_suff}.pth'
            name2 = f'SGD_VI_dynamics_H={H}_{loss_type}{opt_type}_SGD{splus_suff}.pth'
            name3 = f'SGD_VI_dynamics_H={H}_{loss_type}{opt_type}_VI{splus_suff}.pth'
            model = GCN_VI(C, H=H, splus=splus, beta=beta).to(device)
            model.load_state_dict(torch.load(
                name1, map_location=torch.device('cpu')))
            SGD_dict_ref = copy.deepcopy(model.state_dict())
            VI_dict_ref = copy.deepcopy(SGD_dict_ref)
            model.load_state_dict(torch.load(
                name2, map_location=torch.device('cpu')))
            SGD_dict_final = copy.deepcopy(model.state_dict())
            model.load_state_dict(torch.load(
                name3, map_location=torch.device('cpu')))
            VI_dict_final = copy.deepcopy(model.state_dict())
            name = f'SGD_VI_dynamics_H={H}_{loss_type}{opt_type}{splus_suff}.json'
            with open(name, 'r') as j:
                result_dict = json.loads(j.read())
                result_dict = ast.literal_eval(result_dict)
            pred_linferror_vanilla = result_dict['SGD'][1]
            pred_linferror_VI = result_dict['VI'][1]
            for j in [1]:
                key = 'conv1.lin.weight'
                w_ref = SGD_dict_ref[key].cpu().detach().numpy()
                w_SGD = SGD_dict_final[key].cpu().detach().numpy()
                w_VI = VI_dict_final[key].cpu().detach().numpy()
                w_ref_inner = np.sum(w_ref*w_ref, axis=1)
                w_SGD_final = np.sum(w_ref*w_SGD, axis=1)
                w_VI_final = np.sum(w_ref*w_VI, axis=1)
                key2 = 'conv2.lin.weight'
                a_SGD_ref = SGD_dict_ref[key2].cpu().detach().numpy().flatten()
                a_VI_ref = VI_dict_ref[key2].cpu().detach().numpy().flatten()
                a_SGD_final = SGD_dict_final[key2].cpu(
                ).detach().numpy().flatten()
                a_VI_final = VI_dict_final[key2].cpu(
                ).detach().numpy().flatten()
                a_SGD = [a_SGD_ref, a_SGD_final]
                a_VI = [a_VI_ref, a_VI_final]
                w_SGD = [w_ref_inner, w_SGD_final]
                w_VI = [w_ref_inner, w_VI_final]
            fig = plot_dynamics(a_SGD, w_SGD, a_VI, w_VI,
                                pred_linferror_vanilla, pred_linferror_VI)
            fig.savefig(f'SGD_VI_dynamics_H={H}_{loss_type}{opt_type}{splus_suff}.pdf',
                        dpi=300, bbox_inches='tight', pad_inches=0)
