
import matplotlib.pyplot as plt
import numpy as np
import torch

import data as dat
import models as mod


torch.autograd.set_detect_anomaly(True)
plt.close('all')

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

#%% param

num_epochs = 1000


#%% data
data_name = 'Cora' # 'Cora', 'Pubmed', 'Citeseer', 'synthetic'
dense = True # for synthetic data
lr = 3e-4 if (data_name!='Pubmed') else 4e-4
num_layers = 40 if (dense and data_name=='synthetic') else 100
p_intra = (0.05 if dense else 0.01)
p_inter = (0.01 if dense else 0.005)
data, num_classes, num_features = dat.load_data(data_name,
                                                p_intra = p_intra,
                                                p_inter = p_inter)
n = data.x.shape[0]


lossfn = torch.nn.MSELoss()
num_classes = 1
data.y = data.y.float()[:,None]

gradientss = []
lossess = []
for skip in [True, False]:
    scale = 5/num_layers if skip else 1
    
    # GNN
    GNNB = mod.GNNBiSto(num_node_features=num_features, num_classes=num_classes,
                        rec_intermediate_grad=True, num_layers=num_layers, num_units=32,
                        activation='leaky_relu', is_id=True, is_MLP=False, bistochastic=False,
                        std=0.01, skip=skip, scale=scale)
    
    GNNB.to(device)
    data.to(device)
    
    
    
    ################### train
    optim = torch.optim.SGD(GNNB.parameters(), lr=lr)
    
    gradients = []
    losses = []

    for _ in range(num_epochs):
        optim.zero_grad()
        out = GNNB(data)
        
        loss = lossfn(out, data.y)
        loss.backward()
        optim.step()
        
        print(_, loss.item())
        losses.append(loss.item())
        
        g = []
        for _ in range(num_layers):
            G = GNNB.lin_layers[_].weight.grad.cpu().numpy()
            g.append(G)
        gradients.append(g)
    
    gradientss.append(gradients)
    lossess.append(losses)
    

##### plot gradient norm
c=['r','b']
labels=['with skip','without skip']
plt.figure(figsize=(6,3))
for _ in range(len(gradientss)):
    
    for layer in reversed(range(num_layers)):
        llgrad_ = np.array([gradientss[_][i][layer] for i in range(num_epochs)]) # num_epoch * d_out * d_in
        if layer != num_layers-1:
            plt.semilogy(np.arange(num_epochs)+1, (llgrad_**2).mean(axis=(1,2)), c[_], linewidth=.075)

plt.legend(fontsize=14)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('Gradient norm', fontsize=14)


plt.figure(figsize=(6,3))
for _ in range(len(gradientss)):
    plt.semilogy(lossess[_], linewidth=3, label=labels[_], c=c[_])
plt.legend(fontsize=14)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('Loss', fontsize=14)
