"""
Comprehensive experimental runner for QISK framework.
Runs experiments on all datasets with proper seed control and statistical analysis.
"""

import json
import os
import time
from pathlib import Path
import numpy as np
from datetime import datetime
from typing import Dict, List, Tuple

import sys
sys.path.append('..')

from qisk_implementation import QISK
from real_world_datasets import get_real_world_datasets
from baselines.streaming_baselines import StreamingBaseline
from enhanced_baselines import KTATunedBaseline, evaluate_baseline_on_window


class ComprehensiveExperimentRunner:
    """Main experimental runner with full reproducibility."""
    
    def __init__(self, seeds: List[int] = None, output_dir: str = None):
        self.seeds = seeds or [42, 123, 456, 789, 1011]
        self.output_dir = output_dir or f"experimental_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        Path(self.output_dir).mkdir(exist_ok=True)
        
    def run_all_experiments(self) -> Dict:
        """Run comprehensive experiments on all datasets."""
        
        # Get all datasets - both simple synthetic and realistic surrogates
        datasets = {
            "sea": {"type": "sea", "n_samples": 3000, "drift_positions": [1000, 2000]},
            "rotating_hyperplane": {"type": "rotating_hyperplane", "n_samples": 3000},
            "electricity_surrogate": {"type": "electricity", "n_samples": 5000},
            "airlines_surrogate": {"type": "airlines", "n_samples": 8000}, 
            "covertype_surrogate": {"type": "covertype", "n_samples": 6000},
            "sensoring_surrogate": {"type": "sensoring", "n_samples": 4000},
            "poker_surrogate": {"type": "poker", "n_samples": 5000}
        }
        
        methods = {
            "rbf_svm_standard": {"type": "baseline", "method": "svm"},
            "rbf_svm_dro_lite": {"type": "baseline", "method": "svm_dro"},
            "kmm_baseline": {"type": "baseline", "method": "kmm_svm"},
            "ulsif_baseline": {"type": "baseline", "method": "ulsif_svm"},
            "adwin_drift_aware": {"type": "baseline", "method": "adwin"},
            "adaptive_random_forest": {"type": "baseline", "method": "arf"},
            "hoeffding_adaptive_tree": {"type": "baseline", "method": "hat"},
            "periodic_kernel_kta": {"type": "enhanced_baseline", "method": "periodic"},
            "cosine_kernel_kta": {"type": "enhanced_baseline", "method": "cosine"},
            "rff_kernel_kta": {"type": "enhanced_baseline", "method": "rff"},
            "qisk": {"type": "qisk", "method": "full"}
        }
        
        all_results = {}
        
        for dataset_name, dataset_config in datasets.items():
            print(f"\n{'='*60}")
            print(f"Running experiments on {dataset_name}")
            print(f"{'='*60}")
            
            dataset_results = {}
            
            for method_name, method_config in methods.items():
                print(f"\nTesting {method_name} on {dataset_name}...")
                
                method_results = []
                
                for seed_idx, seed in enumerate(self.seeds):
                    print(f"  Seed {seed_idx+1}/{len(self.seeds)}: {seed}")
                    
                    try:
                        result = self._run_single_experiment(
                            dataset_name, dataset_config,
                            method_name, method_config,
                            seed
                        )
                        method_results.append(result)
                        
                    except Exception as e:
                        print(f"    Error with seed {seed}: {e}")
                        continue
                
                if method_results:
                    # Aggregate results across seeds
                    dataset_results[method_name] = self._aggregate_results(method_results)
                    
            all_results[dataset_name] = dataset_results
            
        # Save comprehensive results
        output_file = os.path.join(self.output_dir, "comprehensive_results.json")
        with open(output_file, 'w') as f:
            json.dump(all_results, f, indent=2)
            
        print(f"\n{'='*60}")
        print(f"All experiments completed!")
        print(f"Results saved to: {output_file}")
        print(f"{'='*60}")
        
        return all_results
    
    def _run_single_experiment(self, dataset_name: str, dataset_config: Dict,
                             method_name: str, method_config: Dict, 
                             seed: int) -> Dict:
        """Run a single experiment with proper protocol."""
        
        # Set random seed for reproducibility
        np.random.seed(seed)
        
        # Generate dataset
        if dataset_config["type"] in ["sea", "rotating_hyperplane"]:
            # Use simple synthetic generators
            X, y = self._generate_synthetic_data(dataset_config, seed)
        else:
            # Use surrogate datasets from real_world_datasets.py
            dataset_map = {
                "electricity": 0,
                "airlines": 1, 
                "covertype": 2,
                "sensoring": 3,
                "poker": 4
            }
            if dataset_config["type"] in dataset_map:
                dataset_gen = get_real_world_datasets()[dataset_map[dataset_config["type"]]]
                X, y = dataset_gen.load_data()
                
                # Subsample if needed
                n_samples = dataset_config.get("n_samples", len(X))
                if n_samples < len(X):
                    np.random.seed(seed)
                    indices = np.random.choice(len(X), size=n_samples, replace=False)
                    X, y = X[indices], y[indices]
            else:
                raise ValueError(f"Unknown dataset type: {dataset_config['type']}")
        
        # Run experiment with proper prequential evaluation
        if method_config["type"] == "qisk":
            method = QISK(
                n_qubits=min(4, X.shape[1]),
                n_anchors=16,
                spsa_params={'a': 0.1, 'c': 0.01, 'iterations': 10}
            )
        elif method_config["type"] == "enhanced_baseline":
            # Create enhanced baseline with KTA tuning
            from enhanced_baselines import KTATunedBaseline
            kernel_type = method_config["method"]
            method = KTATunedBaseline(
                kernel_type=kernel_type,
                n_features=min(4, X.shape[1]),
                spsa_iterations=25
            )
        else:
            method = StreamingBaseline(method_config["method"])
            
        # Choose appropriate evaluation protocol
        if method_config["type"] in ["enhanced_baseline", "qisk"]:
            # Batch methods need window-based evaluation
            results = self._window_based_evaluate(X, y, method, seed)
        else:
            # Streaming methods use prequential evaluation
            results = self._prequential_evaluate(X, y, method, seed)
        
        return results
    
    def _generate_synthetic_data(self, config: Dict, seed: int) -> Tuple[np.ndarray, np.ndarray]:
        """Generate synthetic data for SEA and Rotating Hyperplane."""
        np.random.seed(seed)
        n_samples = config["n_samples"]
        
        if config["type"] == "sea":
            # SEA generator with concept drift
            X = np.random.uniform(0, 10, (n_samples, 3))
            y = np.zeros(n_samples, dtype=int)
            
            drift_positions = config.get("drift_positions", [n_samples // 3, 2 * n_samples // 3])
            
            for i in range(n_samples):
                # Determine current concept
                if i < drift_positions[0]:
                    threshold = 8.0  # Concept 1
                elif i < drift_positions[1]:
                    threshold = 9.0  # Concept 2  
                else:
                    threshold = 7.0  # Concept 3
                    
                y[i] = 1 if (X[i, 0] + X[i, 1]) <= threshold else 0
                
        elif config["type"] == "rotating_hyperplane":
            # Rotating hyperplane
            X = np.random.uniform(-1, 1, (n_samples, 4))
            y = np.zeros(n_samples, dtype=int)
            
            for i in range(n_samples):
                # Rotate hyperplane over time
                angle = 2 * np.pi * i / n_samples
                w = np.array([np.cos(angle), np.sin(angle), 0.1, 0.1])
                y[i] = 1 if np.dot(X[i], w) > 0 else 0
                
        # Convert to binary classification
        y = y.astype(int)
        return X, y
    
    def _prequential_evaluate(self, X: np.ndarray, y: np.ndarray, 
                            method, seed: int) -> Dict:
        """
        Proper prequential evaluation protocol (test-then-train).
        
        NOTE: This is true prequential evaluation where each sample is:
        1. First used for prediction (testing)  
        2. Then used for training (partial_fit)
        
        This differs from window-based evaluation used in other scripts.
        """
        
        n_samples = len(X)
        window_size = 200
        predictions = []
        true_labels = []
        window_accuracies = []
        
        # Initialize method
        method.reset() if hasattr(method, 'reset') else None
        
        for i in range(n_samples):
            # Predict first (test-then-train)
            if i > 0:  # Skip prediction for first sample
                pred = method.predict(X[i:i+1])[0]
                predictions.append(pred)
                true_labels.append(y[i])
            
            # Then train on current sample
            method.partial_fit(X[i:i+1], y[i:i+1])
            
            # Calculate window-based metrics
            if len(predictions) >= window_size:
                window_start = max(0, len(predictions) - window_size)
                window_preds = predictions[window_start:]
                window_true = true_labels[window_start:]
                
                window_acc = np.mean(np.array(window_preds) == np.array(window_true))
                window_accuracies.append(window_acc)
        
        # Calculate comprehensive metrics
        if len(predictions) == 0:
            return {"error": "No predictions made"}
            
        predictions = np.array(predictions)
        true_labels = np.array(true_labels)
        
        # Overall metrics
        mean_accuracy = np.mean(predictions == true_labels)
        worst_window_accuracy = np.min(window_accuracies) if window_accuracies else mean_accuracy
        
        # Macro F1
        from sklearn.metrics import f1_score
        macro_f1 = f1_score(true_labels, predictions, average='macro', zero_division=0)
        
        return {
            "mean_accuracy": mean_accuracy,
            "worst_window_accuracy": worst_window_accuracy,
            "macro_f1": macro_f1,
            "n_predictions": len(predictions),
            "n_windows": len(window_accuracies),
            "seed": seed
        }
    
    def _window_based_evaluate(self, X: np.ndarray, y: np.ndarray, 
                             method, seed: int) -> Dict:
        """Window-based evaluation protocol for batch methods."""
        
        n_samples = len(X)
        window_size = 200
        window_results = []
        
        # Process data in sliding windows
        n_windows = (n_samples - window_size) // (window_size // 2) + 1
        
        for i in range(n_windows):
            start_idx = i * (window_size // 2)
            end_idx = min(start_idx + window_size, n_samples)
            
            if end_idx - start_idx < 50:  # Skip too small windows
                break
                
            X_window = X[start_idx:end_idx]
            y_window = y[start_idx:end_idx]
            
            # Skip windows with insufficient class diversity
            if len(np.unique(y_window)) < 2:
                continue
            
            # Split window into train/test (80/20)
            split_idx = int(0.8 * len(X_window))
            X_train = X_window[:split_idx]
            y_train = y_window[:split_idx]
            X_test = X_window[split_idx:]
            y_test = y_window[split_idx:]
            
            if len(X_test) == 0:
                continue
                
            try:
                # Fit method on training data
                if hasattr(method, 'fit'):
                    method.fit(X_train, y_train)
                    predictions = method.predict(X_test)
                else:
                    # Handle methods without standard fit interface
                    continue
                
                # Compute metrics
                accuracy = np.mean(predictions == y_test)
                balanced_acc = balanced_accuracy_score(y_test, predictions)
                macro_f1 = f1_score(y_test, predictions, average='macro', zero_division=0)
                
                window_results.append({
                    'window': i,
                    'accuracy': accuracy,
                    'balanced_accuracy': balanced_acc,
                    'macro_f1': macro_f1
                })
                
            except Exception as e:
                print(f"Warning: Window {i} failed: {e}")
                continue
        
        if len(window_results) == 0:
            return {"error": "No valid windows processed"}
        
        # Aggregate window results
        accuracies = [r['accuracy'] for r in window_results]
        f1_scores = [r['macro_f1'] for r in window_results]
        
        return {
            "mean_accuracy": np.mean(accuracies),
            "worst_window_accuracy": np.min(accuracies),
            "macro_f1": np.mean(f1_scores),
            "n_predictions": sum(len(window_results)),
            "n_windows": len(window_results),
            "seed": seed
        }
    
    def _aggregate_results(self, results: List[Dict]) -> Dict:
        """Aggregate results across seeds with proper statistics."""
        
        if not results:
            return {}
            
        # Extract metrics across seeds
        metrics = {}
        for key in ["mean_accuracy", "worst_window_accuracy", "macro_f1"]:
            values = [r[key] for r in results if key in r and not isinstance(r[key], str)]
            if values:
                metrics[key] = {
                    "mean": np.mean(values),
                    "std": np.std(values, ddof=1) if len(values) > 1 else 0.0,
                    "se": np.std(values, ddof=1) / np.sqrt(len(values)) if len(values) > 1 else 0.0,
                    "values": values
                }
        
        # Add metadata
        metrics["n_seeds"] = len(results)
        metrics["seeds_used"] = [r.get("seed", None) for r in results]
        
        return metrics


def main():
    """Main entry point for comprehensive experiments."""
    
    print("QISK Comprehensive Experimental Runner")
    print("=" * 50)
    
    # Create runner with standard seeds
    runner = ComprehensiveRunner(
        seeds=[42, 123, 456, 789, 1011],
        output_dir="../data/experimental_results" if Path("../data").exists() else "data/experimental_results"
    )
    
    # Run all experiments
    results = runner.run_all_experiments()
    
    # Print summary
    print("\nEXPERIMENT SUMMARY:")
    for dataset_name, dataset_results in results.items():
        print(f"\n{dataset_name}:")
        for method_name, method_results in dataset_results.items():
            if "worst_window_accuracy" in method_results:
                wwa = method_results["worst_window_accuracy"]
                print(f"  {method_name}: {wwa['mean']:.3f} ± {wwa['se']:.3f}")


if __name__ == "__main__":
    main()