

'''In case you wanted to install packages'''
# !pip install torch_geometric
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-$(python3 -c "import torch; print(torch.__version__)").html

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.data import Data
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from torch_geometric.nn import GATConv, SAGEConv, GCNConv
from matplotlib import pyplot as plt
from matplotlib.font_manager import FontProperties

"""#hyperparameters"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_nodes = 100
p, q = 0.9, 0.1
mu1, mu2 = 0.5, -0.5
sigma = 1
ll = list(range(2, 100, 1))
monte_carlo_runs=10

"""#data label and features"""

labels = np.random.randint(0, 2, num_nodes)
features = np.zeros((num_nodes, 2))
for i in range(num_nodes):
    mean = mu1 if labels[i] == 0 else mu2
    features[i] = np.random.normal(loc=mean, scale=sigma, size=2)

edges = []
for i in range(num_nodes):
    for j in range(i + 1, num_nodes):
        if labels[i] == labels[j]:
            if np.random.rand() < p:
                edges.append((i, j))
                edges.append((j, i))
        else:
            if np.random.rand() < q:
                edges.append((i, j))
                edges.append((j, i))

edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
x = torch.tensor(features, dtype=torch.float)
y = torch.tensor(labels, dtype=torch.long)

data = Data(x=x, edge_index=edge_index, y=y).to(device) # Move data to device
train_mask = torch.rand(num_nodes) < 0.48
test_mask = ~train_mask

"""#models"""

class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(GCN, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(GCNConv(input_dim, hidden_dim))
        for _ in range(num_layers - 2):
            self.layers.append(GCNConv(hidden_dim, hidden_dim))
        self.layers.append(GCNConv(hidden_dim, output_dim))

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for layer in self.layers[:-1]:
            x = layer(x, edge_index)
            x = F.relu(x)
        x = self.layers[-1](x, edge_index)
        return x  # raw logits

class GAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(GAT, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(GATConv(input_dim, hidden_dim, heads=4))
        for _ in range(num_layers - 2):
            self.layers.append(GATConv(hidden_dim * 4, hidden_dim, heads=4))
            self.layers.append(GATConv(hidden_dim * 4, hidden_dim, heads=4))
        self.layers.append(GATConv(hidden_dim * 4, output_dim, heads=4))
        self.dropout = nn.Dropout(0)
        self.num_layers = num_layers

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for i, layer in enumerate(self.layers):
            x = layer(x, edge_index)
            if i < self.num_layers - 1:
                x = F.elu(x)
                x = self.dropout(x)
        return x


class SAGE(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(SAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(SAGEConv(input_dim, hidden_dim))
        for _ in range(num_layers - 2):
            self.layers.append(SAGEConv(hidden_dim, hidden_dim))
        self.layers.append(SAGEConv(hidden_dim, output_dim))
        self.num_layers = num_layers
        self.dropout = nn.Dropout(0)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for i, layer in enumerate(self.layers):
            x = layer(x, edge_index)
            if i < self.num_layers - 1:
                x = F.relu(x)
                x = self.dropout(x)
        return x

class AttentionDiffusion(nn.Module):
    def __init__(self, in_dim, heads=4):
        super().__init__()
        self.heads = heads
        self.dk = in_dim // heads
        self.WQ = nn.Linear(in_dim, in_dim)
        self.WK = nn.Linear(in_dim, in_dim)
        self.linear = nn.Linear(self.dk, in_dim)

    def forward(self, X, edge_index):
        N = X.size(0)
        Q = self.WQ(X).view(N, self.heads, self.dk)
        K = self.WK(X).view(N, self.heads, self.dk)

        row, col = edge_index
        scores = (Q[row] * K[col]).sum(dim=-1) / (self.dk ** 0.5)
        attn = softmax(scores, row)

        attn = attn.unsqueeze(-1)
        V = X[col].unsqueeze(1).expand(-1, self.heads, -1)
        out = torch.zeros(N, self.heads, self.dk, device=X.device)
        out.index_add_(0, row, attn.unsqueeze(-1) * V[:, :, :self.dk])
        out = out.mean(dim=1)
        out = self.linear(out)
        return out - X

class LambdaDiffusion(nn.Module):
    def __init__(self, in_dim, heads=4):
        super().__init__()
        self.heads = heads
        self.dk = in_dim // heads
        self.WQ = nn.Linear(in_dim, in_dim)
        self.WK = nn.Linear(in_dim, in_dim)
        self.linear = nn.Linear(self.dk, in_dim)
        self.initialized = False
        self.lambda_diag_initialized = False

    def init_X0(self, X):
        if not self.initialized:
            self.X0 = X.detach().clone()
            self.initialized = True

    def init_lambda_diag(self, n_nodes, device):
        self.lambda_diag = nn.Parameter(torch.tensor(1.0, device=device))
        self.lambda_diag_initialized = True

    def forward(self, X, edge_index):
        self.init_X0(X)
        N = X.size(0)
        if not self.lambda_diag_initialized:
            self.init_lambda_diag(N, X.device)

        Q = self.WQ(X).view(N, self.heads, self.dk)
        K = self.WK(X).view(N, self.heads, self.dk)

        row, col = edge_index
        scores = (Q[row] * K[col]).sum(dim=-1) / (self.dk ** 0.5)
        attn = softmax(scores, row)

        V = X[col].unsqueeze(1).expand(-1, self.heads, -1)
        agg = torch.zeros(N, self.heads, self.dk, device=X.device)
        agg.index_add_(0, row, attn.unsqueeze(-1) * V[:, :, :self.dk])
        agg = agg.mean(dim=1)
        agg = self.linear(agg)

        diffusion_term = agg - X
        memory_term = self.X0 - X
        lambda_weights = torch.sigmoid(self.lambda_diag)
        dXdt = lambda_weights * diffusion_term + (1 - lambda_weights) * memory_term
        return dXdt

def rk4_step_fully(func, x, t, dt, step_number=None):
    k1 = func(x, step_number, t=t)
    k2 = func(x + 0.5 * dt * k1, step_number, t=t + 0.5 * dt)
    k3 = func(x + 0.5 * dt * k2, step_number, t=t + 0.5 * dt)
    k4 = func(x + dt * k3, step_number, t=t + dt)
    return x + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)

def solve_diffusion_rk4_fully(func, x0, t_span, dt):
    x = x0.clone()
    t = t_span[0]
    c = 0
    while c < int((t_span[1] - t_span[0]) / dt):
        x = rk4_step_fully(func, x, t, dt, step_number=c)
        c += 1
        t += dt
    return x

class GRANDRefully(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_steps, heads, T=1.0, input_dropout=0.0):
        super().__init__()
        self.input_dropout = nn.Dropout(input_dropout)
        self.encoder = GCNConv(in_dim, hidden_dim)
        self.diffusions = [LambdaDiffusion(hidden_dim, heads=heads).to(device) for _ in range(num_steps)]
        self.decoder = GCNConv(hidden_dim, out_dim)
        self.num_steps = num_steps
        self.T = T
        self.dt = T / num_steps

    def forward(self, x, edge_index):
        x = self.input_dropout(x)
        x0 = F.relu(self.encoder(x, edge_index))

        for diffusion in self.diffusions:
            diffusion.init_X0(x0)

        def diffusion_func(X, step_number, t=None):
            return self.diffusions[step_number](X, edge_index.to(X.device))

        x = solve_diffusion_rk4_fully(diffusion_func, x0, [0, self.T], self.dt)
        x = self.decoder(x, edge_index)
        return F.log_softmax(x, dim=1), x

class Diffusion(nn.Module):
    def __init__(self, in_dim, heads=4, use_adj_mask=True):
        super().__init__()
        assert in_dim % heads == 0, "in_dim must be divisible by heads"
        self.heads = heads
        self.dk = in_dim // heads
        self.use_adj_mask = use_adj_mask
        self.WQ = nn.Linear(in_dim, in_dim)
        self.WK = nn.Linear(in_dim, in_dim)
        self.linear = nn.Linear(self.dk, in_dim)
        self.initialized = False

    def init_X0(self, X):
        if not self.initialized:
            self.X0 = X.detach().clone()
            self.initialized = True

    def forward(self, X, edge_index):
        self.init_X0(X)
        N = X.size(0)

        Q = self.WQ(X).view(N, self.heads, self.dk)
        K = self.WK(X).view(N, self.heads, self.dk)

        row, col = edge_index
        scores = (Q[row] * K[col]).sum(dim=-1) / (self.dk ** 0.5)
        attn = softmax(scores, row)

        V = X[col].unsqueeze(1).expand(-1, self.heads, -1)
        agg = torch.zeros(N, self.heads, self.dk, device=X.device)
        agg.index_add_(0, row, attn.unsqueeze(-1) * V[:, :, :self.dk])
        agg = agg.mean(dim=1)
        agg = self.linear(agg)

        dXdt = agg - X
        return dXdt


class GRANDR(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_steps, heads, T=1.0, input_dropout=0.0):
        super().__init__()
        self.input_dropout = nn.Dropout(input_dropout)

        self.encoder = nn.Linear(in_dim, hidden_dim)
        self.diffusions = Diffusion(hidden_dim, heads=heads).to(device)
        self.decoder = nn.Linear(hidden_dim, out_dim)

        self.num_steps = num_steps
        self.T = T
        self.dt = T / num_steps

    def forward(self, x, edge_index):
        x = self.input_dropout(x)
        x0 = F.relu(self.encoder(x))

        self.diffusions.init_X0(x0)

        def diffusion_func(X, step_number, t=None):
            return self.diffusions(X, edge_index.to(X.device))

        x = solve_diffusion_rk4_fully(diffusion_func, x0, [0, self.T], self.dt)

        x = self.decoder(x)
        return x

def runner_GCN(num_layers, device, data):
    model = GCN(input_dim=2, hidden_dim=16, output_dim=2, num_layers=num_layers).to(device)
    model.eval()
    with torch.no_grad():
        X = model(data.to(device))
    return X

def runner_GAT(num_layers, device, data):
    model = GAT(input_dim=2, hidden_dim=16, output_dim=2, num_layers=num_layers).to(device)
    model.eval()
    with torch.no_grad():
        X = model(data.to(device))
    return X

def runner_SAGE(num_layers, device, data):
    model = SAGE(input_dim=2, hidden_dim=16, output_dim=2, num_layers=num_layers).to(device)
    model.eval()
    with torch.no_grad():
        X = model(data.to(device))
    return X

def runner_GRANDRefully(num_steps, device, data):
    model = GRANDRefully(in_dim=data.x.shape[1], hidden_dim=16, out_dim=len(data.y.unique()),
                        num_steps=num_steps, heads=4, T=num_steps, input_dropout=0.0).to(device)
    model.eval()
    with torch.no_grad():
        out_log_softmax, embeddings = model(data.x.to(device), data.edge_index.to(device))
    return embeddings.cpu()

def runner_GRAND(num_steps, device, data):
    model = GRANDR(in_dim=data.x.shape[1], hidden_dim=16, out_dim=len(data.y.unique()),
                        num_steps=num_steps, heads=4, T=num_steps, input_dropout=0.0).to(device)
    model.eval()
    with torch.no_grad():
        embeddings = model(data.x.to(device), data.edge_index.to(device))
    return embeddings.cpu()

def dirichlet_energy(X, edge_index):
    return sum(torch.norm(X[i] - X[j])**2 for i, j in edge_index.to(X.device).t().tolist()).item()

def compute_dirichlet_energy_repeated_general(model_runner, ll, device, data, n_runs=10, runner_kwargs=None):
    if runner_kwargs is None:
        runner_kwargs = {}
    energies_all = []
    for run in range(n_runs):
        run_energies = []
        for steps in tqdm(ll, desc=f"run {run+1}/{n_runs}", leave=False):
            embeddings = model_runner(steps, device, data, **runner_kwargs)
            e = dirichlet_energy(embeddings.to(device), data.edge_index.to(device))
            run_energies.append(e)
        energies_all.append(run_energies)
    energies_all = np.array(energies_all)  # (n_runs, len(ll))
    mean_energy = np.mean(energies_all, axis=0)
    relative_std = np.zeros_like(mean_energy)
    for i in range(len(mean_energy)):
        if mean_energy[i] > 1e-36:
            relative_std[i] = np.std(energies_all[:, i]) / mean_energy[i]
        else:
            relative_std[i] = 0.1
    return mean_energy, relative_std

def plot_with_relative_band(x, mean, relative_std, color, label, min_visible_std=0.05):
    """Plot mean curve with shaded band based on log-space standard deviation."""
    plt.plot(x, mean, color=color, label=label, linewidth=2)

    # Compute log-space standard deviation
    log_values = np.log(np.maximum(mean, 1e-30))
    log_std = np.maximum(relative_std, min_visible_std)  # ensure minimum visibility

    upper = mean * np.exp(log_std)
    lower = mean * np.exp(-log_std)

    plt.fill_between(x, lower, upper, color=color, alpha=0.3)

print("Running repeated experiments...")

baseline_energy = dirichlet_energy(data.x, data.edge_index)
print("Baseline energy:", baseline_energy)

def normalize_to_baseline(mean_curve, baseline_energy):
    scaling = baseline_energy / np.maximum(mean_curve[0], 1e-12)
    return mean_curve * scaling

print("GCN...")
meanGCN, rel_stdGCN = compute_dirichlet_energy_repeated_general(runner_GCN, ll, device, data, n_runs=monte_carlo_runs)
meanGCN = normalize_to_baseline(meanGCN, baseline_energy)

print("GAT...")
meanGAT, rel_stdGAT = compute_dirichlet_energy_repeated_general(runner_GAT, ll, device, data, n_runs=monte_carlo_runs)
meanGAT = normalize_to_baseline(meanGAT, baseline_energy)

print("SAGE...")
meanSAGE, rel_stdSAGE = compute_dirichlet_energy_repeated_general(runner_SAGE, ll, device, data, n_runs=monte_carlo_runs)
meanSAGE = normalize_to_baseline(meanSAGE, baseline_energy)

print("GRAND...")
meanGRAND, rel_stdGRAND = compute_dirichlet_energy_repeated_general(runner_GRAND, ll, device, data, n_runs=monte_carlo_runs)
meanGRAND = normalize_to_baseline(meanGRAND, baseline_energy)

print("GRAND-ASC...")
meanGRANDRefully, rel_stdGRANDRefully = compute_dirichlet_energy_repeated_general(runner_GRANDRefully, ll, device, data, n_runs=monte_carlo_runs)
meanGRANDRefully = normalize_to_baseline(meanGRANDRefully, baseline_energy)

bold_font = FontProperties(weight='bold', size=16)

plt.figure(figsize=(8, 6), dpi=600)

plot_with_relative_band(ll, meanGRAND, rel_stdGRAND, 'blue', 'GRAND')
plot_with_relative_band(ll, meanGRANDRefully, rel_stdGRANDRefully, 'green', 'GRAND-ASC')
plot_with_relative_band(ll, meanGCN, rel_stdGCN, 'red', 'GCN')
plot_with_relative_band(ll, meanGAT, rel_stdGAT, 'orange', 'GAT')
plot_with_relative_band(ll, meanSAGE, rel_stdSAGE, 'purple', 'SAGE')

plt.tick_params(axis='both', which='major', labelsize=16)
plt.tick_params(axis='both', which='minor', labelsize=14)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.grid(True, alpha=0.3)
plt.axhline(0, color='black', linewidth=0.5)
plt.axvline(0, color='black', linewidth=0.5)


plt.legend(prop=bold_font)
plt.xlabel('Depth/Layers', fontsize=16, fontweight='bold')
plt.ylabel('Dirichlet Energy', fontsize=16, fontweight='bold')
plt.yscale('log')
plt.tight_layout()
plt.show()
