
import matplotlib.pyplot as plt
import torch

import data as dat
import models as mod
from torch_geometric.utils import degree


from utils import smoothing_metric

#%%
plt.close('all')
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

skip = True
num_expe=10

data_name = 'Cora' # 'Cora', 'Pubmed', 'Citeseer', 'synthetic'
dense = True # for synthetic data
lr = 3e-4 if (data_name!='Pubmed') else 4e-4
num_layers = 40 if (dense and data_name=='synthetic') else 100
p_intra = (0.05 if dense else 0.01)
p_inter = (0.01 if dense else 0.005)
data, num_classes, num_features = dat.load_data(data_name,
                                                p_intra = p_intra,
                                                p_inter = p_inter)
deg = degree(data.edge_index[1], data.x.size(0), dtype=data.x.dtype)
n = data.x.shape[0]

#%% GNN

lossfn = torch.nn.MSELoss()
num_classes = 1
data.y = data.y.float()[:,None]

data.to(device)
dirichxss, dirichbss = [], []
for scale in [1/num_layers, 5/num_layers, 10/num_layers]:

    dirichxs, dirichbs = [], []
    for _ in range(num_expe):
        GNNB = mod.GNNBiSto(num_node_features=num_features, num_classes=num_classes,
                            rec_intermediate_grad=True, num_layers=num_layers, num_units=32,
                            activation='leaky_relu', is_id=True, is_MLP=False, bistochastic=False,
                            std=0.01, skip=skip, scale=scale)
        
        GNNB.to(device)
        
        #%% train
        optim = torch.optim.SGD(GNNB.parameters())
        
        forward_sig, backward_sig, gradients = [], [], []
        
        optim.zero_grad()
        out = GNNB(data)
        loss = lossfn(out, data.y)/n
        loss.backward()
        optim.step()
        
        
        #%% record signals
        
        pi = (deg/deg.sum())
        Xs = GNNB.Xs
        dirichx = []
        for _ in range(num_layers):
            x = Xs[_].detach().cpu()
            dirichx.append(smoothing_metric(x, torch.ones(n)))
        
        Hs = GNNB.Hs
        dirichb = []
        for _ in range(num_layers):
            b = Hs[_].grad.cpu()
            dirichb.append(smoothing_metric(b, pi))
        
        dirichxs.append(dirichx)
        dirichbs.append(dirichb)
    
    dirichxss.append(dirichxs)
    dirichbss.append(dirichbs)
    
dirichxss = torch.tensor(dirichxss)
xmean = dirichxss.mean(dim=1)
xstd = dirichxss.std(dim=1)
dirichbss = torch.tensor(dirichbss)
bmean = dirichbss.mean(dim=1)
bstd = dirichbss.std(dim=1)


plt.figure(figsize=(6,3))
legends=['sL=1', 'sL=5', 'sL=10']
cs=['k', 'r', 'g']
ax = plt.subplot(1,2,1)
for _ in range(3):
    ax.semilogy(xmean[_,:], label=legends[_], c=cs[_])
    ax.fill_between(range(num_layers), xmean[_,:]-xstd[_,:],
                    xmean[_,:]+xstd[_,:], alpha=0.2, color=cs[_])
# plt.title('Forward oversmoothing')
plt.title('E(H)', fontsize=14)
plt.xlabel('Layer', fontsize=14)
plt.legend(fontsize=14)
ax = plt.subplot(1,2,2)
for _ in range(3):
    ax.semilogy(bmean[_,:], label=legends[_], c=cs[_])
    ax.fill_between(range(num_layers), bmean[_,:]-bstd[_,:],
                    bmean[_,:]+bstd[_,:], alpha=0.2, color=cs[_])
# plt.title('Backward oversmoothing')
plt.title('E_pi(B)', fontsize=14)
plt.xlabel('Layer', fontsize=14)
