"""
Iterative Co-Training Module for LapBoost

This module implements iterative co-training strategies for LapBoost,
which progressively incorporates high-confidence pseudo-labels into
the training set to improve model performance.

Classes:
    IterativeLapBoostBase: Base class for iterative LapBoost models
    IterativeLapBoostClassifier: Iterative LapBoost for classification
    IterativeLapBoostRegressor: Iterative LapBoost for regression
"""

import numpy as np
from typing import Optional, Tuple, Dict, Any, Union, List
from sklearn.base import ClassifierMixin, RegressorMixin
from sklearn.preprocessing import LabelEncoder
import xgboost as xgb

from lapboost.core.model import LapBoost, LapBoostClassifier, LapBoostRegressor
from lapboost.utils.metrics import confidence_metrics


class IterativeLapBoostBase(LapBoost):
    """
    Base class for iterative co-training with LapBoost.
    
    This class extends the base LapBoost algorithm with an iterative
    co-training procedure that gradually incorporates high-confidence
    pseudo-labels into the training set.
    
    Parameters
    ----------
    k_neighbors : int, default=10
        Number of neighbors for graph construction
    
    gamma : float, default=0.1
        Graph regularization strength
        
    sigma : float, default=1.0
        Bandwidth parameter for Gaussian similarity weights
        
    confidence_threshold : float, default=0.8
        Initial threshold for pseudo-label confidence
        
    confidence_decay : float, default=0.95
        Factor by which to decay the confidence threshold each iteration
        
    max_iter : int, default=5
        Maximum number of co-training iterations
        
    min_pseudo_labels : int, default=10
        Minimum number of pseudo-labels to add each iteration
        
    min_improvement : float, default=0.001
        Minimum improvement in validation metric to continue iterations
        
    validation_fraction : float, default=0.1
        Fraction of labeled data to use for validation
        
    early_stopping : bool, default=True
        Whether to stop iterations early if performance doesn't improve
        
    xgb_params : dict, default=None
        Parameters for XGBoost model
        
    verbose : bool, default=False
        Whether to print verbose output
        
    random_state : int, default=None
        Random seed for reproducibility
        
    Attributes
    ----------
    models_ : list
        List of trained models from each iteration
        
    confidence_thresholds_ : list
        Confidence thresholds used in each iteration
        
    pseudo_label_counts_ : list
        Number of pseudo-labels used in each iteration
        
    performance_history_ : list
        Performance metrics for each iteration
    """
    
    def __init__(
        self,
        k_neighbors: int = 10,
        gamma: float = 0.1,
        sigma: float = 1.0,
        confidence_threshold: float = 0.8,
        confidence_decay: float = 0.95,
        max_iter: int = 5,
        min_pseudo_labels: int = 10,
        min_improvement: float = 0.001,
        validation_fraction: float = 0.1,
        early_stopping: bool = True,
        xgb_params: Optional[Dict[str, Any]] = None,
        verbose: bool = False,
        random_state: Optional[int] = None
    ):
        super().__init__(
            k_neighbors=k_neighbors,
            gamma=gamma,
            sigma=sigma,
            confidence_threshold=confidence_threshold,
            xgb_params=xgb_params,
            max_iter=max_iter,
            verbose=verbose,
            random_state=random_state
        )
        
        self.confidence_decay = confidence_decay
        self.min_pseudo_labels = min_pseudo_labels
        self.min_improvement = min_improvement
        self.validation_fraction = validation_fraction
        self.early_stopping = early_stopping
        
    def fit(self, X_labeled: np.ndarray, y_labeled: np.ndarray, 
            X_unlabeled: np.ndarray) -> 'IterativeLapBoostBase':
        """
        Fit the iterative co-training model.
        
        This method implements the iterative co-training procedure:
        1. Train an initial model on labeled data
        2. Generate pseudo-labels for unlabeled data
        3. Add high-confidence pseudo-labels to the training set
        4. Repeat until convergence or maximum iterations
        
        Parameters
        ----------
        X_labeled : np.ndarray
            Labeled features
        y_labeled : np.ndarray
            Labels for X_labeled
        X_unlabeled : np.ndarray
            Unlabeled features
            
        Returns
        -------
        self : IterativeLapBoostBase
            Fitted model
        """
        # Initialize tracking variables
        self.models_ = []
        self.confidence_thresholds_ = []
        self.pseudo_label_counts_ = []
        self.performance_history_ = []
        
        # Split labeled data into train and validation sets
        if self.validation_fraction > 0:
            n_val = max(1, int(X_labeled.shape[0] * self.validation_fraction))
            rng = np.random.RandomState(self.random_state)
            val_indices = rng.choice(X_labeled.shape[0], n_val, replace=False)
            train_indices = np.setdiff1d(np.arange(X_labeled.shape[0]), val_indices)
            
            X_train = X_labeled[train_indices]
            y_train = y_labeled[train_indices]
            X_val = X_labeled[val_indices]
            y_val = y_labeled[val_indices]
        else:
            X_train = X_labeled
            y_train = y_labeled
            X_val = X_labeled
            y_val = y_labeled
            
        # Current training data and pseudo-labels
        X_current = X_train
        y_current = y_train
        
        # Track unlabeled data and their pseudo-labels
        X_remaining = X_unlabeled.copy()
        used_indices = np.array([], dtype=int)
        
        # Current confidence threshold
        current_threshold = self.confidence_threshold
        
        # Iterative co-training
        for iter_idx in range(self.max_iter):
            if self.verbose:
                print(f"\nIteration {iter_idx + 1}/{self.max_iter}")
                print(f"Training data size: {X_current.shape[0]}")
                print(f"Unlabeled data remaining: {X_remaining.shape[0]}")
                print(f"Current confidence threshold: {current_threshold:.4f}")
                
            # Train model on current labeled data
            model = self._create_base_model()
            
            # Call the parent class fit method to train on current data
            # Use a small subset of the data for unlabeled since empty arrays cause errors
            # We use a single sample from X_current to ensure dimensions match and min samples requirement is met
            dummy_unlabeled = X_current[:1] if X_current.shape[0] > 0 else X_labeled[:1]
            model.fit(X_current, y_current, dummy_unlabeled)
            
            # Evaluate on validation set
            val_metrics = self._evaluate_model(model, X_val, y_val)
            
            if self.verbose:
                print(f"Validation metrics: {val_metrics}")
                
            # Save model and metrics
            self.models_.append(model)
            self.confidence_thresholds_.append(current_threshold)
            self.performance_history_.append(val_metrics)
            
            # Check early stopping
            if (iter_idx > 0 and self.early_stopping and
                self._check_early_stopping(self.performance_history_)):
                if self.verbose:
                    print("Early stopping: No improvement in validation performance")
                break
                
            # Stop if no unlabeled data remains
            if X_remaining.shape[0] == 0:
                if self.verbose:
                    print("No unlabeled data remaining")
                break
                
            # Generate pseudo-labels for remaining unlabeled data
            pseudo_labels, confidences = self._generate_pseudo_labels(model, X_remaining)
            
            # Find high-confidence pseudo-labels
            high_conf_idx = np.where(confidences >= current_threshold)[0]
            
            # If not enough high-confidence labels, take top-k
            if len(high_conf_idx) < self.min_pseudo_labels and X_remaining.shape[0] > 0:
                top_k = min(self.min_pseudo_labels, X_remaining.shape[0])
                high_conf_idx = np.argsort(confidences)[-top_k:]
                
            # Skip iteration if no high-confidence pseudo-labels
            if len(high_conf_idx) == 0:
                if self.verbose:
                    print("No high-confidence pseudo-labels found")
                current_threshold *= self.confidence_decay
                continue
                
            # Add high-confidence pseudo-labels to training data
            X_pseudo = X_remaining[high_conf_idx]
            y_pseudo = pseudo_labels[high_conf_idx]
            conf_pseudo = confidences[high_conf_idx]
            
            # Update training data
            X_current = np.vstack([X_current, X_pseudo])
            y_current = np.concatenate([y_current, y_pseudo])
            
            # Update unlabeled data (remove used samples)
            mask = np.ones(X_remaining.shape[0], dtype=bool)
            mask[high_conf_idx] = False
            X_remaining = X_remaining[mask]
            
            # Track used indices
            used_indices = np.append(used_indices, high_conf_idx)
            
            # Save pseudo-label count
            self.pseudo_label_counts_.append(len(high_conf_idx))
            
            if self.verbose:
                print(f"Added {len(high_conf_idx)} high-confidence pseudo-labels")
                print(f"Average confidence: {np.mean(conf_pseudo):.4f}")
                
            # Update confidence threshold for next iteration
            current_threshold *= self.confidence_decay
            
        # Set final model as the best one
        best_model_idx = self._get_best_model_index()
        self.xgb_model_ = self.models_[best_model_idx].xgb_model_
        self.graph_constructor_ = self.models_[best_model_idx].graph_constructor_
        
        if hasattr(self.models_[best_model_idx], 'label_encoder_'):
            self.label_encoder_ = self.models_[best_model_idx].label_encoder_
            
        if self.verbose:
            print(f"\nSelected best model from iteration {best_model_idx + 1}")
            print(f"Final model performance: {self.performance_history_[best_model_idx]}")
            
        return self
    
    def _create_base_model(self) -> LapBoost:
        """
        Create a base LapBoost model for an iteration.
        
        Must be implemented by subclasses.
        
        Returns
        -------
        LapBoost
            Base model instance
        """
        raise NotImplementedError("Subclasses must implement this method")
    
    def _generate_pseudo_labels(self, model: LapBoost, 
                               X_unlabeled: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate pseudo-labels and confidence scores for unlabeled data.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        model : LapBoost
            Trained model
        X_unlabeled : np.ndarray
            Unlabeled feature data
            
        Returns
        -------
        tuple
            Pseudo-labels and confidence scores
        """
        raise NotImplementedError("Subclasses must implement this method")
    
    def _evaluate_model(self, model: LapBoost, X: np.ndarray, 
                       y: np.ndarray) -> Dict[str, float]:
        """
        Evaluate model on validation data.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        model : LapBoost
            Model to evaluate
        X : np.ndarray
            Features
        y : np.ndarray
            Targets
            
        Returns
        -------
        dict
            Evaluation metrics
        """
        raise NotImplementedError("Subclasses must implement this method")
    
    def _check_early_stopping(self, history: List[Dict[str, float]]) -> bool:
        """
        Check if early stopping criteria are met.
        
        Parameters
        ----------
        history : list
            List of performance metrics for each iteration
            
        Returns
        -------
        bool
            True if early stopping should be triggered
        """
        # By default, use first metric key as the performance measure
        metric_key = list(history[0].keys())[0]
        
        # Check if performance has improved
        if len(history) < 2:
            return False
            
        # Check if performance has improved by at least min_improvement
        last_perf = history[-1][metric_key]
        prev_perf = history[-2][metric_key]
        
        # For metrics where higher is better (accuracy, f1, etc.)
        improvement = last_perf - prev_perf
        
        # For metrics where lower is better (error), reverse the comparison
        if metric_key in ['mse', 'mae', 'expected_calibration_error']:
            improvement = -improvement
            
        return improvement < self.min_improvement
    
    def _get_best_model_index(self) -> int:
        """
        Get the index of the best model based on validation performance.
        
        Returns
        -------
        int
            Index of best model
        """
        # By default, use first metric key as the performance measure
        metric_key = list(self.performance_history_[0].keys())[0]
        
        # Extract performance values
        performance = [metrics[metric_key] for metrics in self.performance_history_]
        
        # For metrics where higher is better (accuracy, f1, etc.)
        if metric_key not in ['mse', 'mae', 'expected_calibration_error']:
            return np.argmax(performance)
        else:
            # For metrics where lower is better (error)
            return np.argmin(performance)


class IterativeLapBoostClassifier(IterativeLapBoostBase, ClassifierMixin):
    """
    Iterative LapBoost for classification tasks.
    
    This class extends the iterative co-training procedure for classification
    problems, using progressive pseudo-labeling to improve performance.
    
    See IterativeLapBoostBase for parameters.
    """
    
    def __init__(
        self,
        k_neighbors: int = 10,
        gamma: float = 0.1,
        sigma: float = 1.0,
        confidence_threshold: float = 0.8,
        confidence_decay: float = 0.95,
        max_iter: int = 5,
        min_pseudo_labels: int = 10,
        min_improvement: float = 0.001,
        validation_fraction: float = 0.1,
        early_stopping: bool = True,
        xgb_params: Optional[Dict[str, Any]] = None,
        verbose: bool = False,
        random_state: Optional[int] = None
    ):
        super().__init__(
            k_neighbors=k_neighbors,
            gamma=gamma,
            sigma=sigma,
            confidence_threshold=confidence_threshold,
            confidence_decay=confidence_decay,
            max_iter=max_iter,
            min_pseudo_labels=min_pseudo_labels,
            min_improvement=min_improvement,
            validation_fraction=validation_fraction,
            early_stopping=early_stopping,
            xgb_params=xgb_params,
            verbose=verbose,
            random_state=random_state
        )
        
    def fit(self, X_labeled: np.ndarray, y_labeled: np.ndarray, 
            X_unlabeled: np.ndarray) -> 'IterativeLapBoostClassifier':
        """
        Fit the iterative classifier.
        
        Parameters
        ----------
        X_labeled : np.ndarray
            Labeled features
        y_labeled : np.ndarray
            Labels for X_labeled
        X_unlabeled : np.ndarray
            Unlabeled features
            
        Returns
        -------
        self : IterativeLapBoostClassifier
            Fitted classifier
        """
        # Encode target labels
        self.label_encoder_ = LabelEncoder()
        y_encoded = self.label_encoder_.fit_transform(y_labeled)
        
        # Get number of classes
        self.n_classes_ = len(self.label_encoder_.classes_)
        
        # Call parent fit method with encoded labels
        return super().fit(X_labeled, y_encoded, X_unlabeled)
    
    def _create_base_model(self) -> LapBoostClassifier:
        """
        Create a base LapBoostClassifier for an iteration.
        
        Returns
        -------
        LapBoostClassifier
            Base classifier instance
        """
        return LapBoostClassifier(
            k_neighbors=self.k_neighbors,
            gamma=self.gamma,
            sigma=self.sigma,
            confidence_threshold=self.confidence_threshold,
            xgb_params=self.xgb_params,
            max_iter=1,  # Only one iteration within each base model
            verbose=False,  # No verbose output from base models
            random_state=self.random_state
        )
    
    def _generate_pseudo_labels(self, model: LapBoost, 
                               X_unlabeled: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate pseudo-labels and confidence scores for classification.
        
        Parameters
        ----------
        model : LapBoost
            Trained model
        X_unlabeled : np.ndarray
            Unlabeled feature data
            
        Returns
        -------
        tuple
            Pseudo-labels and confidence scores
        """
        # Get probability predictions
        probas = model.predict_proba(X_unlabeled)
        
        # Get most confident class and its probability
        pseudo_labels = np.argmax(probas, axis=1)
        confidences = np.max(probas, axis=1)
        
        return pseudo_labels, confidences
    
    def _evaluate_model(self, model: LapBoost, X: np.ndarray, 
                       y: np.ndarray) -> Dict[str, float]:
        """
        Evaluate classification model on validation data.
        
        Parameters
        ----------
        model : LapBoost
            Model to evaluate
        X : np.ndarray
            Features
        y : np.ndarray
            Targets
            
        Returns
        -------
        dict
            Evaluation metrics
        """
        from sklearn.metrics import accuracy_score
        
        # Make predictions
        y_pred = model.predict(X)
        
        # For multi-class, try to get probabilities
        try:
            y_proba = model.predict_proba(X)
        except:
            y_proba = None
            
        # Basic accuracy metric
        metrics = {'accuracy': accuracy_score(y, y_pred)}
        
        # Add confidence calibration metrics if possible
        if y_proba is not None:
            confidences = np.max(y_proba, axis=1)
            conf_metrics = confidence_metrics(
                y, y_pred, confidences, task='classification'
            )
            metrics['ece'] = conf_metrics.get('expected_calibration_error', 1.0)
            
        return metrics
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict class labels.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
            
        Returns
        -------
        np.ndarray
            Predicted class labels
        """
        # Make predictions and convert back to original labels
        y_pred = self.xgb_model_.predict(X)
        return self.label_encoder_.inverse_transform(y_pred)
    
    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """
        Predict class probabilities.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
            
        Returns
        -------
        np.ndarray
            Class probabilities
        """
        return self.xgb_model_.predict_proba(X)


class IterativeLapBoostRegressor(IterativeLapBoostBase, RegressorMixin):
    """
    Iterative LapBoost for regression tasks.
    
    This class extends the iterative co-training procedure for regression
    problems, using progressive pseudo-labeling to improve performance.
    
    See IterativeLapBoostBase for parameters.
    """
    
    def __init__(
        self,
        k_neighbors: int = 10,
        gamma: float = 0.1,
        sigma: float = 1.0,
        confidence_threshold: float = 0.2,  # Lower threshold for regression
        confidence_decay: float = 0.95,
        max_iter: int = 5,
        min_pseudo_labels: int = 10,
        min_improvement: float = 0.001,
        validation_fraction: float = 0.1,
        early_stopping: bool = True,
        xgb_params: Optional[Dict[str, Any]] = None,
        verbose: bool = False,
        random_state: Optional[int] = None
    ):
        super().__init__(
            k_neighbors=k_neighbors,
            gamma=gamma,
            sigma=sigma,
            confidence_threshold=confidence_threshold,
            confidence_decay=confidence_decay,
            max_iter=max_iter,
            min_pseudo_labels=min_pseudo_labels,
            min_improvement=min_improvement,
            validation_fraction=validation_fraction,
            early_stopping=early_stopping,
            xgb_params=xgb_params,
            verbose=verbose,
            random_state=random_state
        )
        
    def _create_base_model(self) -> LapBoostRegressor:
        """
        Create a base LapBoostRegressor for an iteration.
        
        Returns
        -------
        LapBoostRegressor
            Base regressor instance
        """
        return LapBoostRegressor(
            k_neighbors=self.k_neighbors,
            gamma=self.gamma,
            sigma=self.sigma,
            confidence_threshold=self.confidence_threshold,
            xgb_params=self.xgb_params,
            max_iter=1,  # Only one iteration within each base model
            verbose=False,  # No verbose output from base models
            random_state=self.random_state
        )
    
    def _generate_pseudo_labels(self, model: LapBoost, 
                               X_unlabeled: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate pseudo-labels and confidence scores for regression.
        
        For regression, confidence is based on the distance to nearest neighbors
        as a proxy for prediction uncertainty.
        
        Parameters
        ----------
        model : LapBoost
            Trained model
        X_unlabeled : np.ndarray
            Unlabeled feature data
            
        Returns
        -------
        tuple
            Pseudo-labels and confidence scores
        """
        # Get predictions for unlabeled data
        predictions = model.predict(X_unlabeled)
        
        # Use distance to nearest neighbors as proxy for uncertainty
        neighbor_dists = model.graph_constructor_.get_neighbor_distances(X_unlabeled)
        
        # Normalize distances to [0, 1] range and invert to get confidence
        max_dist = np.max(neighbor_dists) if len(neighbor_dists) > 0 else 1.0
        uncertainties = neighbor_dists / max_dist
        confidences = 1.0 - uncertainties
        
        return predictions, confidences
    
    def _evaluate_model(self, model: LapBoost, X: np.ndarray, 
                       y: np.ndarray) -> Dict[str, float]:
        """
        Evaluate regression model on validation data.
        
        Parameters
        ----------
        model : LapBoost
            Model to evaluate
        X : np.ndarray
            Features
        y : np.ndarray
            Targets
            
        Returns
        -------
        dict
            Evaluation metrics
        """
        from sklearn.metrics import mean_squared_error, r2_score
        
        # Make predictions
        y_pred = model.predict(X)
        
        # Basic regression metrics
        mse = mean_squared_error(y, y_pred)
        r2 = r2_score(y, y_pred)
        
        metrics = {
            'mse': mse,
            'rmse': np.sqrt(mse),
            'r2': r2
        }
        
        return metrics
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict regression targets.
        
        Parameters
        ----------
        X : np.ndarray
            Input features
            
        Returns
        -------
        np.ndarray
            Predicted values
        """
        return self.xgb_model_.predict(X)
