#!/usr/bin/env python
"""
Real-world and synthetic streaming datasets for QISK evaluation.
Provides consistent streaming data generators with concept drift.
"""

import numpy as np
from typing import Iterator, Tuple, List, Dict, Any
from sklearn.datasets import make_classification
import warnings

# Suppress sklearn warnings for cleaner output
warnings.filterwarnings('ignore')


class StreamingDataset:
    """Base class for streaming datasets with concept drift."""
    
    def __init__(self, name: str, n_samples: int = 5000, n_features: int = 4, 
                 n_classes: int = 2, random_state: int = 42):
        self.name = name
        self.n_samples = n_samples
        self.n_features = n_features
        self.n_classes = n_classes
        self.random_state = random_state
        self.rng = np.random.RandomState(random_state)
        
    def stream(self) -> Iterator[Tuple[np.ndarray, int]]:
        """Generate streaming data points."""
        raise NotImplementedError
        
    def get_batch(self, batch_size: int) -> Tuple[np.ndarray, np.ndarray]:
        """Get a batch of data points."""
        X, y = [], []
        stream_gen = self.stream()
        for _ in range(batch_size):
            try:
                x, label = next(stream_gen)
                X.append(x)
                y.append(label)
            except StopIteration:
                break
        return np.array(X), np.array(y)


class SEADataset(StreamingDataset):
    """
    SEA dataset with concept drift.
    Based on the classic concept drift benchmark.
    """
    
    def __init__(self, n_samples: int = 5000, noise_level: float = 0.1, 
                 drift_points: List[int] = None, random_state: int = 42):
        super().__init__("sea", n_samples, n_features=3, n_classes=2, random_state=random_state)
        self.noise_level = noise_level
        self.drift_points = drift_points or [n_samples // 3, 2 * n_samples // 3]
        
    def _sea_concept(self, x: np.ndarray, concept_id: int) -> int:
        """SEA concept function with different thresholds."""
        concepts = {
            0: lambda x: (x[0] + x[1]) <= 0.7,  # Original concept
            1: lambda x: (x[0] + x[1]) <= 0.8,  # Shifted threshold
            2: lambda x: (x[0] + x[1]) <= 0.6,  # Different shift
        }
        return int(concepts[concept_id](x))
        
    def stream(self) -> Iterator[Tuple[np.ndarray, int]]:
        """Generate SEA concept drift stream."""
        for i in range(self.n_samples):
            # Determine current concept
            concept_id = 0
            for j, drift_point in enumerate(self.drift_points):
                if i >= drift_point:
                    concept_id = j + 1
                    
            # Generate sample
            x = self.rng.uniform(0, 1, 3)
            y = self._sea_concept(x, concept_id % 3)
            
            # Add noise
            if self.rng.random() < self.noise_level:
                y = 1 - y
                
            yield x, y


class RotatingHyperplaneDataset(StreamingDataset):
    """
    Rotating hyperplane dataset with gradual concept drift.
    The decision boundary rotates over time.
    """
    
    def __init__(self, n_samples: int = 5000, n_features: int = 4, 
                 rotation_speed: float = 0.001, noise_level: float = 0.05,
                 random_state: int = 42):
        super().__init__("rotating_hyperplane", n_samples, n_features, 2, random_state)
        self.rotation_speed = rotation_speed
        self.noise_level = noise_level
        
        # Initialize random hyperplane
        self.initial_weights = self.rng.normal(0, 1, n_features)
        self.initial_weights /= np.linalg.norm(self.initial_weights)
        
    def _get_weights(self, time_step: int) -> np.ndarray:
        """Get hyperplane weights at time step (with rotation)."""
        angle = self.rotation_speed * time_step
        rotation_matrix = self._rotation_matrix(angle)
        return rotation_matrix @ self.initial_weights
        
    def _rotation_matrix(self, angle: float) -> np.ndarray:
        """Create rotation matrix for gradual concept drift."""
        # Simple 2D rotation extended to higher dimensions
        R = np.eye(self.n_features)
        if self.n_features >= 2:
            R[0, 0] = np.cos(angle)
            R[0, 1] = -np.sin(angle)
            R[1, 0] = np.sin(angle)
            R[1, 1] = np.cos(angle)
        return R
        
    def stream(self) -> Iterator[Tuple[np.ndarray, int]]:
        """Generate rotating hyperplane stream."""
        for i in range(self.n_samples):
            # Generate sample from unit Gaussian
            x = self.rng.normal(0, 1, self.n_features)
            
            # Get current hyperplane weights
            weights = self._get_weights(i)
            
            # Classify based on hyperplane
            y = int(np.dot(x, weights) > 0)
            
            # Add noise
            if self.rng.random() < self.noise_level:
                y = 1 - y
                
            yield x, y


class SineDataset(StreamingDataset):
    """
    Sine dataset with periodic concept drift.
    Classification boundary follows a sine wave pattern.
    """
    
    def __init__(self, n_samples: int = 5000, n_features: int = 2,
                 frequency: float = 0.01, noise_level: float = 0.1,
                 random_state: int = 42):
        super().__init__("sine", n_samples, n_features, 2, random_state)
        self.frequency = frequency
        self.noise_level = noise_level
        
    def stream(self) -> Iterator[Tuple[np.ndarray, int]]:
        """Generate sine wave concept drift stream."""
        for i in range(self.n_samples):
            # Generate sample
            x = self.rng.uniform(-1, 1, self.n_features)
            
            # Sine-based decision boundary that changes over time
            threshold = 0.5 + 0.3 * np.sin(2 * np.pi * self.frequency * i)
            y = int(x[0] > threshold)
            
            # Add noise
            if self.rng.random() < self.noise_level:
                y = 1 - y
                
            yield x, y


class MixedDataset(StreamingDataset):
    """
    Mixed synthetic dataset combining multiple concept types.
    """
    
    def __init__(self, n_samples: int = 5000, n_features: int = 4,
                 n_informative: int = 3, n_clusters_per_class: int = 2,
                 random_state: int = 42):
        super().__init__("mixed", n_samples, n_features, 2, random_state)
        self.n_informative = n_informative
        self.n_clusters_per_class = n_clusters_per_class
        
    def stream(self) -> Iterator[Tuple[np.ndarray, int]]:
        """Generate mixed concept stream with sklearn make_classification."""
        # Generate data in chunks to simulate concept drift
        chunk_size = self.n_samples // 4
        
        for chunk in range(4):
            # Vary parameters slightly for each chunk (concept drift)
            flip_y = 0.05 + 0.02 * chunk
            class_sep = 0.8 + 0.1 * np.sin(chunk)
            
            X_chunk, y_chunk = make_classification(
                n_samples=chunk_size,
                n_features=self.n_features,
                n_informative=self.n_informative,
                n_clusters_per_class=self.n_clusters_per_class,
                flip_y=flip_y,
                class_sep=class_sep,
                random_state=self.random_state + chunk
            )
            
            for i in range(len(X_chunk)):
                yield X_chunk[i], y_chunk[i]


def get_real_world_datasets(include_large: bool = False) -> List[StreamingDataset]:
    """
    Get list of streaming datasets for evaluation.
    
    Args:
        include_large: Whether to include large datasets (for full evaluation)
        
    Returns:
        List of StreamingDataset instances
    """
    datasets = [
        SEADataset(n_samples=5000, random_state=42),
        RotatingHyperplaneDataset(n_samples=5000, random_state=42),
    ]
    
    if include_large:
        datasets.extend([
            SineDataset(n_samples=5000, random_state=42),
            MixedDataset(n_samples=5000, random_state=42),
            # Larger versions for thorough evaluation
            SEADataset(n_samples=10000, random_state=123),
            RotatingHyperplaneDataset(n_samples=10000, random_state=123),
        ])
    
    return datasets


def get_dataset_by_name(name: str, n_samples: int = 5000, **kwargs) -> StreamingDataset:
    """Get a specific dataset by name."""
    datasets = {
        'sea': SEADataset,
        'rotating_hyperplane': RotatingHyperplaneDataset,
        'sine': SineDataset,
        'mixed': MixedDataset,
    }
    
    if name not in datasets:
        raise ValueError(f"Unknown dataset: {name}. Available: {list(datasets.keys())}")
    
    return datasets[name](n_samples=n_samples, **kwargs)


if __name__ == "__main__":
    # Demo the datasets
    print("QISK Streaming Datasets Demo")
    print("=" * 40)
    
    datasets = get_real_world_datasets()
    
    for dataset in datasets:
        print(f"\n📊 {dataset.name.upper()} Dataset")
        print(f"   Samples: {dataset.n_samples}")
        print(f"   Features: {dataset.n_features}")
        print(f"   Classes: {dataset.n_classes}")
        
        # Get a small sample
        X_sample, y_sample = dataset.get_batch(10)
        print(f"   Sample shape: {X_sample.shape}")
        print(f"   Label distribution: {np.bincount(y_sample)}")
    
    print(f"\n✅ All {len(datasets)} datasets ready for streaming evaluation!")