"""
Faithful streaming baselines using River library for concept drift evaluation.
Implements proper streaming algorithms with drift detection and adaptation.
"""

import numpy as np
from typing import Dict, List, Tuple, Any
from river import forest, tree, drift, compose, preprocessing, linear_model, metrics


class StreamingBaseline:
    """Base class for streaming baselines with proper streaming evaluation."""
    
    def __init__(self, name: str):
        self.name = name
        self.model = None
        self.drift_detector = None
        self.scaler = None
        self.reset()
    
    def reset(self):
        """Reset the model and drift detector."""
        self._initialize_model()
        
    def _initialize_model(self):
        """Initialize the specific model - to be overridden."""
        raise NotImplementedError
        
    def fit_predict(self, X: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, Dict[str, Any]]:
        """
        Proper streaming evaluation: predict then learn (prequential evaluation).
        Returns predictions and metadata about drift detection.
        """
        predictions = []
        drift_points = []
        
        for i, (xi, yi) in enumerate(zip(X, y)):
            # Convert to dict format for River
            xi_dict = {f'x{j}': xi[j] for j in range(len(xi))}
            
            # Predict first (proper streaming evaluation)
            if hasattr(self.model, 'predict_proba_one'):
                pred_proba = self.model.predict_proba_one(xi_dict)
                pred = 1 if pred_proba.get(True, pred_proba.get(1, 0)) > 0.5 else 0
            else:
                pred = self.model.predict_one(xi_dict)
                pred = 1 if pred else 0
            
            predictions.append(pred)
            
            # Then learn from the true label
            self.model.learn_one(xi_dict, bool(yi))
            
            # Update drift detector if available
            if self.drift_detector is not None:
                # Use prediction error as drift signal
                error = abs(pred - yi)
                self.drift_detector.update(error)
                
                if self.drift_detector.drift_detected:
                    drift_points.append(i)
                    # Reset model on drift detection
                    self._initialize_model()
                    
        return np.array(predictions), {
            'drift_points': drift_points,
            'n_drifts_detected': len(drift_points)
        }


class AdaptiveRandomForest(StreamingBaseline):
    """Proper Adaptive Random Forest implementation using River."""
    
    def __init__(self, n_models: int = 10, max_features: str = 'sqrt'):
        super().__init__("Adaptive Random Forest")
        self.n_models = n_models
        self.max_features = max_features
        
    def _initialize_model(self):
        # Proper ARF with drift detection per tree
        self.model = forest.ARFClassifier(
            n_models=self.n_models,
            max_features=self.max_features,
            drift_detector=drift.ADWIN(),
            warning_detector=drift.ADWIN(delta=0.01),
            seed=42
        )


class HoeffdingAdaptiveTree(StreamingBaseline):
    """Proper Hoeffding Adaptive Tree with ADWIN drift detection."""
    
    def __init__(self):
        super().__init__("Hoeffding Adaptive Tree")
        
    def _initialize_model(self):
        # HAT with ADWIN drift detection
        self.model = tree.HoeffdingAdaptiveTreeClassifier(
            split_criterion='gini',
            drift_detector=drift.ADWIN(delta=0.002),
            switch_significance=0.05,
            seed=42
        )


class StreamingSVM(StreamingBaseline):
    """Streaming SVM using Passive-Aggressive algorithm."""
    
    def __init__(self, C: float = 1.0):
        super().__init__("Streaming SVM")
        self.C = C
        
    def _initialize_model(self):
        # PA classifier with preprocessing
        self.model = compose.Pipeline(
            preprocessing.StandardScaler(),
            linear_model.PAClassifier(C=self.C, mode=1)
        )
        self.drift_detector = drift.ADWIN(delta=0.002)


class StreamingLogisticRegression(StreamingBaseline):
    """Streaming Logistic Regression with SGD."""
    
    def __init__(self, learning_rate: float = 0.01):
        super().__init__("Streaming Logistic Regression")
        self.learning_rate = learning_rate
        
    def _initialize_model(self):
        # SGD Logistic Regression with preprocessing
        self.model = compose.Pipeline(
            preprocessing.StandardScaler(),
            linear_model.LogisticRegression(lr=self.learning_rate)
        )
        self.drift_detector = drift.ADWIN(delta=0.002)


class EWMA_BasedClassifier(StreamingBaseline):
    """Exponentially Weighted Moving Average classifier for comparison."""
    
    def __init__(self, alpha: float = 0.5):
        super().__init__("EWMA Classifier")
        self.alpha = alpha
        
    def _initialize_model(self):
        # Simple EWMA-based approach
        self.model = compose.Pipeline(
            preprocessing.StandardScaler(),
            linear_model.PAClassifier(C=1.0, mode=1)
        )
        # Use DDM for drift detection (more sensitive)
        self.drift_detector = drift.DDM()


def get_streaming_baselines() -> List[StreamingBaseline]:
    """Get all proper streaming baselines for evaluation."""
    return [
        AdaptiveRandomForest(n_models=10),
        HoeffdingAdaptiveTree(),
        StreamingSVM(C=1.0),
        StreamingLogisticRegression(learning_rate=0.01),
        EWMA_BasedClassifier(alpha=0.5)
    ]


def evaluate_streaming_baseline(baseline: StreamingBaseline, 
                               X: np.ndarray, 
                               y: np.ndarray,
                               window_size: int = 200) -> Dict[str, Any]:
    """
    Evaluate a streaming baseline with proper prequential evaluation.
    Returns comprehensive metrics including per-window performance.
    """
    baseline.reset()
    
    # Full stream evaluation
    predictions, metadata = baseline.fit_predict(X, y)
    
    # Window-based evaluation
    n_windows = len(X) // window_size
    window_accuracies = []
    window_f1_scores = []
    
    for i in range(n_windows):
        start_idx = i * window_size
        end_idx = (i + 1) * window_size
        
        window_preds = predictions[start_idx:end_idx]
        window_true = y[start_idx:end_idx]
        
        # Window accuracy
        window_acc = np.mean(window_preds == window_true)
        window_accuracies.append(window_acc)
        
        # Window F1 score
        tp = np.sum((window_preds == 1) & (window_true == 1))
        fp = np.sum((window_preds == 1) & (window_true == 0))
        fn = np.sum((window_preds == 0) & (window_true == 1))
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        window_f1_scores.append(f1)
    
    # Overall metrics
    overall_accuracy = np.mean(predictions == y)
    
    # Calculate macro-F1
    tp_total = np.sum((predictions == 1) & (y == 1))
    fp_total = np.sum((predictions == 1) & (y == 0))
    fn_total = np.sum((predictions == 0) & (y == 1))
    
    precision_total = tp_total / (tp_total + fp_total) if (tp_total + fp_total) > 0 else 0
    recall_total = tp_total / (tp_total + fn_total) if (tp_total + fn_total) > 0 else 0
    macro_f1 = 2 * precision_total * recall_total / (precision_total + recall_total) if (precision_total + recall_total) > 0 else 0
    
    return {
        'method': baseline.name,
        'mean_accuracy': overall_accuracy,
        'macro_f1': macro_f1,
        'worst_window_accuracy': np.min(window_accuracies) if window_accuracies else 0.0,
        'best_window_accuracy': np.max(window_accuracies) if window_accuracies else 0.0,
        'window_accuracies': window_accuracies,
        'window_f1_scores': window_f1_scores,
        'n_drifts_detected': metadata['n_drifts_detected'],
        'drift_points': metadata['drift_points'],
        'predictions': predictions
    }


# River-compatible dataset generators
def generate_sea_stream(n_samples: int = 2000, drift_points: List[int] = None) -> Tuple[np.ndarray, np.ndarray]:
    """Generate SEA concept drift dataset compatible with streaming evaluation."""
    from river.datasets import synth
    
    if drift_points is None:
        drift_points = [n_samples // 3, 2 * n_samples // 3]
    
    X, y = [], []
    
    # Generate data with concept drift
    dataset = synth.SEA(variant=0, seed=42)
    
    for i, (xi, yi) in enumerate(dataset.take(n_samples)):
        # Change variant at drift points to simulate concept drift
        if i in drift_points:
            variant = (dataset.variant + 1) % 4
            dataset = synth.SEA(variant=variant, seed=42 + i)
            
        X.append([xi['x0'], xi['x1'], xi['x2']])
        y.append(1 if yi else 0)
    
    return np.array(X), np.array(y)


def generate_rotating_hyperplane_stream(n_samples: int = 2000, n_features: int = 4) -> Tuple[np.ndarray, np.ndarray]:
    """Generate rotating hyperplane dataset."""
    from river.datasets import synth
    
    X, y = [], []
    
    dataset = synth.Hyperplane(n_features=n_features, change_speed=0.001, mag_change=0.0, seed=42)
    
    for xi, yi in dataset.take(n_samples):
        X.append([xi[f'x{i}'] for i in range(n_features)])
        y.append(1 if yi else 0)
    
    return np.array(X), np.array(y)