"""
LapBoost Model Implementation

This module contains the core LapBoost algorithm implementations for semi-supervised
learning with Graph Laplacian Tree Alternating Optimization (LapTAO).

Classes:
    LapBoost: Base class for LapBoost models
    LapBoostClassifier: LapBoost implementation for classification tasks
    LapBoostRegressor: LapBoost implementation for regression tasks
"""

import numpy as np
from typing import Optional, Tuple, Dict, Any, Union, List
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.utils.validation import check_X_y, check_array
from sklearn.preprocessing import LabelEncoder
import xgboost as xgb

from lapboost.core.graph import GraphConstructor
from lapboost.core.optimization import LapTAO
from lapboost.utils.validation import check_inputs, validate_parameters


class LapBoost(BaseEstimator):
    """
    Base class for LapBoost models.
    
    LapBoost is a semi-supervised learning algorithm that combines XGBoost with
    Graph Laplacian Tree Alternating Optimization (LapTAO) to leverage both
    labeled and unlabeled data for improved predictive performance.
    
    Parameters
    ----------
    k_neighbors : int, default=10
        Number of neighbors for graph construction
    
    gamma : float, default=0.1
        Graph regularization strength
        
    sigma : float, default=1.0
        Bandwidth parameter for Gaussian similarity weights
        
    confidence_threshold : float, default=0.8
        Threshold for pseudo-label confidence
        
    xgb_params : dict, default=None
        Parameters for XGBoost model
        
    max_iter : int, default=3
        Maximum number of pseudo-labeling iterations
        
    verbose : bool, default=False
        Whether to print verbose output
        
    random_state : int, default=None
        Random seed for reproducibility
    
    Attributes
    ----------
    graph_constructor_ : GraphConstructor
        Graph construction object
        
    lap_tao_ : LapTAO
        LapTAO optimization object
        
    xgb_model_ : xgb.Booster
        Trained XGBoost model
    """
    
    def __init__(
        self,
        k_neighbors: int = 10,
        gamma: float = 0.1,
        sigma: float = 1.0,
        confidence_threshold: float = 0.8,
        xgb_params: Optional[Dict[str, Any]] = None,
        max_iter: int = 3,
        verbose: bool = False,
        random_state: Optional[int] = None
    ):
        self.k_neighbors = k_neighbors
        self.gamma = gamma
        self.sigma = sigma
        self.confidence_threshold = confidence_threshold
        self.xgb_params = xgb_params
        self.max_iter = max_iter
        self.verbose = verbose
        self.random_state = random_state

    def _build_graph(self, X: np.ndarray) -> None:
        """
        Build k-NN graph from data.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
        """
        self.graph_constructor_ = GraphConstructor(
            k_neighbors=self.k_neighbors,
            sigma=self.sigma,
            random_state=self.random_state
        )
        self.graph_constructor_.fit(X)
        
        if self.verbose:
            print(f"Built graph with {X.shape[0]} nodes and "
                  f"{self.graph_constructor_.adjacency_matrix_.nnz} edges")
    
    def _train_laptao(self, X: np.ndarray, y: np.ndarray, 
                     sample_weight: Optional[np.ndarray] = None) -> np.ndarray:
        """
        Train LapTAO model and generate graph-regularized predictions.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
        y : np.ndarray
            Target values, with -1 for unlabeled samples
        sample_weight : np.ndarray, optional
            Sample weights
            
        Returns
        -------
        np.ndarray
            Smoothed predictions
        """
        self.lap_tao_ = LapTAO(
            gamma=self.gamma,
            random_state=self.random_state,
            verbose=self.verbose
        )
        
        # Get graph Laplacian from constructor
        laplacian = self.graph_constructor_.get_laplacian()
        
        # Fit LapTAO and get smoothed predictions
        self.lap_tao_.fit(X, y, sample_weight=sample_weight, laplacian=laplacian)
        return self.lap_tao_.get_smoothed_targets()
    
    def _generate_pseudo_labels(self, X_unlabeled: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate pseudo-labels for unlabeled data based on model confidence.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        X_unlabeled : np.ndarray
            Unlabeled feature data
            
        Returns
        -------
        tuple
            Pseudo-labels and confidence scores
        """
        raise NotImplementedError("Subclasses must implement this method")
    
    def fit(self, X_labeled: np.ndarray, y_labeled: np.ndarray, 
            X_unlabeled: np.ndarray) -> 'LapBoost':
        """
        Fit the model to labeled and unlabeled data.
        
        Parameters
        ----------
        X_labeled : np.ndarray
            Labeled features
        y_labeled : np.ndarray
            Labels for X_labeled
        X_unlabeled : np.ndarray
            Unlabeled features
            
        Returns
        -------
        self : LapBoost
            Fitted model
        """
        # Input validation
        X_labeled, y_labeled = check_inputs(X_labeled, y_labeled, multi_output=False)
        X_unlabeled = check_array(X_unlabeled, ensure_min_features=X_labeled.shape[1])
        
        # Check feature dimensions match
        if X_labeled.shape[1] != X_unlabeled.shape[1]:
            raise ValueError("X_labeled and X_unlabeled must have the same number of features")
            
        # Combine labeled and unlabeled data
        X_combined = np.vstack([X_labeled, X_unlabeled])
        y_combined = np.concatenate([
            y_labeled, 
            np.full(X_unlabeled.shape[0], -1)  # -1 for unlabeled samples
        ])
        
        # Initialize sample weights (1 for labeled, 0 for unlabeled)
        sample_weights = np.concatenate([
            np.ones(X_labeled.shape[0]),
            np.zeros(X_unlabeled.shape[0])
        ])
        
        # First, train an initial model on labeled data only
        # This is necessary to generate initial pseudo-labels
        self._final_fit(X_labeled, y_labeled)
        
        # Build graph from combined data
        self._build_graph(X_combined)
        
        # Apply LapTAO to get smoothed targets
        smoothed_targets = self._train_laptao(X_combined, y_combined, sample_weights)
        
        # Update pseudo-labels with smoothed targets for unlabeled data
        n_labeled = X_labeled.shape[0]
        y_pseudo = smoothed_targets[n_labeled:]
        
        # Now we can generate pseudo-labels since we have a trained model
        pseudo_labels, confidences = self._generate_pseudo_labels(X_unlabeled)
        high_confidence_idx = np.where(confidences >= self.confidence_threshold)[0]
        
        if len(high_confidence_idx) > 0:
            if self.verbose:
                print(f"Using {len(high_confidence_idx)} high-confidence pseudo-labels "
                      f"out of {X_unlabeled.shape[0]} unlabeled samples "
                      f"({len(high_confidence_idx) / X_unlabeled.shape[0] * 100:.2f}%)")
            
            # Update sample weights for pseudo-labeled samples
            for idx in high_confidence_idx:
                sample_weights[n_labeled + idx] = confidences[idx]
            
            # Update pseudo-labels for high confidence samples
            y_combined[n_labeled:] = pseudo_labels
            
            # Train final model on labeled + pseudo-labeled data
            self._final_fit(X_combined, y_combined, sample_weights)
        else:
            if self.verbose:
                print("No high-confidence pseudo-labels found. "
                      "Using model trained on labeled data only.")
        
        return self
    
    def _final_fit(self, X: np.ndarray, y: np.ndarray, 
                   sample_weight: Optional[np.ndarray] = None) -> None:
        """
        Final model fit using XGBoost.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        X : np.ndarray
            Features
        y : np.ndarray
            Targets
        sample_weight : np.ndarray, optional
            Sample weights
        """
        raise NotImplementedError("Subclasses must implement this method")
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict using the fitted model.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
            
        Returns
        -------
        np.ndarray
            Predictions
        """
        raise NotImplementedError("Subclasses must implement this method")


class LapBoostClassifier(LapBoost, ClassifierMixin):
    """
    LapBoost implementation for classification tasks.
    
    This class extends the base LapBoost algorithm for classification problems,
    leveraging both labeled and unlabeled data through graph-based regularization.
    
    Parameters
    ----------
    k_neighbors : int, default=10
        Number of neighbors for graph construction
    
    gamma : float, default=0.1
        Graph regularization strength
        
    sigma : float, default=1.0
        Bandwidth parameter for Gaussian similarity weights
        
    confidence_threshold : float, default=0.8
        Threshold for pseudo-label confidence
        
    xgb_params : dict, default=None
        Parameters for XGBoost model
        
    max_iter : int, default=3
        Maximum number of pseudo-labeling iterations
        
    verbose : bool, default=False
        Whether to print verbose output
        
    random_state : int, default=None
        Random seed for reproducibility
    """
    
    def __init__(
        self,
        k_neighbors: int = 10,
        gamma: float = 0.1,
        sigma: float = 1.0,
        confidence_threshold: float = 0.8,
        xgb_params: Optional[Dict[str, Any]] = None,
        max_iter: int = 3,
        verbose: bool = False,
        random_state: Optional[int] = None
    ):
        super().__init__(
            k_neighbors=k_neighbors,
            gamma=gamma,
            sigma=sigma,
            confidence_threshold=confidence_threshold,
            xgb_params=xgb_params,
            max_iter=max_iter,
            verbose=verbose,
            random_state=random_state
        )
        
        # Default XGBoost parameters for classification
        self.default_xgb_params = {
            'objective': 'multi:softprob',
            'eval_metric': 'mlogloss',
            'seed': self.random_state,
            'nthread': -1,
            'eta': 0.1,
            'max_depth': 6,
            'min_child_weight': 1,
            'subsample': 0.8,
            'colsample_bytree': 0.8
        }
        
    def fit(self, X_labeled: np.ndarray, y_labeled: np.ndarray, 
            X_unlabeled: np.ndarray) -> 'LapBoostClassifier':
        """
        Fit the classifier on labeled and unlabeled data.
        
        Parameters
        ----------
        X_labeled : np.ndarray
            Labeled features
        y_labeled : np.ndarray
            Labels for X_labeled
        X_unlabeled : np.ndarray
            Unlabeled features
            
        Returns
        -------
        self : LapBoostClassifier
            Fitted classifier
        """
        # Encode target labels
        self.label_encoder_ = LabelEncoder()
        y_encoded = self.label_encoder_.fit_transform(y_labeled)
        
        # Get number of classes
        self.n_classes_ = len(self.label_encoder_.classes_)
        
        # Update XGBoost parameters with number of classes
        if self.xgb_params is None:
            self.xgb_params_ = self.default_xgb_params.copy()
        else:
            self.xgb_params_ = self.default_xgb_params.copy()
            self.xgb_params_.update(self.xgb_params)
        
        self.xgb_params_['num_class'] = self.n_classes_
        
        # Call parent fit method
        return super().fit(X_labeled, y_encoded, X_unlabeled)
    
    def _generate_pseudo_labels(self, X_unlabeled: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate pseudo-labels for unlabeled data based on classifier confidence.
        
        Parameters
        ----------
        X_unlabeled : np.ndarray
            Unlabeled feature data
            
        Returns
        -------
        tuple
            Pseudo-labels and confidence scores
        """
        # Get probability predictions
        probas = self.predict_proba(X_unlabeled)
        
        # Get most confident class and its probability
        pseudo_labels = np.argmax(probas, axis=1)
        confidences = np.max(probas, axis=1)
        
        return pseudo_labels, confidences
    
    def _final_fit(self, X: np.ndarray, y: np.ndarray, 
                   sample_weight: Optional[np.ndarray] = None) -> None:
        """
        Final model fit using XGBoost classifier.
        
        Parameters
        ----------
        X : np.ndarray
            Features
        y : np.ndarray
            Targets
        sample_weight : np.ndarray, optional
            Sample weights
        """
        # Initialize and train XGBoost classifier
        self.xgb_model_ = xgb.XGBClassifier(**self.xgb_params_)
        self.xgb_model_.fit(X, y, sample_weight=sample_weight)
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict class labels.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
            
        Returns
        -------
        np.ndarray
            Predicted class labels
        """
        X = check_array(X)
        # Get probabilities and convert to class indices
        probas = self.xgb_model_.predict_proba(X)
        y_pred = np.argmax(probas, axis=1)
        # Convert back to original labels
        return self.label_encoder_.inverse_transform(y_pred)
    
    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """
        Predict class probabilities.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
            
        Returns
        -------
        np.ndarray
            Class probabilities
        """
        X = check_array(X)
        return self.xgb_model_.predict_proba(X)


class LapBoostRegressor(LapBoost, RegressorMixin):
    """
    LapBoost implementation for regression tasks.
    
    This class extends the base LapBoost algorithm for regression problems,
    leveraging both labeled and unlabeled data through graph-based regularization.
    
    Parameters
    ----------
    k_neighbors : int, default=10
        Number of neighbors for graph construction
    
    gamma : float, default=0.1
        Graph regularization strength
        
    sigma : float, default=1.0
        Bandwidth parameter for Gaussian similarity weights
        
    confidence_threshold : float, default=0.8
        Threshold for pseudo-label confidence (based on prediction variance)
        
    xgb_params : dict, default=None
        Parameters for XGBoost model
        
    max_iter : int, default=3
        Maximum number of pseudo-labeling iterations
        
    verbose : bool, default=False
        Whether to print verbose output
        
    random_state : int, default=None
        Random seed for reproducibility
    """
    
    def __init__(
        self,
        k_neighbors: int = 10,
        gamma: float = 0.1,
        sigma: float = 1.0,
        confidence_threshold: float = 0.2,  # Lower threshold for regression
        xgb_params: Optional[Dict[str, Any]] = None,
        max_iter: int = 3,
        verbose: bool = False,
        random_state: Optional[int] = None
    ):
        super().__init__(
            k_neighbors=k_neighbors,
            gamma=gamma,
            sigma=sigma,
            confidence_threshold=confidence_threshold,
            xgb_params=xgb_params,
            max_iter=max_iter,
            verbose=verbose,
            random_state=random_state
        )
        
        # Default XGBoost parameters for regression
        self.default_xgb_params = {
            'objective': 'reg:squarederror',
            'eval_metric': 'rmse',
            'seed': self.random_state,
            'nthread': -1,
            'eta': 0.1,
            'max_depth': 6,
            'min_child_weight': 1,
            'subsample': 0.8,
            'colsample_bytree': 0.8
        }
        
    def _generate_pseudo_labels(self, X_unlabeled: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate pseudo-labels for unlabeled data for regression.
        
        For regression, we use prediction variance to estimate confidence.
        Lower variance indicates higher confidence.
        
        Parameters
        ----------
        X_unlabeled : np.ndarray
            Unlabeled feature data
            
        Returns
        -------
        tuple
            Pseudo-labels (predictions) and confidence scores (1/variance)
        """
        # Get predictions for unlabeled data
        predictions = self.predict(X_unlabeled)
        
        # For regression, we need to estimate prediction uncertainty
        # Here we use a simple heuristic: use the distance to nearest neighbors
        # as a proxy for uncertainty
        
        # Get distances to nearest neighbors
        neighbor_dists = self.graph_constructor_.get_neighbor_distances(X_unlabeled)
        
        # Normalize distances to [0, 1] range and invert to get confidence
        max_dist = np.max(neighbor_dists) if len(neighbor_dists) > 0 else 1.0
        uncertainties = neighbor_dists / max_dist
        confidences = 1.0 - uncertainties
        
        return predictions, confidences
    
    def _final_fit(self, X: np.ndarray, y: np.ndarray, 
                   sample_weight: Optional[np.ndarray] = None) -> None:
        """
        Final model fit using XGBoost regressor.
        
        Parameters
        ----------
        X : np.ndarray
            Features
        y : np.ndarray
            Targets
        sample_weight : np.ndarray, optional
            Sample weights
        """
        # Initialize and train XGBoost regressor
        if self.xgb_params is None:
            self.xgb_params_ = self.default_xgb_params.copy()
        else:
            self.xgb_params_ = self.default_xgb_params.copy()
            self.xgb_params_.update(self.xgb_params)
            
        self.xgb_model_ = xgb.XGBRegressor(**self.xgb_params_)
        self.xgb_model_.fit(X, y, sample_weight=sample_weight)
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict regression targets.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
            
        Returns
        -------
        np.ndarray
            Predicted values
        """
        X = check_array(X)
        return self.xgb_model_.predict(X)
