import os
import random
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib.patches import FancyArrowPatch

import torch
import torch.nn as nn
import torch.nn.functional as F

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"🔒 Seed set to: {seed}")
set_seed(13)

plt.rcParams.update({
    "font.size": 20,        # default font size for labels, titles, ticks
    "axes.titlesize": 16,   # title size
    "axes.labelsize": 14,   # x and y labels
    "xtick.labelsize": 12,  # x tick labels
    "ytick.labelsize": 12,  # y tick labels
    "legend.fontsize": 12,  # legend text
})

class Invertible1x1Conv(nn.Module):
    """ 
    As introduced in Glow paper, with bias.
    """

    def __init__(self, dim):
        super().__init__()
        self.dim = dim

        Q = torch.nn.init.orthogonal_(torch.randn(dim, dim))
        LU, pivots = torch.linalg.lu_factor(Q)
        P, L, U = torch.lu_unpack(LU, pivots)

        self.register_buffer("P", P)
        self.L = nn.Parameter(L)
        self.S = nn.Parameter(U.diag())
        self.U = nn.Parameter(torch.triu(U, diagonal=1))

        self.bias = nn.Parameter(torch.zeros(dim))

    def _assemble_W(self, device):
        L = torch.tril(self.L, diagonal=-1) + torch.diag(
            torch.ones(self.dim, device=device)
        )
        U = torch.triu(self.U, diagonal=1)
        W = self.P @ L @ (U + torch.diag(self.S))
        return W

    def forward(self, x):
        W = self._assemble_W(x.device)
        z = x @ W + self.bias

        log_det = torch.sum(torch.log(torch.abs(self.S)))
        log_det = log_det.expand(x.shape[0])

        return z, log_det

    def backward(self, z):
        W = self._assemble_W(z.device)
        W_inv = torch.inverse(W)

        x = (z - self.bias) @ W_inv

        log_det = -torch.sum(torch.log(torch.abs(self.S)))
        return x, log_det

class Sigmoid(nn.Module):
    """
    Sigmoid activation for flows.
    """

    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        z = torch.sigmoid(x)
        log_det = torch.sum(torch.log(z * (1 - z) + 1e-12), dim=-1) 
        return z, log_det

    def backward(self, z):
        z = torch.clamp(z, 1e-12, 1-1e-12)
        x = torch.log(z) - torch.log(1 - z)
        log_det = -torch.sum(torch.log(z * (1 - z)), dim=-1)
        return x, log_det

class InvGNN(nn.Module):
    def __init__(self, input_dim, n_layers, device):
        super(InvGNN, self).__init__()
        self.n_layers = n_layers
        self.prior = torch.distributions.MultivariateNormal(
            torch.zeros(input_dim, device=device), torch.eye(input_dim, device=device)
        )
        self.fc = nn.ModuleList([Invertible1x1Conv(input_dim) for _ in range(n_layers)])
        self.act = Sigmoid()

    def f(self, x, exp_adj, logabsdet_exp_adj): 
        log_det_J, z = x.new_zeros(x.size(0), device=x.device), x
        for i in range(self.n_layers):
            z = torch.spmm(exp_adj, z)
            log_det_J += logabsdet_exp_adj
            z, logdetconv = self.fc[i](z)
            log_det_J += logdetconv
            if i < self.n_layers - 1:
               z, logdetact = self.act(z)
               log_det_J += logdetact                

        return z, log_det_J
    
    def log_prob(self, x, exp_adj, logabsdet_exp_adj):
        z, logdet = self.f(x, exp_adj, logabsdet_exp_adj)
        return self.prior.log_prob(z) + logdet


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

n_nodes = 20
G = nx.fast_gnp_random_graph(n_nodes+1, 0.2)
G.remove_node(0)
max_deg = max([G.degree(node) for node in G.nodes()])
print('Max degree:', max_deg)
A = nx.to_numpy_array(G)/max_deg
L,U = np.linalg.eigh(A)
L[np.abs(L) < 1e-8] = 0
U[np.abs(U) < 1e-8] = 0
exp_adj = np.linalg.multi_dot((U, np.diag(np.exp(L)), U.T))
logabsdet_exp_adj = torch.ones(n_nodes)*np.linalg.slogdet(exp_adj)[1]
exp_adj = torch.from_numpy(exp_adj).float().to(device)
x = torch.randn(n_nodes, 2).to(device)

model = InvGNN(input_dim=2, n_layers=10, device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
for epoch in range(1, 5001):
    optimizer.zero_grad()  
    loss = -model.log_prob(x, exp_adj, logabsdet_exp_adj).mean()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch} | NLL: {loss.item():.4f}")

G.add_node(n_nodes+1)
nbrs = np.random.choice(n_nodes, size=10, replace=False) + 1
for i in range(nbrs.size):
    G.add_edge(n_nodes+1, nbrs[i])

max_deg = max([G.degree(node) for node in G.nodes()])
print('Max degree:', max_deg)
A = nx.to_numpy_array(G)/max_deg
L,U = np.linalg.eigh(A)
exp_adj = np.linalg.multi_dot((U, np.diag(np.exp(L)), U.T))
logabsdet_exp_adj = torch.ones(n_nodes+1)*np.linalg.slogdet(exp_adj)[1]
exp_adj = torch.from_numpy(exp_adj).float().to(device)
x = torch.cat([x, torch.randn(1, 2).to(device)], dim=0)

model.eval()
with torch.no_grad():
    logpz = -model.log_prob(x, exp_adj, logabsdet_exp_adj)

nx.set_node_attributes(G, {i+1: float(v) for i, v in enumerate(logpz)}, "logpz")

node_color = [G.nodes[i]["logpz"] for i in G.nodes()]
pos = nx.spring_layout(G)

fig, ax = plt.subplots(figsize=(8,8))

nodes = nx.draw_networkx_nodes(
    G, pos,
    node_color=node_color,
    cmap='Blues',
    node_size=1000,
    edgecolors='k',
    alpha=0.7,
    ax=ax
)

node_radius = 0.07
def draw_edge(ax, p1, p2, node_radius, color='black', lw=1.5):
    x1, y1 = p1
    x2, y2 = p2
    dx, dy = x2 - x1, y2 - y1
    dist = np.sqrt(dx**2 + dy**2)
    factor = node_radius / dist
    x_start, y_start = x1 + dx*factor, y1 + dy*factor
    x_end, y_end = x2 - dx*factor, y2 - dy*factor
    edge = FancyArrowPatch((x_start, y_start), (x_end, y_end),
                           arrowstyle='-', color=color, lw=lw)
    ax.add_patch(edge)

for u, v in G.edges():
    draw_edge(ax, pos[u], pos[v], node_radius)

nx.draw_networkx_labels(G, pos, ax=ax)

sm = plt.cm.ScalarMappable(cmap='Blues', norm=plt.Normalize(vmin=min(node_color), vmax=max(node_color)))
sm.set_array([])

cbar = fig.colorbar(sm, ax=ax, alpha=0.7, fraction=0.046, pad=0.12, shrink=0.8)
cbar.set_label(r"$-\log p(x)$")

ax.axis('off')
plt.show()