"""
Graph Construction Module

This module contains classes for building and manipulating the k-NN graph 
and Graph Laplacian used in LapBoost.

Classes:
    GraphConstructor: Builds k-NN graph with Gaussian similarity weights
    GraphLaplacian: Constructs and manipulates graph Laplacian matrices
"""

import numpy as np
from typing import Optional, Tuple, Dict, Any, Union
from sklearn.neighbors import kneighbors_graph, NearestNeighbors
from scipy.sparse import csr_matrix, diags, identity
from scipy.sparse.linalg import spsolve


class GraphConstructor:
    """
    Construct k-NN graphs with Gaussian similarity weights.
    
    This class builds a weighted undirected graph where vertices correspond
    to data points and edge weights encode pairwise similarities using a
    Gaussian kernel.
    
    Parameters
    ----------
    k_neighbors : int, default=10
        Number of neighbors for graph construction
        
    sigma : float, default=1.0
        Bandwidth parameter for Gaussian similarity weights
        
    symmetrize : bool, default=True
        Whether to symmetrize the graph (make it undirected)
        
    random_state : int, default=None
        Random seed for reproducibility
        
    Attributes
    ----------
    adjacency_matrix_ : scipy.sparse.csr_matrix
        Adjacency matrix of the k-NN graph
        
    degree_matrix_ : scipy.sparse.csr_matrix
        Diagonal matrix with node degrees
        
    laplacian_ : scipy.sparse.csr_matrix
        Graph Laplacian matrix
    """
    
    def __init__(
        self,
        k_neighbors: int = 10,
        sigma: float = 1.0,
        symmetrize: bool = True,
        random_state: Optional[int] = None
    ):
        self.k_neighbors = k_neighbors
        self.sigma = sigma
        self.symmetrize = symmetrize
        self.random_state = random_state
        
    def fit(self, X: np.ndarray) -> 'GraphConstructor':
        """
        Build the k-NN graph from data.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
            
        Returns
        -------
        self : GraphConstructor
            Fitted graph constructor
        """
        # Store data for possible future use
        self.X_ = X
        
        # Build k-NN graph
        self.knn_ = NearestNeighbors(
            n_neighbors=self.k_neighbors,
            algorithm='auto',
            n_jobs=-1
        )
        self.knn_.fit(X)
        
        # Get distances and indices of k nearest neighbors
        self.distances_, self.indices_ = self.knn_.kneighbors(X)
        
        # Build adjacency matrix with Gaussian weights
        n_samples = X.shape[0]
        rows, cols, data = [], [], []
        
        for i in range(n_samples):
            for j, dist in zip(self.indices_[i], self.distances_[i]):
                # Compute Gaussian similarity weight
                weight = np.exp(-0.5 * (dist / self.sigma) ** 2)
                rows.append(i)
                cols.append(j)
                data.append(weight)
        
        # Create sparse adjacency matrix
        self.adjacency_matrix_ = csr_matrix(
            (data, (rows, cols)),
            shape=(n_samples, n_samples)
        )
        
        # Symmetrize the graph (make it undirected)
        if self.symmetrize:
            self.adjacency_matrix_ = 0.5 * (
                self.adjacency_matrix_ + self.adjacency_matrix_.T
            )
        
        # Compute degree matrix
        degrees = self.adjacency_matrix_.sum(axis=1).A1
        self.degree_matrix_ = diags(degrees)
        
        # Compute Laplacian matrix
        self._compute_laplacian()
        
        return self
    
    def _compute_laplacian(self) -> None:
        """
        Compute the graph Laplacian matrix.
        
        This method computes both the unnormalized and normalized
        graph Laplacian matrices.
        """
        # Unnormalized Laplacian: L = D - A
        self.laplacian_ = self.degree_matrix_ - self.adjacency_matrix_
        
        # Normalized Laplacian: L_norm = I - D^{-1/2} A D^{-1/2}
        # Handle zero degrees
        degrees = self.degree_matrix_.diagonal()
        n_samples = self.X_.shape[0]
        
        # Manually construct normalized Laplacian to ensure diagonal elements are exactly 1.0
        rows, cols, data = [], [], []
        
        # Add diagonal elements (all 1.0)
        for i in range(n_samples):
            rows.append(i)
            cols.append(i)
            data.append(1.0)
        
        # Add off-diagonal elements
        for i, j in zip(*self.adjacency_matrix_.nonzero()):
            if i != j:  # Skip diagonal elements
                weight = self.adjacency_matrix_[i, j]
                if degrees[i] > 0 and degrees[j] > 0:  # Avoid division by zero
                    normalized_weight = -weight / np.sqrt(degrees[i] * degrees[j])
                    rows.append(i)
                    cols.append(j)
                    data.append(normalized_weight)
        
        self.normalized_laplacian_ = csr_matrix((data, (rows, cols)), shape=(n_samples, n_samples))
    
    def get_laplacian(self, normalized: bool = True) -> csr_matrix:
        """
        Get the graph Laplacian matrix.
        
        Parameters
        ----------
        normalized : bool, default=True
            Whether to return the normalized Laplacian
            
        Returns
        -------
        scipy.sparse.csr_matrix
            Graph Laplacian matrix
        """
        if not hasattr(self, 'laplacian_'):
            raise RuntimeError("Graph must be built with fit() before accessing the Laplacian")
            
        return self.normalized_laplacian_ if normalized else self.laplacian_
    
    def get_neighbor_distances(self, X_query: np.ndarray) -> np.ndarray:
        """
        Get distances to nearest neighbors for query points.
        
        Parameters
        ----------
        X_query : np.ndarray
            Query points
            
        Returns
        -------
        np.ndarray
            Mean distances to k nearest neighbors
        """
        if not hasattr(self, 'knn_'):
            raise RuntimeError("Graph must be built with fit() before querying neighbors")
            
        # Get distances to k nearest neighbors
        distances, _ = self.knn_.kneighbors(X_query)
        
        # Return mean distance to neighbors as a confidence proxy
        return np.mean(distances, axis=1)


class GraphLaplacian:
    """
    Utility class for Graph Laplacian operations.
    
    This class provides methods for manipulating graph Laplacian matrices
    and performing regularization operations.
    
    Parameters
    ----------
    gamma : float, default=0.1
        Graph regularization strength
        
    Attributes
    ----------
    gamma : float
        Graph regularization strength
    """
    
    def __init__(self, gamma: float = 0.1):
        self.gamma = gamma
    
    def smooth_targets(
        self, 
        y: np.ndarray,
        laplacian: csr_matrix,
        mask: Optional[np.ndarray] = None
    ) -> np.ndarray:
        """
        Apply Laplacian smoothing to target values.
        
        Parameters
        ----------
        y : np.ndarray
            Target values, with possible missing values
        laplacian : scipy.sparse.csr_matrix
            Graph Laplacian matrix
        mask : np.ndarray, optional
            Boolean mask indicating known labels (True) vs unknown (False)
            
        Returns
        -------
        np.ndarray
            Smoothed target values
        """
        if mask is None:
            mask = ~np.isnan(y)
        
        # Create target vector, initializing unknown values to 0
        y_smooth = y.copy()
        y_smooth[~mask] = 0
        
        # Create diagonal matrix for label mask
        mask_diag = diags(mask.astype(float))
        
        # Solve linear system for smoothed labels:
        # (M + γL)y = My₀, where M is the mask diagonal and y₀ is the initial labels
        # This is equivalent to: y = (I + γM⁻¹L)⁻¹y₀
        # For unlabeled points (mask=0), this simplifies to minimizing y'Ly
        A = mask_diag + self.gamma * laplacian
        b = mask_diag @ y_smooth
        
        # Solve the linear system
        y_smooth = spsolve(A, b)
        
        return y_smooth
