import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor
from torch_geometric.utils import get_laplacian, to_dense_adj
import numpy as np

class Matern_Kernel_Module(nn.Module):
    def __init__(self, hidden_dim=64, polynomial_type="rational"):
        super(Matern_Kernel_Module, self).__init__()
        # Networks to compute k and nu parameters based on graph features H
        self.k_net = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Softplus()  # Ensures k is positive
        )
        
        self.nu_net = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Softplus()  # Ensures nu is positive
        )
        
        # Type of polynomial approximation to use
        self.polynomial_type = polynomial_type

    def forward(self, H, L=None, edge_index=None, num_nodes=None):
        """
        Compute Matérn kernel given graph data
        Args:
            H: Node features/embeddings (N x hidden_dim)
            L: Graph Laplacian matrix (optional)
            edge_index: Edge indices (optional)
            num_nodes: Number of nodes (optional)
        Returns:
            K: Matérn kernel matrix
        """
        # Compute adaptive parameters based on graph features
        k = self.k_net(H.mean(dim=0, keepdim=True))  # Global graph representation
        nu = self.nu_net(H.mean(dim=0, keepdim=True))
        
        # Remove extra dimensions but keep as tensors
        k = k.squeeze()
        nu = nu.squeeze()
        
        # Get Laplacian if not provided
        if L is None:
            assert edge_index is not None and num_nodes is not None
            edge_index, edge_weight = get_laplacian(edge_index, normalization='sym', 
                                                  num_nodes=num_nodes)
            L = to_dense_adj(edge_index, edge_attr=edge_weight)[0]
        
        try:
            # Try standard eigendecomposition approach first
            eigvals, eigvecs = torch.linalg.eigh(L)
            
            # Compute Matérn kernel function on eigenvalues
            # Ensure proper broadcasting by reshaping scalar tensors
            denom = 2 * nu / (k**2) + eigvals
            kernel_eigvals = denom.pow(-nu)
            
            # Reconstruct kernel matrix
            K = eigvecs @ torch.diag(kernel_eigvals) @ eigvecs.T
        except RuntimeError:
            # Fallback to polynomial approximation
            if self.polynomial_type == "chebyshev":
                K = self.chebyshev_matern(L, k, nu)
            else:  # Default to rational approximation
                K = self.rational_matern(L, k, nu)
        
        return K, k, nu
    
    def chebyshev_matern(self, L, k, nu, order=20):
        """
        Approximate Matérn kernel using Chebyshev polynomial approximation
        
        Args:
            L: Graph Laplacian matrix
            k: Length scale parameter
            nu: Smoothness parameter
            order: Order of Chebyshev approximation
            
        Returns:
            K: Approximated Matérn kernel matrix
        """
        import math
        n = L.shape[0]
        I = torch.eye(n, device=L.device)
        
        # Scale L to have eigenvalues in [-1, 1]
        # We use a simple estimate based on Gershgorin circle theorem
        # which states that eigenvalues are bounded by row sums
        row_sums = torch.sum(torch.abs(L), dim=1)
        max_eigenval = torch.max(row_sums)
        scaled_L = 2 * L / max_eigenval - I
        
        # Set up the operator A = (2*nu/(k^2))*I + L
        term = 2 * nu / (k**2)
        scaled_A = (2/max_eigenval) * term * I + scaled_L
        
        # Approximate (A)^{-nu} using truncated Chebyshev series
        # Initialize with identity and first Chebyshev polynomial
        T_prev = I
        T_curr = scaled_A
        
        # For negative power, use a direct polynomial approximation
        # Initialize result with appropriate coefficient
        c0 = self._compute_chebyshev_coeff(-nu, 0)
        result = c0 * T_prev
        
        c1 = self._compute_chebyshev_coeff(-nu, 1)
        result += c1 * T_curr
        
        # Recurrence relation for higher order terms
        for i in range(2, order):
            T_next = 2 * scaled_A @ T_curr - T_prev
            T_prev = T_curr
            T_curr = T_next
            
            ci = self._compute_chebyshev_coeff(-nu, i)
            result += ci * T_curr
        
        # Final scaling to account for domain transformation
        scaling_factor = (max_eigenval / 2) ** nu
        K = scaling_factor * result
        
        # Add jitter to ensure positive definiteness
        jitter = 1e-6 * torch.eye(n, device=L.device)
        K = K + jitter
        
        return K
        
    def rational_matern(self, L, k, nu, order=15):
        """
        Approximate Matérn kernel using rational approximation that ensures positive definiteness
        
        Args:
            L: Graph Laplacian matrix
            k: Length scale parameter
            nu: Smoothness parameter
            order: Order of approximation
            
        Returns:
            K: Approximated Matérn kernel matrix
        """
        n = L.shape[0]
        I = torch.eye(n, device=L.device)
        
        # For Matérn kernel, we can ensure positive definiteness by:
        # 1. Adding jitter to the diagonal
        # 2. Using an alternative approach with rational approximation
        
        # Method 1: Using modified rational approximation
        # Set up the operator A = (2*nu/(k^2))*I + L
        A = (2 * nu / (k**2)) * I + L
        
        # Approximate (A)^{-nu} using a rational approximation
        # For fractional powers, we can use (A^{-1})^{nu} approximation
        # First compute an approximation of A^{-1} via Neumann series
        
        # Estimate largest eigenvalue for scaling
        row_sums = torch.sum(torch.abs(L), dim=1)
        spectral_radius = torch.max(row_sums)
        
        # Scale A to ensure convergence of the series
        scaling = 1.0 / (spectral_radius + 2*nu/(k**2))
        scaled_A = scaling * A
        
        # Compute approximate inverse using Neumann series: 
        # (I - (I - scaled_A))^{-1} = I + (I-scaled_A) + (I-scaled_A)^2 + ...
        B = I - scaled_A
        A_inv_approx = I.clone()
        B_power = I.clone()
        
        for i in range(1, order):  # default 15 terms should be sufficient
            B_power = B_power @ B
            A_inv_approx += B_power
        
        # Now we have A_inv_approx ≈ A^{-1}
        # Undo scaling
        A_inv_approx = scaling * A_inv_approx
        
        # For non-integer nu, we need to approximate A^{-nu} = (A^{-1})^nu
        # We'll use a rational approximation based on continued fractions
        
        # For nu ≤ 1, we can directly use A_inv_approx^nu
        if nu <= 1:
            # Compute eigendecomposition of A_inv_approx safely
            try:
                eigvals, eigvecs = torch.linalg.eigh(A_inv_approx)
                # Ensure all eigenvalues are positive
                eigvals = torch.clamp(eigvals, min=1e-10)
                # Compute power
                K = eigvecs @ torch.diag(eigvals.pow(nu)) @ eigvecs.T
            except RuntimeError:
                # If eigendecomposition fails, use a different approach
                # Add jitter to ensure positive definiteness
                jitter = 1e-6 * torch.eye(n, device=L.device)
                K = A_inv_approx + jitter
                # For nu close to 1, A_inv_approx is already close to A^{-nu}
                if nu < 0.9:
                    # For smaller nu, use weighted combination
                    alpha = nu  # Weight factor
                    K = alpha * A_inv_approx + (1 - alpha) * I
        else:
            # For nu > 1, we use a recursive approach
            # Compute A^{-1} first, then recursively compute powers
            integer_part = int(nu)
            fractional_part = nu - integer_part
            
            # Compute A^{-integer_part}
            K = A_inv_approx.clone()
            for _ in range(1, integer_part):
                K = K @ A_inv_approx
            
            # If there's a fractional part, compute it
            if fractional_part > 0:
                try:
                    # Try eigendecomposition on the current K
                    eigvals, eigvecs = torch.linalg.eigh(K)
                    # Ensure all eigenvalues are positive
                    eigvals = torch.clamp(eigvals, min=1e-10)
                    # Apply fractional power
                    K = eigvecs @ torch.diag(eigvals.pow(fractional_part)) @ eigvecs.T
                except RuntimeError:
                    # If eigendecomposition fails, approximate the fractional part
                    alpha = fractional_part  # Weight factor
                    K_frac = alpha * A_inv_approx + (1 - alpha) * I
                    K = K @ K_frac
        
        # Add small jitter to diagonal to ensure positive definiteness
        jitter = 1e-6 * torch.eye(n, device=L.device)
        K = K + jitter
        
        return K

    def _compute_chebyshev_coeff(self, alpha, k):
        """
        Compute coefficients for Chebyshev approximation of x^alpha
        """
        import math
        if k == 0:
            return 2**alpha * math.gamma(alpha + 1/2) / (math.sqrt(math.pi) * math.gamma(alpha + 1))
        else:
            num = (-1)**k * math.gamma(alpha + 1/2) * math.gamma(alpha - k + 1)
            denom = math.sqrt(math.pi) * math.factorial(k) * math.gamma(alpha + 1) * (1 - 2*k)
            return 2**alpha * num / denom

class Diffusion_Kernel_Module(nn.Module):
    def __init__(self, hidden_dim=64):
        super(Diffusion_Kernel_Module, self).__init__()
        # Network to compute k parameter based on graph features H
        self.k_net = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Softplus()  # Ensures k is positive
        )

    def forward(self, H, L=None, edge_index=None, num_nodes=None):
        """
        Compute diffusion kernel given graph data
        Args:
            H: Node features/embeddings (N x hidden_dim)
            L: Graph Laplacian matrix (optional)
            edge_index: Edge indices (optional)
            num_nodes: Number of nodes (optional)
        Returns:
            K: Diffusion kernel matrix
        """
        # Compute adaptive parameter based on graph features
        k = self.k_net(H.mean(dim=0, keepdim=True))  # Global graph representation
        
        # Remove extra dimensions but keep as tensor
        k = k.squeeze()
        
        # Get Laplacian if not provided
        if L is None:
            assert edge_index is not None and num_nodes is not None
            edge_index, edge_weight = get_laplacian(edge_index, normalization='sym', 
                                                  num_nodes=num_nodes)
            L = to_dense_adj(edge_index, edge_attr=edge_weight)[0]
        
        try:
            # Try standard eigendecomposition approach first
            eigvals, eigvecs = torch.linalg.eigh(L)
            
            # Compute diffusion kernel function on eigenvalues
            kernel_eigvals = torch.exp(-k**2/2 * eigvals)
            
            # Reconstruct kernel matrix
            K = eigvecs @ torch.diag(kernel_eigvals) @ eigvecs.T
        except RuntimeError:
            # Fallback: Use Taylor series approximation for exp(-t*L)
            K = self.taylor_diffusion(L, k)
        
        return K, k
    
    def taylor_diffusion(self, L, k, order=10):
        """
        Approximate diffusion kernel using Taylor series expansion
        
        Args:
            L: Graph Laplacian matrix
            k: Diffusion parameter
            order: Order of Taylor approximation
            
        Returns:
            K: Approximated diffusion kernel matrix
        """
        n = L.shape[0]
        I = torch.eye(n, device=L.device)
        
        # Compute exp(-k^2/2 * L) using Taylor series
        scaled_L = -k**2/2 * L
        K = I.clone()  # First term: I
        L_power = I.clone()
        factorial = 1.0
        
        for i in range(1, order):
            factorial *= i
            L_power = L_power @ scaled_L  # L^i
            K += L_power / factorial
        
        return K
    
    @staticmethod
    def get_sparse_kernel(edge_index, k, num_nodes):
        """
        Compute sparse diffusion kernel for large graphs
        Args:
            edge_index: Edge indices
            k: Diffusion parameter
            num_nodes: Number of nodes
        Returns:
            K: Sparse diffusion kernel
        """
        # Get sparse Laplacian
        edge_index, edge_weight = get_laplacian(edge_index, normalization='sym', 
                                              num_nodes=num_nodes)
        L = SparseTensor(row=edge_index[0], col=edge_index[1], 
                        value=edge_weight, sparse_sizes=(num_nodes, num_nodes))
        
        # Ensure k is properly shaped for computation
        if torch.is_tensor(k):
            k = k.reshape(())  # Make sure it's a scalar tensor
            
        # Approximate matrix exponential using truncated series
        K = SparseTensor.eye(num_nodes)
        L_power = SparseTensor.eye(num_nodes)
        coeff = torch.tensor(1.0, device=k.device) if torch.is_tensor(k) else 1.0
        
        for i in range(1, 10):  # Use 10 terms in the series
            coeff = coeff * (-k**2/4) / i
            L_power = L_power @ L
            K = K + coeff * L_power
        
        return K
