
import matplotlib.pyplot as plt
import torch

import data as dat

import models as mod
from torch_geometric.utils import degree

from utils import smoothing_metric

#%%
plt.close('all')
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

num_expe=10

data_name = 'Cora' # 'Cora', 'Pubmed', 'Citeseer', 'synthetic'
dense = True # for synthetic data
num_layers = 40 if (dense and data_name=='synthetic') else 100
data, num_classes, num_features = dat.load_data(data_name,
                                                p_intra = (0.05 if dense else 0.01),
                                                p_inter = (0.01 if dense else 0.005))
deg = degree(data.edge_index[1], data.x.size(0), dtype=data.x.dtype)
n = data.x.shape[0]

#%% GNN

lossfn = torch.nn.MSELoss()
num_classes = 1
data.y = data.y.float()[:,None]
data.to(device)

dirichxs, dirichbs = [], []
for _ in range(num_expe):
    GNNB = mod.GNNBiSto(num_node_features=num_features, num_classes=num_classes,
                        rec_intermediate_grad=True, num_layers=num_layers, num_units=32,
                        activation='leaky_relu', is_id=True, is_MLP=False, bistochastic=False,
                        std=0.01)
    
    
    
    GNNB.to(device)
    
    
    
    #%% train
    optim = torch.optim.SGD(GNNB.parameters())
    
    forward_sig, backward_sig, gradients = [], [], []
    
    optim.zero_grad()
    out = GNNB(data)
    loss = lossfn(out, data.y)/n
    loss.backward()
    optim.step()
    
    
    #%% record signals
    
    pi = (deg/deg.sum())
    
    Xs = GNNB.Xs
    dirichx = []
    for _ in range(num_layers):
        x = Xs[_].detach().cpu()
        dirichx.append(smoothing_metric(x, torch.ones(n)))
    
    Hs = GNNB.Hs
    dirichb = []
    for _ in range(num_layers):
        b = Hs[_].grad.cpu()
        dirichb.append(smoothing_metric(b, pi))
    
    dirichxs.append(dirichx)
    dirichbs.append(dirichb)
    
dirichxs = torch.tensor(dirichxs)
xmean = dirichxs.mean(dim=0)
xstd = dirichxs.std(dim=0)
dirichbs = torch.tensor(dirichbs)
bmean = dirichbs.mean(dim=0)
bstd = dirichbs.std(dim=0)


plt.figure(figsize=(6,3))
ax = plt.subplot(1,2,1)
ax.semilogy(xmean)
ax.fill_between(range(num_layers), xmean-xstd, xmean+xstd, alpha=0.2)
# plt.title('Forward oversmoothing')
plt.title('E(H)', fontsize=14)
plt.xlabel('Layer', fontsize=14)
ax = plt.subplot(1,2,2)
ax.semilogy(bmean)
ax.fill_between(range(num_layers), bmean-bstd, bmean+bstd, alpha=0.2)
# plt.title('Backward oversmoothing')
plt.title('E_pi(B)', fontsize=14)
plt.xlabel('Layer', fontsize=14)
