#!/usr/bin/env python
"""
Simplified experimental runner for QISK evaluation.
Uses only sklearn dependencies and generates reproducible results.
"""

import numpy as np
import json
import os
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Any

# Import core modules
from enhanced_baselines import KTATunedBaseline
from simple_baselines import SimpleStreamingBaseline, evaluate_simple_baseline
from evaluation_protocols import WindowBasedEvaluator


class SimpleExperimentRunner:
    """Simple experiment runner with minimal dependencies."""
    
    def __init__(self, output_dir: str = None):
        self.output_dir = output_dir or f"experimental_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        Path(self.output_dir).mkdir(exist_ok=True)
        
    def generate_sea_data(self, n_samples: int = 1000, seed: int = 42) -> tuple:
        """Generate SEA concept drift dataset."""
        np.random.seed(seed)
        
        # SEA generator with concept drift
        X = np.random.uniform(0, 10, (n_samples, 3))
        y = np.zeros(n_samples, dtype=int)
        
        # Drift points at 1/3 and 2/3
        drift_points = [n_samples // 3, 2 * n_samples // 3]
        
        for i in range(n_samples):
            # Determine current concept based on position
            if i < drift_points[0]:
                threshold = 8.0  # Concept 1
            elif i < drift_points[1]:
                threshold = 9.0  # Concept 2  
            else:
                threshold = 7.0  # Concept 3
                
            y[i] = 1 if (X[i, 0] + X[i, 1]) <= threshold else 0
        
        return X, y
    
    def generate_hyperplane_data(self, n_samples: int = 1000, seed: int = 42) -> tuple:
        """Generate rotating hyperplane dataset."""
        np.random.seed(seed)
        
        X = np.random.uniform(-1, 1, (n_samples, 4))
        y = np.zeros(n_samples, dtype=int)
        
        for i in range(n_samples):
            # Rotate hyperplane over time
            angle = 2 * np.pi * i / n_samples
            w = np.array([np.cos(angle), np.sin(angle), 0.1, 0.1])
            y[i] = 1 if np.dot(X[i], w) > 0 else 0
                
        return X, y
    
    def evaluate_enhanced_baseline(self, baseline: KTATunedBaseline, X: np.ndarray, y: np.ndarray) -> Dict[str, Any]:
        """Evaluate enhanced baseline using window-based protocol."""
        evaluator = WindowBasedEvaluator(window_size=200, train_ratio=0.8)
        return evaluator.evaluate(X, y, baseline)
    
    def run_experiments(self, n_seeds: int = 3) -> Dict[str, Any]:
        """Run simplified experiments on synthetic datasets."""
        
        datasets = {
            'sea': {'generator': self.generate_sea_data, 'n_samples': 1000},
            'rotating_hyperplane': {'generator': self.generate_hyperplane_data, 'n_samples': 1000}
        }
        
        # Define methods to test
        methods = {
            'rbf_svm': {'type': 'simple_baseline', 'name': 'svm'},
            'random_forest': {'type': 'simple_baseline', 'name': 'random_forest'}, 
            'periodic_kernel': {'type': 'enhanced_baseline', 'kernel': 'periodic'},
            'cosine_kernel': {'type': 'enhanced_baseline', 'kernel': 'cosine'},
            'rff_kernel': {'type': 'enhanced_baseline', 'kernel': 'rff'}
        }
        
        all_results = {}
        
        for dataset_name, dataset_config in datasets.items():
            print(f"\n{'='*50}")
            print(f"Running experiments on {dataset_name}")
            print(f"{'='*50}")
            
            dataset_results = {}
            
            for method_name, method_config in methods.items():
                print(f"\nTesting {method_name}...")
                
                method_results = []
                
                for seed in range(n_seeds):
                    print(f"  Seed {seed + 1}/{n_seeds}")
                    
                    try:
                        # Generate data
                        X, y = dataset_config['generator'](
                            dataset_config['n_samples'], seed=seed + 42
                        )
                        
                        # Create method
                        if method_config['type'] == 'simple_baseline':
                            method = SimpleStreamingBaseline(method_config['name'])
                            result = evaluate_simple_baseline(method, X, y)
                        else:  # enhanced_baseline
                            method = KTATunedBaseline(
                                kernel_type=method_config['kernel'],
                                n_features=min(4, X.shape[1]),
                                spsa_iterations=10  # Reduced for speed
                            )
                            result = self.evaluate_enhanced_baseline(method, X, y)
                            # Standardize result format
                            if 'worst_window_accuracy' not in result:
                                result['worst_window_accuracy'] = result.get('mean_accuracy', 0.5)
                        
                        method_results.append(result)
                        
                        print(f"    Accuracy: {result['mean_accuracy']:.3f}, "
                              f"Worst window: {result['worst_window_accuracy']:.3f}")
                        
                    except Exception as e:
                        print(f"    Error: {e}")
                        continue
                
                # Aggregate results across seeds
                if method_results:
                    dataset_results[method_name] = self.aggregate_results(method_results)
            
            all_results[dataset_name] = dataset_results
        
        # Save results
        results_file = os.path.join(self.output_dir, 'simple_results.json')
        with open(results_file, 'w') as f:
            json.dump(all_results, f, indent=2, default=str)
        
        print(f"\n{'='*50}")
        print("Experiments completed!")
        print(f"Results saved to: {results_file}")
        print(f"{'='*50}")
        
        return all_results
    
    def aggregate_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Aggregate results across seeds."""
        if not results:
            return {}
        
        # Extract metrics
        metrics = ['mean_accuracy', 'worst_window_accuracy', 'macro_f1']
        aggregated = {}
        
        for metric in metrics:
            values = []
            for result in results:
                if metric in result and isinstance(result[metric], (int, float)):
                    values.append(result[metric])
            
            if values:
                aggregated[metric] = {
                    'mean': np.mean(values),
                    'std': np.std(values, ddof=1) if len(values) > 1 else 0.0,
                    'se': np.std(values, ddof=1) / np.sqrt(len(values)) if len(values) > 1 else 0.0,
                    'values': values
                }
        
        aggregated['n_seeds'] = len(results)
        return aggregated
    
    def print_summary(self, results: Dict[str, Any]):
        """Print experimental results summary."""
        print("\n" + "="*60)
        print("EXPERIMENTAL RESULTS SUMMARY")
        print("="*60)
        
        for dataset_name, dataset_results in results.items():
            print(f"\n{dataset_name.upper().replace('_', ' ')} Dataset:")
            print("-" * 40)
            
            for method_name, method_results in dataset_results.items():
                if 'worst_window_accuracy' in method_results:
                    wwa = method_results['worst_window_accuracy']
                    if isinstance(wwa, dict):
                        print(f"  {method_name:20s}: {wwa['mean']:.3f} ± {wwa['se']:.3f}")
                    else:
                        print(f"  {method_name:20s}: {wwa:.3f}")
        
        print("="*60)


def main():
    """Main function to run simple experiments."""
    print("🔬 QISK Simple Experimental Runner")
    print("=" * 50)
    
    # Create runner
    runner = SimpleExperimentRunner()
    
    # Run experiments
    results = runner.run_experiments(n_seeds=3)  # Reduced seeds for speed
    
    # Print summary
    runner.print_summary(results)
    
    return results


if __name__ == "__main__":
    main()