#!/usr/bin/env python
"""
Simple baseline methods for comparison without external dependencies.
Implements basic streaming classification methods using only sklearn.
"""

import numpy as np
from typing import Dict, List, Tuple, Any
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import balanced_accuracy_score, f1_score


class SimpleStreamingBaseline:
    """Simple baseline that retrains periodically on sliding window."""
    
    def __init__(self, method_name: str, window_size: int = 200):
        self.method_name = method_name
        self.window_size = window_size
        self.data_buffer = []
        self.label_buffer = []
        self.model = None
        self.scaler = StandardScaler()
        self.fitted = False
        
        # Initialize model based on method name
        if method_name == "svm":
            self.base_model = SVC(random_state=42)
        elif method_name == "random_forest":
            self.base_model = RandomForestClassifier(n_estimators=10, random_state=42)
        elif method_name == "decision_tree":
            self.base_model = DecisionTreeClassifier(random_state=42)
        else:
            self.base_model = SVC(random_state=42)  # Default fallback
    
    def partial_fit(self, X: np.ndarray, y: np.ndarray):
        """Add new data and retrain if buffer is full."""
        # Add to buffer
        for i in range(len(X)):
            self.data_buffer.append(X[i])
            self.label_buffer.append(y[i])
        
        # Keep only recent data
        if len(self.data_buffer) > self.window_size:
            excess = len(self.data_buffer) - self.window_size
            self.data_buffer = self.data_buffer[excess:]
            self.label_buffer = self.label_buffer[excess:]
        
        # Retrain if we have enough data
        if len(self.data_buffer) >= 20:  # Minimum data for training
            try:
                X_train = np.array(self.data_buffer)
                y_train = np.array(self.label_buffer)
                
                # Scale features
                X_train_scaled = self.scaler.fit_transform(X_train)
                
                # Train model
                self.model = self.base_model
                self.model.fit(X_train_scaled, y_train)
                self.fitted = True
                
            except Exception:
                pass  # Skip training on failure
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions."""
        if not self.fitted or self.model is None:
            # Return random predictions if not fitted
            return np.random.choice([0, 1], size=len(X))
        
        try:
            X_scaled = self.scaler.transform(X)
            return self.model.predict(X_scaled)
        except Exception:
            # Return random predictions on failure
            return np.random.choice([0, 1], size=len(X))
    
    def reset(self):
        """Reset the baseline."""
        self.data_buffer = []
        self.label_buffer = []
        self.model = None
        self.scaler = StandardScaler()
        self.fitted = False


def get_simple_baselines() -> List[SimpleStreamingBaseline]:
    """Get list of simple streaming baselines."""
    return [
        SimpleStreamingBaseline("svm"),
        SimpleStreamingBaseline("random_forest"),
        SimpleStreamingBaseline("decision_tree")
    ]


def evaluate_simple_baseline(baseline: SimpleStreamingBaseline, X: np.ndarray, y: np.ndarray,
                           window_size: int = 200) -> Dict[str, Any]:
    """
    Evaluate a simple baseline using prequential protocol.
    
    Args:
        baseline: The baseline method to evaluate
        X: Feature matrix
        y: Label vector
        window_size: Window size for metrics calculation
        
    Returns:
        Evaluation results
    """
    n_samples = len(X)
    predictions = []
    true_labels = []
    window_accuracies = []
    
    # Reset baseline
    baseline.reset()
    
    # Prequential evaluation: test-then-train
    min_train_size = 50
    
    for i in range(n_samples):
        # Predict first (after minimum training size)
        if i >= min_train_size:
            pred = baseline.predict(X[i:i+1])[0]
            predictions.append(pred)
            true_labels.append(y[i])
        
        # Then train on current sample
        baseline.partial_fit(X[i:i+1], y[i:i+1])
        
        # Calculate sliding window metrics
        if len(predictions) >= window_size:
            window_start = len(predictions) - window_size
            window_preds = predictions[window_start:]
            window_true = true_labels[window_start:]
            
            window_acc = np.mean(np.array(window_preds) == np.array(window_true))
            window_accuracies.append(window_acc)
    
    if len(predictions) == 0:
        return {
            'method': baseline.method_name,
            'mean_accuracy': 0.5,
            'worst_window_accuracy': 0.5,
            'macro_f1': 0.33,
            'n_predictions': 0
        }
    
    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    
    # Calculate metrics
    mean_accuracy = np.mean(predictions == true_labels)
    worst_window_accuracy = np.min(window_accuracies) if window_accuracies else mean_accuracy
    macro_f1 = f1_score(true_labels, predictions, average='macro', zero_division=0)
    
    return {
        'method': baseline.method_name,
        'mean_accuracy': mean_accuracy,
        'worst_window_accuracy': worst_window_accuracy,
        'macro_f1': macro_f1,
        'n_predictions': len(predictions),
        'n_windows': len(window_accuracies)
    }


if __name__ == "__main__":
    # Test simple baselines
    print("Testing simple baselines...")
    
    # Generate test data
    np.random.seed(42)
    n_samples = 500
    X = np.random.randn(n_samples, 4)
    y = (X[:, 0] + X[:, 1] > 0).astype(int)
    
    # Test each baseline
    baselines = get_simple_baselines()
    
    for baseline in baselines:
        print(f"\nTesting {baseline.method_name}...")
        results = evaluate_simple_baseline(baseline, X, y)
        
        print(f"  Mean accuracy: {results['mean_accuracy']:.3f}")
        print(f"  Worst window accuracy: {results['worst_window_accuracy']:.3f}")
        print(f"  Macro F1: {results['macro_f1']:.3f}")
        print(f"  Predictions made: {results['n_predictions']}")
    
    print("\n✅ Simple baselines tested successfully!")