import os
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj, degree, stochastic_blockmodel_graph, erdos_renyi_graph

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"🔒 Seed set to: {seed}")
set_seed(13)

plt.rcParams.update({
    "font.size": 16,        # default font size for labels, titles, ticks
    "axes.titlesize": 16,   # title size
    "axes.labelsize": 14,   # x and y labels
    "xtick.labelsize": 12,  # x tick labels
    "ytick.labelsize": 12,  # y tick labels
    "legend.fontsize": 12,  # legend text
})

def generate_sbm_graphs(n_graphs=100, n_nodes=50, feature_dim=2, outlier=False):
    data_lst = []
    for _ in range(n_graphs):
        if not outlier:
            edge_index = erdos_renyi_graph(n_nodes, edge_prob=(0.55*n_nodes-1)/(2*(n_nodes-1)))
        else:
            edge_probs = [[0.5, 0.05], 
                          [0.05, 0.5]]
            
            block_sizes = [n_nodes // 2, n_nodes // 2]
            edge_index = stochastic_blockmodel_graph(block_sizes, edge_probs)
        
        A = torch.zeros((n_nodes, n_nodes), dtype=torch.float32)
        A[edge_index[0], edge_index[1]] = 1./15
        A = A.numpy()
        L,U = np.linalg.eigh(A)
        L[np.abs(L) < 1e-8] = 0
        U[np.abs(U) < 1e-8] = 0
        exp_adj = np.linalg.multi_dot((U, np.diag(np.exp(L)), U.T))
        logabsdet_exp_adj = torch.ones(n_nodes)*np.linalg.slogdet(exp_adj)[1]
        row, col = np.where(exp_adj>0)
        edge_index_exp_adj = torch.tensor(np.array([row, col]), dtype=torch.long)
        exp_adj_flat = torch.from_numpy(exp_adj[row,col]).unsqueeze(1)

        x = torch.zeros(n_nodes, feature_dim)
        x[:n_nodes // 2] = torch.randn(n_nodes // 2, feature_dim) * 0.5 + 1.
        x[n_nodes // 2:] = torch.randn(n_nodes // 2, feature_dim) * 0.5 - 1.
        
        data_lst.append(Data(x=x, edge_index=edge_index, edge_index_exp_adj=edge_index_exp_adj, exp_adj_flat=exp_adj_flat, logabsdet_exp_adj=logabsdet_exp_adj))
    
    return data_lst


class Invertible1x1Conv(nn.Module):
    """ 
    As introduced in Glow paper, with bias.
    """

    def __init__(self, dim):
        super().__init__()
        self.dim = dim

        Q = torch.nn.init.orthogonal_(torch.randn(dim, dim))
        LU, pivots = torch.linalg.lu_factor(Q)
        P, L, U = torch.lu_unpack(LU, pivots)

        self.register_buffer("P", P)
        self.L = nn.Parameter(L)
        self.S = nn.Parameter(U.diag())
        self.U = nn.Parameter(torch.triu(U, diagonal=1))

        self.bias = nn.Parameter(torch.zeros(dim))

    def _assemble_W(self, device):
        L = torch.tril(self.L, diagonal=-1) + torch.diag(
            torch.ones(self.dim, device=device)
        )
        U = torch.triu(self.U, diagonal=1)
        W = self.P @ L @ (U + torch.diag(self.S))
        return W

    def forward(self, x):
        W = self._assemble_W(x.device)
        z = x @ W + self.bias

        log_det = torch.sum(torch.log(torch.abs(self.S)))
        log_det = log_det.expand(x.shape[0])

        return z, log_det

    def backward(self, z):
        W = self._assemble_W(z.device)
        W_inv = torch.inverse(W)

        x = (z - self.bias) @ W_inv

        log_det = -torch.sum(torch.log(torch.abs(self.S)))
        return x, log_det

class Sigmoid(nn.Module):
    """
    Invertible Sigmoid activation for flows.
    Supports forward, inverse, and log-determinant computation.
    """

    def __init__(self):
        super().__init__()
    
    # ---------------- Forward ----------------
    def forward(self, x):
        """
        Forward pass of sigmoid
        x: [batch, dim] or [dim]
        Returns:
            z: output in (0, 1)
            log_det: log-determinant of the Jacobian
        """
        z = torch.sigmoid(x)
        # derivative of sigmoid: sigmoid(x)*(1 - sigmoid(x))
        log_det = torch.sum(torch.log(z * (1 - z) + 1e-12), dim=-1)  # add eps for stability
        return z, log_det

    # ---------------- Inverse / Backward ----------------
    def backward(self, z):
        """
        Inverse of sigmoid (logit)
        z: [batch, dim] or [dim], input in (0,1)
        Returns:
            x: reconstructed input
            log_det: log-determinant of the inverse
        """
        # clamp to avoid log(0)
        z = torch.clamp(z, 1e-12, 1-1e-12)
        x = torch.log(z) - torch.log(1 - z)  # logit
        # log-det of inverse is -log-det of forward
        log_det = -torch.sum(torch.log(z * (1 - z)), dim=-1)
        return x, log_det

class Sigmoid(nn.Module):
    """
    Sigmoid activation for flows.
    """

    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        z = torch.sigmoid(x)
        log_det = torch.sum(torch.log(z * (1 - z) + 1e-12), dim=-1) 
        return z, log_det

    def backward(self, z):
        z = torch.clamp(z, 1e-12, 1-1e-12)
        x = torch.log(z) - torch.log(1 - z)
        log_det = -torch.sum(torch.log(z * (1 - z)), dim=-1)
        return x, log_det

class InvGNN(nn.Module):
    def __init__(self, input_dim, n_layers, device):
        super(InvGNN, self).__init__()
        self.n_layers = n_layers
        self.prior = torch.distributions.MultivariateNormal(
            torch.zeros(input_dim, device=device), torch.eye(input_dim, device=device)
        )
        self.fc = nn.ModuleList([Invertible1x1Conv(input_dim) for _ in range(n_layers)])
        self.act = Sigmoid()

    def f(self, data, use_graph): 
        x, edge_index, exp_adj_flat = data.x, data.edge_index_exp_adj, data.exp_adj_flat
        exp_adj = torch.sparse_coo_tensor(edge_index, exp_adj_flat.squeeze(), torch.Size([x.size(0),x.size(0)])).to(x.device)  
        
        log_det_J, z = x.new_zeros(x.size(0), device=x.device), x
        for i in range(self.n_layers):
            if use_graph:
                z = torch.spmm(exp_adj, z)
                log_det_J += data.logabsdet_exp_adj
            z, logdetconv = self.fc[i](z)
            log_det_J += logdetconv
            if i < self.n_layers - 1:
               z, logdetact = self.act(z)
               log_det_J += logdetact                

        return z, log_det_J
    
    def log_prob(self, data, use_graph=True):
        z, logdet = self.f(data, use_graph)
        return self.prior.log_prob(z) + logdet


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = generate_sbm_graphs(n_graphs=100)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)

test = generate_sbm_graphs(n_graphs=10) 
test_outlier = generate_sbm_graphs(n_graphs=10, outlier=True)

model = InvGNN(input_dim=2, n_layers=2, device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
for epoch in range(1, 2001):
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        loss = -model.log_prob(data).mean()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch} | NLL: {total_loss / len(train_loader):.4f}")

model.eval()
nlls = []
nlls_outliers = []
with torch.no_grad():
    for data in test:
        data = data.to(device)
        log_pz = -model.log_prob(data)
        nlls.append(log_pz) 

    for data in test_outlier:
        data = data.to(device)
        log_pz = -model.log_prob(data)
        nlls_outliers.append(log_pz)

nlls = torch.cat(nlls, dim=0).detach().cpu().numpy()
nlls_outliers = torch.cat(nlls_outliers, dim=0).detach().cpu().numpy()

model_set = InvGNN(input_dim=2, n_layers=2, device=device).to(device)
optimizer_set = torch.optim.Adam(model.parameters(), lr=1e-3)

model_set.train()
for epoch in range(1, 2001):
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer_set.zero_grad()
        
        loss = -model_set.log_prob(data, use_graph=False).mean()
        loss.backward()
        optimizer_set.step()
        total_loss += loss.item()
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch} | NLL: {total_loss / len(train_loader):.4f}")

model_set.eval()
nlls_set = []
nlls_outliers_set = []
with torch.no_grad():
    for data in test:
        data = data.to(device)
        log_pz = -model_set.log_prob(data, use_graph=False)
        nlls_set.append(log_pz) 

    for data in test_outlier:
        data = data.to(device)
        log_pz = -model_set.log_prob(data, use_graph=False)
        nlls_outliers_set.append(log_pz)

nlls_set = torch.cat(nlls_set, dim=0).detach().cpu().numpy()
nlls_outliers_set = torch.cat(nlls_outliers_set, dim=0).detach().cpu().numpy()

fig, ax = plt.subplots(1, 2, figsize=(16, 4), sharey=True)
colors = plt.get_cmap("Accent").colors

bins = np.histogram_bin_edges(np.concatenate([nlls, nlls_outliers]), bins=20)
ax[0].hist(nlls, bins=bins, alpha=0.6, color='#1f77b4', edgecolor='black', label='normal data')
ax[0].hist(nlls_outliers, bins=bins, alpha=0.6, color='#ff7f0e', edgecolor='black', label='anomalous data')
ax[0].set_xlabel(r"$-\log p(x)$")
ax[0].set_ylabel("# nodes")
ax[0].set_title("Graph structure considered")

bins = np.histogram_bin_edges(np.concatenate([nlls_set, nlls_outliers_set]), bins=20)
ax[1].hist(nlls_set, bins=bins, alpha=0.6, color='#1f77b4', edgecolor='black', label='normal data')
ax[1].hist(nlls_outliers_set, bins=bins, alpha=0.6, color='#ff7f0e', edgecolor='black', label='anomalous data')
ax[1].set_xlabel(r"$-\log p(x)$")
ax[1].legend()
ax[1].set_title("Graph structure ignored")

plt.tight_layout()
plt.show()