"""
LapTAO Optimization Module

This module implements the Graph Laplacian Tree Alternating Optimization (LapTAO)
algorithm for semi-supervised learning with graph regularization.

Classes:
    ObliqueTree: Tree implementation with linear splits for LapTAO
    LapTAO: Implementation of LapTAO algorithm
"""

import numpy as np
from typing import Optional, Tuple, Dict, Any, Union
from sklearn.tree import DecisionTreeRegressor
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import spsolve


class ObliqueTree:
    """
    Oblique tree implementation for LapTAO.
    
    This class implements a decision tree with oblique (linear) splits,
    which are more flexible than axis-aligned splits used in standard trees.
    
    Parameters
    ----------
    max_depth : int, default=3
        Maximum depth of the tree
        
    min_samples_split : int, default=2
        Minimum number of samples required to split a node
        
    random_state : int, default=None
        Random seed for reproducibility
        
    Attributes
    ----------
    tree : dict
        Trained tree structure
    """
    
    def __init__(
        self,
        max_depth: int = 3,
        min_samples_split: int = 2,
        random_state: Optional[int] = None
    ):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.random_state = random_state
        self.tree = None
        self.feature_weights = None
        
    def _best_oblique_split(self, X: np.ndarray, y: np.ndarray, 
                           sample_weights: Optional[np.ndarray] = None) -> Tuple[np.ndarray, float, float]:
        """
        Find best oblique split using random projections.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
        y : np.ndarray
            Target values
        sample_weights : np.ndarray, optional
            Sample weights
            
        Returns
        -------
        tuple
            Best projection vector, threshold, and impurity reduction
        """
        n_samples, n_features = X.shape
        
        if sample_weights is None:
            sample_weights = np.ones(n_samples)
            
        # Initialize with standard decision tree to get a good axis-aligned split
        tree = DecisionTreeRegressor(max_depth=1)
        tree.fit(X, y, sample_weight=sample_weights)
        
        # Extract the feature and threshold from the tree
        if tree.tree_.feature[0] == -2:  # Leaf node, no good split found
            return np.zeros(n_features), 0.0, 0.0
            
        best_feature = tree.tree_.feature[0]
        best_threshold = tree.tree_.threshold[0]
        
        # Create initial projection vector (one-hot for best feature)
        projection = np.zeros(n_features)
        projection[best_feature] = 1.0
        
        # Try to improve the projection vector with random perturbations
        best_score = self._compute_split_score(X, y, projection, best_threshold, sample_weights)
        
        # Number of random perturbations to try
        n_perturbations = min(10, n_features)
        rng = np.random.RandomState(self.random_state)
        
        for _ in range(n_perturbations):
            # Create perturbed projection by adding small random values
            perturbed = projection.copy()
            # Add random noise to a few random features
            idx = rng.choice(n_features, size=rng.randint(1, max(2, n_features // 3)), replace=False)
            perturbed[idx] += rng.normal(0, 0.3, size=len(idx))
            
            # Normalize the projection vector
            if np.sum(perturbed**2) > 0:
                perturbed = perturbed / np.sqrt(np.sum(perturbed**2))
                
            # Compute optimal threshold for this projection
            projected_values = X @ perturbed
            sorted_idx = np.argsort(projected_values)
            sorted_y = y[sorted_idx]
            sorted_weights = sample_weights[sorted_idx]
            
            # Find best threshold by scanning through sorted values
            cum_left_weight = 0
            cum_left_sum = 0
            cum_right_weight = np.sum(sorted_weights)
            cum_right_sum = np.sum(sorted_weights * sorted_y)
            
            best_pert_threshold = projected_values[sorted_idx[0]]
            best_pert_score = float('inf')
            
            for i in range(n_samples - 1):
                cum_left_weight += sorted_weights[i]
                cum_left_sum += sorted_weights[i] * sorted_y[i]
                cum_right_weight -= sorted_weights[i]
                cum_right_sum -= sorted_weights[i] * sorted_y[i]
                
                if cum_left_weight < 1e-10 or cum_right_weight < 1e-10:
                    continue
                    
                left_mean = cum_left_sum / cum_left_weight
                right_mean = cum_right_sum / cum_right_weight
                
                left_score = np.sum(sorted_weights[:i+1] * (sorted_y[:i+1] - left_mean)**2)
                right_score = np.sum(sorted_weights[i+1:] * (sorted_y[i+1:] - right_mean)**2)
                
                score = left_score + right_score
                
                if score < best_pert_score:
                    best_pert_score = score
                    best_pert_threshold = 0.5 * (projected_values[sorted_idx[i]] + projected_values[sorted_idx[i+1]])
            
            # Check if this perturbed projection is better
            pert_score = self._compute_split_score(X, y, perturbed, best_pert_threshold, sample_weights)
            
            if pert_score < best_score:
                best_score = pert_score
                projection = perturbed
                best_threshold = best_pert_threshold
                
        # Compute impurity reduction
        total_impurity = np.sum(sample_weights * (y - np.average(y, weights=sample_weights))**2)
        impurity_reduction = (total_impurity - best_score) / total_impurity
                
        return projection, best_threshold, impurity_reduction
    
    def _compute_split_score(self, X: np.ndarray, y: np.ndarray, projection: np.ndarray, 
                            threshold: float, sample_weights: np.ndarray) -> float:
        """
        Compute score (weighted MSE) for a split.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
        y : np.ndarray
            Target values
        projection : np.ndarray
            Projection vector for oblique split
        threshold : float
            Split threshold
        sample_weights : np.ndarray
            Sample weights
            
        Returns
        -------
        float
            Split score (lower is better)
        """
        projected_values = X @ projection
        left_mask = projected_values <= threshold
        right_mask = ~left_mask
        
        left_weights_sum = np.sum(sample_weights[left_mask])
        right_weights_sum = np.sum(sample_weights[right_mask])
        
        if left_weights_sum < 1e-10 or right_weights_sum < 1e-10:
            return float('inf')
            
        left_mean = np.average(y[left_mask], weights=sample_weights[left_mask])
        right_mean = np.average(y[right_mask], weights=sample_weights[right_mask])
        
        left_score = np.sum(sample_weights[left_mask] * (y[left_mask] - left_mean)**2)
        right_score = np.sum(sample_weights[right_mask] * (y[right_mask] - right_mean)**2)
        
        return left_score + right_score
    
    def _build_tree(self, X: np.ndarray, y: np.ndarray, depth: int = 0, 
                  sample_weights: Optional[np.ndarray] = None) -> Dict:
        """
        Recursively build the tree.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
        y : np.ndarray
            Target values
        depth : int
            Current depth
        sample_weights : np.ndarray, optional
            Sample weights
            
        Returns
        -------
        dict
            Tree node
        """
        n_samples = X.shape[0]
        
        if sample_weights is None:
            sample_weights = np.ones(n_samples)
            
        # Create leaf node if stopping criteria are met
        if (depth >= self.max_depth or 
            n_samples < self.min_samples_split or 
            np.all(np.abs(y - y[0]) < 1e-6)):
            
            return {
                'is_leaf': True,
                'value': np.average(y, weights=sample_weights),
                'n_samples': n_samples
            }
            
        # Find best oblique split
        projection, threshold, impurity_reduction = self._best_oblique_split(X, y, sample_weights)
        
        # If no good split found, create leaf node
        if impurity_reduction <= 1e-6:
            return {
                'is_leaf': True,
                'value': np.average(y, weights=sample_weights),
                'n_samples': n_samples
            }
            
        # Split data
        projected_values = X @ projection
        left_mask = projected_values <= threshold
        right_mask = ~left_mask
        
        # Check if split is valid
        if np.sum(left_mask) == 0 or np.sum(right_mask) == 0:
            return {
                'is_leaf': True,
                'value': np.average(y, weights=sample_weights),
                'n_samples': n_samples
            }
            
        # Create split node
        node = {
            'is_leaf': False,
            'projection': projection,
            'threshold': threshold,
            'impurity_reduction': impurity_reduction,
            'n_samples': n_samples,
            'left': self._build_tree(
                X[left_mask], y[left_mask], depth + 1, sample_weights[left_mask]
            ),
            'right': self._build_tree(
                X[right_mask], y[right_mask], depth + 1, sample_weights[right_mask]
            )
        }
        
        return node
    
    def fit(self, X: np.ndarray, y: np.ndarray, 
            sample_weight: Optional[np.ndarray] = None) -> 'ObliqueTree':
        """
        Fit the oblique tree to data.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
        y : np.ndarray
            Target values
        sample_weight : np.ndarray, optional
            Sample weights
            
        Returns
        -------
        self : ObliqueTree
            Fitted tree
        """
        self.tree = self._build_tree(X, y, sample_weights=sample_weight)
        return self
    
    def _predict_sample(self, x: np.ndarray, node: Dict) -> float:
        """
        Predict for a single sample using the fitted tree.
        
        Parameters
        ----------
        x : np.ndarray
            Input feature vector
        node : dict
            Current tree node
            
        Returns
        -------
        float
            Predicted value
        """
        if node['is_leaf']:
            return node['value']
            
        # Apply projection and threshold
        if np.dot(x, node['projection']) <= node['threshold']:
            return self._predict_sample(x, node['left'])
        else:
            return self._predict_sample(x, node['right'])
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict using the fitted tree.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
            
        Returns
        -------
        np.ndarray
            Predictions
        """
        if self.tree is None:
            raise RuntimeError("Tree must be fitted before prediction")
            
        return np.array([self._predict_sample(x, self.tree) for x in X])


class LapTAO:
    """
    Graph Laplacian Tree Alternating Optimization (LapTAO) algorithm.
    
    This class implements the LapTAO algorithm for semi-supervised learning
    with graph regularization. It alternates between smoothing labels using
    the graph Laplacian and fitting oblique trees to the smoothed labels.
    
    Parameters
    ----------
    gamma : float, default=0.1
        Graph regularization strength
        
    max_iter : int, default=5
        Maximum number of alternating optimization iterations
        
    tree_params : dict, default=None
        Parameters for the oblique tree
        
    verbose : bool, default=False
        Whether to print verbose output
        
    random_state : int, default=None
        Random seed for reproducibility
        
    Attributes
    ----------
    tree_ : ObliqueTree
        Fitted oblique tree
        
    smoothed_targets_ : np.ndarray
        Smoothed target values after alternating optimization
    """
    
    def __init__(
        self,
        gamma: float = 0.1,
        max_iter: int = 5,
        tree_params: Optional[Dict[str, Any]] = None,
        verbose: bool = False,
        random_state: Optional[int] = None
    ):
        self.gamma = gamma
        self.max_iter = max_iter
        self.tree_params = tree_params or {}
        self.verbose = verbose
        self.random_state = random_state
    
    def _smooth_targets(self, y: np.ndarray, mask: np.ndarray, 
                       laplacian: csr_matrix) -> np.ndarray:
        """
        Apply Laplacian smoothing to target values.
        
        Parameters
        ----------
        y : np.ndarray
            Target values
        mask : np.ndarray
            Boolean mask for labeled samples (True) vs unlabeled (False)
        laplacian : scipy.sparse.csr_matrix
            Graph Laplacian matrix
            
        Returns
        -------
        np.ndarray
            Smoothed target values
        """
        n_samples = y.shape[0]
        
        # Create diagonal matrix for label mask
        M = csr_matrix((mask.astype(float), (np.arange(n_samples), np.arange(n_samples))), 
                       shape=(n_samples, n_samples))
        
        # Set up linear system: (M + γL)y = My₀
        A = M + self.gamma * laplacian
        b = M @ y
        
        # Solve the linear system
        try:
            smoothed_y = spsolve(A, b)
        except Exception as e:
            if self.verbose:
                print(f"Warning: Error in solving linear system: {e}")
                print("Falling back to simple masking for unlabeled data")
            smoothed_y = y.copy()
            
        return smoothed_y
    
    def fit(self, X: np.ndarray, y: np.ndarray, 
            sample_weight: Optional[np.ndarray] = None,
            laplacian: Optional[csr_matrix] = None) -> 'LapTAO':
        """
        Fit the LapTAO model to data.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
        y : np.ndarray
            Target values, with -1 for unlabeled samples
        sample_weight : np.ndarray, optional
            Sample weights
        laplacian : scipy.sparse.csr_matrix, optional
            Graph Laplacian matrix
            
        Returns
        -------
        self : LapTAO
            Fitted model
        """
        if laplacian is None:
            raise ValueError("Graph Laplacian matrix must be provided")
            
        n_samples = X.shape[0]
        
        # Create mask for labeled samples
        labeled_mask = y != -1
        
        # Initialize target values
        # For unlabeled data, start with mean of labeled data
        current_targets = y.copy()
        labeled_mean = np.mean(y[labeled_mask])
        current_targets[~labeled_mask] = labeled_mean
        
        # Initialize sample weights if not provided
        if sample_weight is None:
            sample_weight = np.ones(n_samples)
            
        # Set weights to 0 for unlabeled data
        tree_weights = sample_weight.copy()
        tree_weights[~labeled_mask] = 0.0
        
        # Alternating optimization
        for iter_idx in range(self.max_iter):
            if self.verbose:
                print(f"LapTAO iteration {iter_idx + 1}/{self.max_iter}")
                
            # Label step: smooth targets using graph Laplacian
            smoothed_targets = self._smooth_targets(current_targets, labeled_mask, laplacian)
            
            # Tree step: fit oblique tree to smoothed targets
            tree = ObliqueTree(
                max_depth=self.tree_params.get('max_depth', 3),
                min_samples_split=self.tree_params.get('min_samples_split', 2),
                random_state=self.random_state
            )
            tree.fit(X, smoothed_targets, sample_weight=tree_weights)
            
            # Update current targets with tree predictions for unlabeled data
            tree_preds = tree.predict(X)
            current_targets[~labeled_mask] = tree_preds[~labeled_mask]
            
            # Keep original labels for labeled data
            current_targets[labeled_mask] = y[labeled_mask]
            
        # Store final tree and smoothed targets
        self.tree_ = tree
        self.smoothed_targets_ = smoothed_targets
        
        return self
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict using the fitted tree.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
            
        Returns
        -------
        np.ndarray
            Predictions
        """
        if not hasattr(self, 'tree_'):
            raise RuntimeError("Model must be fitted before prediction")
            
        return self.tree_.predict(X)
    
    def get_smoothed_targets(self) -> np.ndarray:
        """
        Get smoothed target values after alternating optimization.
        
        Returns
        -------
        np.ndarray
            Smoothed target values
        """
        if not hasattr(self, 'smoothed_targets_'):
            raise RuntimeError("Model must be fitted before accessing smoothed targets")
            
        return self.smoothed_targets_
