"""
Unit tests for LapBoost graph construction components.
"""

import unittest
import numpy as np
from scipy.sparse import csr_matrix

from lapboost.core.graph import GraphConstructor, GraphLaplacian


class TestGraphConstructor(unittest.TestCase):
    """Test cases for GraphConstructor."""
    
    def setUp(self):
        """Set up test data."""
        # Create synthetic data
        np.random.seed(42)
        self.X = np.random.rand(50, 5)
        self.gc = GraphConstructor(k_neighbors=5, sigma=1.0, random_state=42)
        
    def test_fit(self):
        """Test fitting the graph constructor."""
        self.gc.fit(self.X)
        
        # Check that matrices have been created
        self.assertTrue(hasattr(self.gc, 'adjacency_matrix_'))
        self.assertTrue(hasattr(self.gc, 'degree_matrix_'))
        self.assertTrue(hasattr(self.gc, 'laplacian_'))
        
        # Check matrix dimensions
        n_samples = self.X.shape[0]
        self.assertEqual(self.gc.adjacency_matrix_.shape, (n_samples, n_samples))
        self.assertEqual(self.gc.degree_matrix_.shape, (n_samples, n_samples))
        self.assertEqual(self.gc.laplacian_.shape, (n_samples, n_samples))
        
    def test_get_laplacian(self):
        """Test getting the Laplacian matrix."""
        self.gc.fit(self.X)
        
        # Get normalized and unnormalized Laplacian
        L_norm = self.gc.get_laplacian(normalized=True)
        L = self.gc.get_laplacian(normalized=False)
        
        # Check dimensions
        self.assertEqual(L_norm.shape, (self.X.shape[0], self.X.shape[0]))
        self.assertEqual(L.shape, (self.X.shape[0], self.X.shape[0]))
        
        # Check Laplacian properties
        # For normalized Laplacian, diagonal should be close to 1
        diag_norm = L_norm.diagonal()
        self.assertTrue(np.allclose(diag_norm, 1.0))
        
        # For unnormalized Laplacian, row sums should be close to 0
        row_sums = L.sum(axis=1).A1
        self.assertTrue(np.allclose(row_sums, 0.0, atol=1e-10))
        
    def test_get_neighbor_distances(self):
        """Test getting neighbor distances."""
        self.gc.fit(self.X)
        
        # Get distances to same points
        distances = self.gc.get_neighbor_distances(self.X)
        
        # Check dimensions
        self.assertEqual(distances.shape[0], self.X.shape[0])
        
        # Distances should be non-negative
        self.assertTrue(np.all(distances >= 0))


class TestGraphLaplacian(unittest.TestCase):
    """Test cases for GraphLaplacian."""
    
    def setUp(self):
        """Set up test data and Laplacian."""
        # Create synthetic graph
        n = 20
        np.random.seed(42)
        
        # Create a simple adjacency matrix
        rows = np.random.randint(0, n, 100)
        cols = np.random.randint(0, n, 100)
        data = np.random.rand(100)
        
        self.A = csr_matrix((data, (rows, cols)), shape=(n, n))
        self.A = 0.5 * (self.A + self.A.T)  # Make symmetric
        
        # Create degree matrix
        degrees = self.A.sum(axis=1).A1
        self.D = csr_matrix((degrees, (np.arange(n), np.arange(n))), shape=(n, n))
        
        # Create Laplacian
        self.L = self.D - self.A
        
        # Create GraphLaplacian object
        self.gl = GraphLaplacian(gamma=0.1)
        
        # Create sample targets with missing values
        self.y = np.zeros(n)
        mask = np.random.rand(n) > 0.5
        self.y[~mask] = np.nan
        self.mask = mask
        
    def test_smooth_targets(self):
        """Test smoothing targets with Laplacian."""
        # Smooth targets
        y_smooth = self.gl.smooth_targets(self.y, self.L, self.mask)
        
        # Check dimensions
        self.assertEqual(y_smooth.shape[0], self.y.shape[0])
        
        # Check that no NaN values remain
        self.assertFalse(np.any(np.isnan(y_smooth)))
        
        # Original values should be preserved (approximately) at masked positions
        masked_diff = np.abs(y_smooth[self.mask] - self.y[self.mask])
        self.assertTrue(np.all(masked_diff < 1e-10))


if __name__ == '__main__':
    unittest.main()
