import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from matplotlib.font_manager import FontProperties
from torch_geometric.data import Data
from torch_geometric.datasets import WebKB
from torch_geometric.utils import softmax, to_dense_adj

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from model import LambdaDiffusion
from read_data import read_dataset

'''
class LambdaDiffusion(nn.Module):
    def __init__(self, in_dim, heads=4):
        super().__init__()
        self.heads = heads
        self.dk = in_dim // heads
        self.WQ = nn.Linear(in_dim, in_dim)
        self.WK = nn.Linear(in_dim, in_dim)
        self.linear = nn.Linear(self.dk, in_dim)
        self.initialized = True
        self.lambda_diag_initialized = True

    def forward(self, X, edge_index):
        N = X.size(0)
        Q = self.WQ(X).view(N, self.heads, self.dk)
        K = self.WK(X).view(N, self.heads, self.dk)

        row, col = edge_index
        scores = (Q[row] * K[col]).sum(dim=-1) / (self.dk ** 0.5)
        attn = softmax(scores, row)  # (E,)

        V = X[col].unsqueeze(1).expand(-1, self.heads, -1)
        agg = torch.zeros(N, self.heads, self.dk, device=X.device)
        agg.index_add_(0, row, attn.unsqueeze(-1) * V[:, :, :self.dk])
        agg = agg.mean(dim=1)
        agg = self.linear(agg)

        diffusion_term = agg - X
        return diffusion_term
'''
def get_attention_matrix(model, X, edge_index):
    model.eval()
    with torch.no_grad():
        N = X.size(0)
        heads = model.heads
        dk = model.dk

        Q = model.WQ(X).view(N, heads, dk)
        K = model.WK(X).view(N, heads, dk)

        row, col = edge_index
        scores = (Q[row] * K[col]).sum(dim=-1) / (dk ** 0.5)
        attn = softmax(scores, row)  # normalize over source node neighbors

        A_attn = to_dense_adj(edge_index, max_num_nodes=N, edge_attr=attn)

        if A_attn.ndim == 4:
            A_attn = A_attn[0]  
            A_attn = A_attn.mean(dim=-1)  
        elif A_attn.ndim == 3:
            A_attn = A_attn[0]  

        diffusion_matrix = A_attn - torch.eye(N, device=X.device)

    return diffusion_matrix.cpu().numpy()


# Stability functions
class Stability:
    @staticmethod
    def stability_euler(z):
        return 1 + z

    @staticmethod
    def stability_rk4(z):
        return 1 + z + z**2 / 2 + z**3 / 6 + z**4 / 24

    @staticmethod
    def stability_rk5(z):
        return 1 + z + z**2 / 2 + z**3 / 6 + z**4 / 24 + z**5 / 120

def plotter(stab_rk4,stab_rk5,stab_euler,X,Y):
    
    bold_font = FontProperties(weight='bold', size=16)

    plt.figure(figsize=(8, 6), dpi=600)

    # Plot RK4 stability region (blue)
    plt.contourf(X, Y, stab_rk4, levels=[0.5, 1], colors=['lightblue'], alpha=0.5)
    plt.contour(X, Y, stab_rk4, levels=[0.5], colors=['blue'], linewidths=1)

    # Plot RK5 stability region (cyan)
    plt.contourf(X, Y, stab_rk5, levels=[0.5, 1], colors=['lightcyan'], alpha=0.5)
    plt.contour(X, Y, stab_rk5, levels=[0.5], colors=['deepskyblue'], linewidths=1, linestyles='--')

    # Plot Euler stability region (red)
    plt.contourf(X, Y, stab_euler, levels=[0.5, 1], colors=['mistyrose'], alpha=0.5)
    plt.contour(X, Y, stab_euler, levels=[0.5], colors=['red'], linewidths=1, linestyles='-')

    # Formatting
    plt.xlabel(r"Re($\Delta t \mu_{\mathbf{B}}$)", fontsize=16, fontweight='bold')
    plt.ylabel(r"Im($\Delta t \mu_{\mathbf{B}}$)", fontsize=16, fontweight='bold')
    plt.title('Stability Regions: Euler (red) vs RK4 (blue) vs RK5 (cyan)', fontsize=16, fontweight='bold')

    plt.tick_params(axis='both', which='major', labelsize=16)
    plt.tick_params(axis='both', which='minor', labelsize=14)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    
    plt.grid(True, alpha=0.3)
    plt.axhline(0, color='black', linewidth=0.5)
    plt.axvline(0, color='black', linewidth=0.5)

    # Legend with matching line styles
    plt.plot([], [], color='blue', label='RK4')
    plt.plot([], [], color='red', linestyle='-', label='Euler')
    plt.plot([], [], color='deepskyblue', linestyle='--', label='RK5')

    plt.legend(prop=bold_font)
    plt.tight_layout()
    plt.show()

def runner(data,in_dim,hidden_dim,heads):
    if hidden_dim%heads != 0 :
        raise ValueError("hidden_dim must be divisible by heads")
    
    # Initialize LambdaDiffusion model
    lambda_diffusion_model = LambdaDiffusion(in_dim=hidden_dim, heads=heads).to(device)
    lambda_diffusion_model.initialized = True
    lambda_diffusion_model.lambda_diag_initialized = True

    # Project node features to hidden_dim
    linear_proj = nn.Linear(in_dim, hidden_dim).to(device)
    X_hidden = linear_proj(data.x.to(device))

    # Get diffusion (attention) matrix
    diffusion_matrix = get_attention_matrix(lambda_diffusion_model, X_hidden, data.edge_index)

    # Compute eigenvalues
    eigenvalues = np.linalg.eigvals(diffusion_matrix)


    #stability regions
    x = np.linspace(-5, 5, 400)
    y = np.linspace(-5, 5, 400)
    X, Y = np.meshgrid(x, y)
    Z = X + 1j * Y

    stab_euler = np.abs(Stability.stability_euler(Z)) <= 1
    stab_rk4 = np.abs(Stability.stability_rk4(Z)) <= 1
    stab_rk5 = np.abs(Stability.stability_rk5(Z)) <= 1
    h = 0.5
    scaled_eigs = h * eigenvalues

    plotter(stab_rk4,stab_rk5,stab_euler,X,Y)


## Synthetic Case : SBM
def SBM_generator(num_nodes,p,q,mu1,mu2,sigma):

    labels = np.random.randint(0, 2, num_nodes)
    features = np.zeros((num_nodes, 2))
    for i in range(num_nodes):
        mean = mu1 if labels[i] == 0 else mu2
        features[i] = np.random.normal(loc=mean, scale=sigma, size=2)

    edges = []
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if labels[i] == labels[j]:
                if np.random.rand() < p:
                    edges.append((i, j))
                    edges.append((j, i))
            else:
                if np.random.rand() < q:
                    edges.append((i, j))
                    edges.append((j, i))

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    x = torch.tensor(features, dtype=torch.float)
    y = torch.tensor(labels, dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, y=y).to(device) # Move data to device
    in_dim = data.num_features
    return data,in_dim



data,in_dim,_ = read_dataset('Texas',device)
runner(data,in_dim,hidden_dim=32,heads=4) 

data,in_dim = SBM_generator(num_nodes = 100,p = 0.9, q = 0.1 , mu1 = 0.5,  mu2 = -0.5, sigma = 1)
runner(data,in_dim,hidden_dim=32,heads=4) 
