#!/usr/bin/env python
"""
Unified evaluation protocols for QISK framework.
Clarifies and implements both prequential and window-based evaluation.
"""

import numpy as np
from typing import Dict, List, Any, Optional, Tuple, Union
from sklearn.metrics import balanced_accuracy_score, f1_score
from abc import ABC, abstractmethod


class StreamingEvaluator(ABC):
    """Base class for streaming evaluation protocols."""
    
    @abstractmethod
    def evaluate(self, X: np.ndarray, y: np.ndarray, method) -> Dict[str, Any]:
        """Evaluate method on streaming data."""
        pass


class PrequentialEvaluator(StreamingEvaluator):
    """
    Prequential (test-then-train) evaluation protocol.
    
    Each sample is:
    1. First used for testing (prediction)
    2. Then used for training (partial_fit)
    
    This is the gold standard for streaming evaluation.
    """
    
    def __init__(self, window_size: int = 200, min_train_size: int = 50):
        """
        Initialize prequential evaluator.
        
        Args:
            window_size: Size of sliding window for metrics calculation
            min_train_size: Minimum training samples before predictions start
        """
        self.window_size = window_size
        self.min_train_size = min_train_size
    
    def evaluate(self, X: np.ndarray, y: np.ndarray, method) -> Dict[str, Any]:
        """
        Perform prequential evaluation.
        
        Args:
            X: Feature matrix (n_samples, n_features)
            y: Label vector (n_samples,)
            method: Streaming method with partial_fit and predict methods
            
        Returns:
            Dictionary with evaluation metrics
        """
        n_samples = len(X)
        predictions = []
        true_labels = []
        window_accuracies = []
        
        # Reset method state
        if hasattr(method, 'reset'):
            method.reset()
        
        # Prequential protocol: test-then-train
        for i in range(n_samples):
            # Predict first (only after sufficient training)
            if i >= self.min_train_size:
                try:
                    pred = method.predict(X[i:i+1])[0]
                    predictions.append(pred)
                    true_labels.append(y[i])
                except Exception:
                    # Handle prediction failures gracefully
                    predictions.append(1 if len(np.unique(y[:i])) < 2 else np.random.choice(np.unique(y[:i])))
                    true_labels.append(y[i])
            
            # Then train on current sample
            try:
                if hasattr(method, 'partial_fit'):
                    method.partial_fit(X[i:i+1], y[i:i+1])
                elif hasattr(method, 'learn_one'):  # River-style
                    method.learn_one(dict(enumerate(X[i])), y[i])
            except Exception:
                pass  # Skip problematic training samples
            
            # Calculate sliding window metrics
            if len(predictions) >= self.window_size:
                window_start = len(predictions) - self.window_size
                window_preds = predictions[window_start:]
                window_true = true_labels[window_start:]
                
                window_acc = np.mean(np.array(window_preds) == np.array(window_true))
                window_accuracies.append(window_acc)
        
        if len(predictions) == 0:
            return {
                'mean_accuracy': 0.0,
                'worst_window_accuracy': 0.0,
                'macro_f1': 0.0,
                'n_predictions': 0,
                'n_windows': 0,
                'evaluation_protocol': 'prequential'
            }
        
        predictions = np.array(predictions)
        true_labels = np.array(true_labels)
        
        # Overall metrics
        mean_accuracy = np.mean(predictions == true_labels)
        macro_f1 = f1_score(true_labels, predictions, average='macro', zero_division=0)
        worst_window_accuracy = np.min(window_accuracies) if window_accuracies else mean_accuracy
        
        return {
            'mean_accuracy': mean_accuracy,
            'worst_window_accuracy': worst_window_accuracy,
            'macro_f1': macro_f1,
            'n_predictions': len(predictions),
            'n_windows': len(window_accuracies),
            'window_accuracies': window_accuracies,
            'evaluation_protocol': 'prequential'
        }


class WindowBasedEvaluator(StreamingEvaluator):
    """
    Window-based evaluation with train/test splits within each window.
    
    Each window is:
    1. Split into train (80%) and test (20%) portions
    2. Method is trained on train portion 
    3. Evaluated on test portion
    4. Results aggregated across windows
    
    This is more suitable for batch-oriented methods like QISK.
    """
    
    def __init__(self, window_size: int = 200, train_ratio: float = 0.8, 
                 overlap_ratio: float = 0.0, min_window_samples: int = 50):
        """
        Initialize window-based evaluator.
        
        Args:
            window_size: Number of samples per window
            train_ratio: Fraction of window used for training (rest for testing)
            overlap_ratio: Overlap between consecutive windows (0.0 = no overlap)
            min_window_samples: Minimum samples required for valid window
        """
        self.window_size = window_size
        self.train_ratio = train_ratio
        self.overlap_ratio = overlap_ratio
        self.min_window_samples = min_window_samples
    
    def evaluate(self, X: np.ndarray, y: np.ndarray, method) -> Dict[str, Any]:
        """
        Perform window-based evaluation.
        
        Args:
            X: Feature matrix (n_samples, n_features)
            y: Label vector (n_samples,) 
            method: Method with fit and predict methods
            
        Returns:
            Dictionary with evaluation metrics
        """
        n_samples = len(X)
        step_size = int(self.window_size * (1 - self.overlap_ratio))
        
        window_results = []
        method_history = []  # For methods that need historical context
        
        window_idx = 0
        start_idx = 0
        
        while start_idx + self.min_window_samples <= n_samples:
            end_idx = min(start_idx + self.window_size, n_samples)
            
            X_window = X[start_idx:end_idx]
            y_window = y[start_idx:end_idx]
            
            # Skip windows with insufficient diversity
            if len(np.unique(y_window)) < 2 or len(y_window) < self.min_window_samples:
                start_idx += step_size
                continue
            
            # Split window into train/test
            split_idx = int(len(X_window) * self.train_ratio)
            X_train = X_window[:split_idx]
            y_train = y_window[:split_idx]
            X_test = X_window[split_idx:]
            y_test = y_window[split_idx:]
            
            if len(X_test) == 0:  # Ensure we have test data
                start_idx += step_size
                continue
            
            try:
                # Train method on current window
                if hasattr(method, 'fit_window'):
                    # QISK-style training with history
                    train_results, classifier = method.fit_window(
                        X_train, y_train, method_history
                    )
                    eval_results = method.evaluate_window(
                        X_test, y_test, classifier, X_train
                    )
                    
                    window_result = {
                        'window': window_idx,
                        'train_size': len(X_train),
                        'test_size': len(X_test),
                        'accuracy': eval_results['accuracy'],
                        'balanced_accuracy': eval_results.get('balanced_accuracy', eval_results['accuracy']),
                        'macro_f1': eval_results['macro_f1']
                    }
                    
                    # Add training metrics if available
                    if 'best_kta' in train_results:
                        window_result['kta'] = train_results['best_kta']
                    
                    # Update method history
                    method_history.append(X_train)
                    if len(method_history) > 3:  # Keep last 3 windows
                        method_history.pop(0)
                
                elif hasattr(method, 'fit'):
                    # Standard sklearn-style method
                    method.fit(X_train, y_train)
                    predictions = method.predict(X_test)
                    
                    accuracy = np.mean(predictions == y_test)
                    balanced_acc = balanced_accuracy_score(y_test, predictions)
                    macro_f1 = f1_score(y_test, predictions, average='macro', zero_division=0)
                    
                    window_result = {
                        'window': window_idx,
                        'train_size': len(X_train),
                        'test_size': len(X_test),
                        'accuracy': accuracy,
                        'balanced_accuracy': balanced_acc,
                        'macro_f1': macro_f1
                    }
                
                else:
                    raise ValueError(f"Method {type(method)} must have fit_window or fit method")
                
                window_results.append(window_result)
                
            except Exception as e:
                print(f"Warning: Window {window_idx} failed: {e}")
                pass
            
            window_idx += 1
            start_idx += step_size
        
        if len(window_results) == 0:
            return {
                'mean_accuracy': 0.0,
                'worst_window_accuracy': 0.0,
                'macro_f1': 0.0,
                'n_windows': 0,
                'evaluation_protocol': 'window_based'
            }
        
        # Aggregate results across windows
        accuracies = [r['accuracy'] for r in window_results]
        balanced_accs = [r['balanced_accuracy'] for r in window_results if 'balanced_accuracy' in r]
        f1_scores = [r['macro_f1'] for r in window_results]
        
        results = {
            'mean_accuracy': np.mean(accuracies),
            'worst_window_accuracy': np.min(accuracies),
            'best_window_accuracy': np.max(accuracies),
            'macro_f1': np.mean(f1_scores),
            'n_windows': len(window_results),
            'window_accuracies': accuracies,
            'per_window_results': window_results,
            'evaluation_protocol': 'window_based'
        }
        
        if balanced_accs:
            results['balanced_accuracy'] = np.mean(balanced_accs)
            results['worst_balanced_accuracy'] = np.min(balanced_accs)
        
        # Add method-specific metrics
        kta_scores = [r['kta'] for r in window_results if 'kta' in r]
        if kta_scores:
            results['kta_correlation'] = np.corrcoef(kta_scores, accuracies)[0, 1] if len(kta_scores) > 1 else 0.0
            results['mean_kta'] = np.mean(kta_scores)
        
        return results


def choose_evaluation_protocol(method, data_size: int, prefer_prequential: bool = True) -> StreamingEvaluator:
    """
    Choose appropriate evaluation protocol based on method characteristics.
    
    Args:
        method: The method to be evaluated
        data_size: Size of the dataset
        prefer_prequential: Whether to prefer prequential when possible
        
    Returns:
        Appropriate evaluator instance
    """
    # Methods requiring batch training (like QISK) need window-based evaluation
    if hasattr(method, 'fit_window') or hasattr(method, 'spsa_iterations'):
        return WindowBasedEvaluator()
    
    # Methods with incremental learning can use prequential
    elif hasattr(method, 'partial_fit') or hasattr(method, 'learn_one'):
        if prefer_prequential and data_size <= 10000:  # Reasonable size limit
            return PrequentialEvaluator()
        else:
            return WindowBasedEvaluator()  # Fallback to window-based for efficiency
    
    # Default to window-based for batch methods
    else:
        return WindowBasedEvaluator()


def evaluate_with_protocol(X: np.ndarray, y: np.ndarray, method, 
                          protocol: Optional[str] = None, **kwargs) -> Dict[str, Any]:
    """
    Convenience function to evaluate using specified or auto-selected protocol.
    
    Args:
        X: Feature matrix
        y: Label vector  
        method: Method to evaluate
        protocol: 'prequential', 'window_based', or None for auto-selection
        **kwargs: Additional arguments for evaluator
        
    Returns:
        Evaluation results with protocol information
    """
    if protocol == 'prequential':
        evaluator = PrequentialEvaluator(**kwargs)
    elif protocol == 'window_based':  
        evaluator = WindowBasedEvaluator(**kwargs)
    elif protocol is None:
        evaluator = choose_evaluation_protocol(method, len(X))
    else:
        raise ValueError(f"Unknown protocol: {protocol}")
    
    results = evaluator.evaluate(X, y, method)
    results['evaluation_protocol'] = evaluator.__class__.__name__.lower().replace('evaluator', '')
    
    return results


if __name__ == "__main__":
    # Example usage and testing
    print("Testing evaluation protocols...")
    
    # Generate test data
    np.random.seed(42)
    n_samples = 1000
    X = np.random.randn(n_samples, 4)
    y = (X[:, 0] + X[:, 1] > 0).astype(int)
    
    # Test with a dummy streaming method
    class DummyStreamingMethod:
        def __init__(self):
            self.classes_ = [0, 1]
        
        def partial_fit(self, X, y):
            pass
        
        def predict(self, X):
            return np.random.choice([0, 1], size=len(X))
    
    method = DummyStreamingMethod()
    
    # Test prequential evaluation
    print("\nTesting prequential evaluation...")
    results_preq = evaluate_with_protocol(X, y, method, protocol='prequential')
    print(f"Mean accuracy: {results_preq['mean_accuracy']:.3f}")
    print(f"Worst window: {results_preq['worst_window_accuracy']:.3f}")
    print(f"Protocol: {results_preq['evaluation_protocol']}")
    
    # Test window-based evaluation  
    print("\nTesting window-based evaluation...")
    results_window = evaluate_with_protocol(X, y, method, protocol='window_based')
    print(f"Mean accuracy: {results_window['mean_accuracy']:.3f}")
    print(f"Worst window: {results_window['worst_window_accuracy']:.3f}")
    print(f"Protocol: {results_window['evaluation_protocol']}")
    
    print("\n✅ Evaluation protocols tested successfully!")