"""
Enhanced datasets for concept drift evaluation.
Includes challenging synthetic generators and real-world-inspired scenarios.
"""

import numpy as np
from typing import Tuple, Dict, List, Any
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_classification, make_blobs
import warnings
warnings.filterwarnings('ignore')


class AdvancedSyntheticGenerator:
    """Advanced synthetic data generators with complex drift patterns."""
    
    @staticmethod
    def generate_sine_clusters(n_samples: int = 3000, n_features: int = 4, 
                             drift_points: List[int] = None) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate sine-wave based clusters with evolving decision boundaries.
        More complex than traditional SEA generator.
        """
        if drift_points is None:
            drift_points = [n_samples // 3, 2 * n_samples // 3]
            
        X = np.zeros((n_samples, n_features))
        y = np.zeros(n_samples, dtype=int)
        
        for i in range(n_samples):
            # Generate base features
            X[i] = np.random.uniform(-2, 2, n_features)
            
            # Determine concept based on position
            if i < drift_points[0]:
                # Concept 1: Sine wave boundary
                boundary = np.sin(X[i, 0]) + 0.5 * np.cos(X[i, 1])
                y[i] = 1 if X[i, 2] > boundary else 0
            elif i < drift_points[1]:
                # Concept 2: Rotated sine wave
                angle = np.pi * i / drift_points[1]
                rotated_x = X[i, 0] * np.cos(angle) - X[i, 1] * np.sin(angle)
                boundary = np.sin(2 * rotated_x) + 0.3 * X[i, 3]
                y[i] = 1 if X[i, 2] > boundary else 0
            else:
                # Concept 3: Multi-modal boundary
                boundary1 = np.sin(X[i, 0] + X[i, 1])
                boundary2 = np.cos(X[i, 2] - X[i, 3])
                combined_boundary = 0.5 * boundary1 + 0.5 * boundary2
                y[i] = 1 if X[i, 0] + X[i, 1] > combined_boundary else 0
                
        return X, y
    
    @staticmethod
    def generate_multi_modal_drift(n_samples: int = 3000, n_features: int = 6,
                                 n_modes: int = 4) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate multi-modal data with gradually changing cluster centers.
        """
        X = np.zeros((n_samples, n_features))
        y = np.zeros(n_samples, dtype=int)
        
        # Initial cluster centers
        centers = np.random.uniform(-3, 3, (n_modes, n_features))
        center_velocities = np.random.uniform(-0.01, 0.01, (n_modes, n_features))
        
        for i in range(n_samples):
            # Update cluster centers (concept drift)
            centers += center_velocities
            
            # Add noise to velocities for more realistic drift
            if i % 100 == 0:
                center_velocities += np.random.uniform(-0.005, 0.005, (n_modes, n_features))
            
            # Assign sample to nearest cluster
            cluster_idx = i % n_modes
            X[i] = centers[cluster_idx] + np.random.normal(0, 0.5, n_features)
            
            # Complex labeling based on cluster interactions
            if cluster_idx < n_modes // 2:
                # Distance-based labeling
                distances = np.linalg.norm(X[i] - centers, axis=1)
                y[i] = 1 if np.min(distances) < np.mean(distances) else 0
            else:
                # Feature-combination labeling  
                y[i] = 1 if np.sum(X[i][:2]) > np.sum(X[i][2:4]) else 0
                
        return X, y
    
    @staticmethod
    def generate_gradual_concept_drift(n_samples: int = 3000, n_features: int = 4,
                                     drift_rate: float = 0.001) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate data with very gradual concept drift.
        Decision boundary changes smoothly over time.
        """
        X = np.random.uniform(-2, 2, (n_samples, n_features))
        y = np.zeros(n_samples, dtype=int)
        
        for i in range(n_samples):
            # Time-dependent decision boundary parameters
            t = i / n_samples
            
            # Gradually rotating hyperplane
            angle = drift_rate * i
            rotation_matrix = np.array([
                [np.cos(angle), -np.sin(angle)],
                [np.sin(angle), np.cos(angle)]
            ])
            
            # Apply rotation to first two features
            rotated_features = rotation_matrix @ X[i, :2]
            
            # Time-dependent threshold
            threshold = 0.5 * np.sin(2 * np.pi * t) + 0.3 * t
            
            # Complex decision rule
            decision_value = (rotated_features[0] * (1 + 0.5 * t) + 
                            rotated_features[1] * (1 - 0.3 * t) +
                            0.2 * X[i, 2] * X[i, 3])
            
            y[i] = 1 if decision_value > threshold else 0
            
        return X, y


class RealWorldDataGenerator:
    """
    Real-world inspired data generators based on actual concept drift scenarios.
    """
    
    @staticmethod
    def generate_financial_market_drift(n_samples: int = 4000, 
                                       n_features: int = 8) -> Tuple[np.ndarray, np.ndarray]:
        """
        Simulate financial market data with regime changes.
        Features represent market indicators, target is market direction.
        """
        X = np.zeros((n_samples, n_features))
        y = np.zeros(n_samples, dtype=int)
        
        # Market regimes: bull, bear, volatile, stable
        regime_length = n_samples // 4
        regimes = ['bull', 'bear', 'volatile', 'stable']
        
        for i in range(n_samples):
            regime_idx = min(i // regime_length, 3)
            regime = regimes[regime_idx]
            
            if regime == 'bull':
                # Bull market: positive trends
                trend = np.random.normal(0.02, 0.1, n_features)
                volatility = np.abs(np.random.normal(0.8, 0.2, n_features)) + 0.1
                X[i] = np.random.normal(trend, volatility)
                y[i] = 1 if np.mean(X[i][:4]) > 0 else 0
                
            elif regime == 'bear':
                # Bear market: negative trends
                trend = np.random.normal(-0.02, 0.1, n_features)
                volatility = np.abs(np.random.normal(1.2, 0.3, n_features)) + 0.1
                X[i] = np.random.normal(trend, volatility)
                y[i] = 1 if np.mean(X[i][:4]) > -0.5 else 0
                
            elif regime == 'volatile':
                # Volatile market: high variance
                trend = np.random.normal(0, 0.05, n_features)
                volatility = np.abs(np.random.normal(2.0, 0.5, n_features)) + 0.1
                X[i] = np.random.normal(trend, volatility)
                y[i] = 1 if np.std(X[i][:4]) < np.mean(np.abs(X[i][4:])) else 0
                
            else:  # stable
                # Stable market: low variance  
                trend = np.random.normal(0.005, 0.02, n_features)
                volatility = np.abs(np.random.normal(0.3, 0.1, n_features)) + 0.1
                X[i] = np.random.normal(trend, volatility)
                y[i] = 1 if np.sum(X[i][:2]) > np.sum(X[i][2:4]) else 0
                
        return X, y
    
    @staticmethod
    def generate_sensor_network_drift(n_samples: int = 3500,
                                    n_features: int = 6) -> Tuple[np.ndarray, np.ndarray]:
        """
        Simulate sensor network data with gradual sensor degradation.
        """
        X = np.zeros((n_samples, n_features))
        y = np.zeros(n_samples, dtype=int)
        
        # Sensor degradation parameters
        degradation_rate = np.random.uniform(0.0001, 0.001, n_features)
        noise_increase_rate = np.random.uniform(0.0002, 0.002, n_features)
        
        for i in range(n_samples):
            # Time factor
            t = i / n_samples
            
            # Base sensor readings
            base_readings = np.random.normal(0, 1, n_features)
            
            # Add degradation effects
            degradation_factor = 1 + degradation_rate * i
            noise_factor = 1 + noise_increase_rate * i
            
            X[i] = base_readings * degradation_factor + np.random.normal(0, noise_factor, n_features)
            
            # Detection based on anomaly patterns
            # As sensors degrade, normal patterns become harder to detect
            anomaly_threshold = 1.5 - 0.5 * t  # Threshold decreases over time
            anomaly_score = np.linalg.norm(X[i] - np.mean(X[max(0, i-100):i+1], axis=0))
            
            y[i] = 1 if anomaly_score > anomaly_threshold else 0
            
        return X, y
    
    @staticmethod
    def generate_user_behavior_drift(n_samples: int = 4500,
                                   n_features: int = 10) -> Tuple[np.ndarray, np.ndarray]:
        """
        Simulate user behavior data with changing preferences over time.
        """
        X = np.zeros((n_samples, n_features))
        y = np.zeros(n_samples, dtype=int)
        
        # User preference evolution
        preference_weights = np.random.uniform(0.1, 1.0, n_features)
        preference_trends = np.random.uniform(-0.0005, 0.0005, n_features)
        
        for i in range(n_samples):
            # Update preferences gradually
            preference_weights += preference_trends
            preference_weights = np.clip(preference_weights, 0.01, 2.0)
            
            # Generate user behavior features
            base_behavior = np.random.exponential(1.0, n_features)
            seasonal_factor = 1 + 0.3 * np.sin(2 * np.pi * i / 365)  # Yearly seasonality
            weekly_factor = 1 + 0.1 * np.sin(2 * np.pi * i / 7)      # Weekly patterns
            
            X[i] = base_behavior * preference_weights * seasonal_factor * weekly_factor
            
            # Binary outcome based on complex user decision model
            decision_score = np.dot(X[i][:5], [1.0, -0.5, 0.8, -0.3, 0.6])
            time_penalty = 0.1 * (i % 100) / 100  # Fatigue effect
            social_influence = 0.2 * np.random.choice([-1, 1])  # Random social factor
            
            final_score = decision_score - time_penalty + social_influence
            y[i] = 1 if final_score > 0 else 0
            
        return X, y


class ConceptDriftDatasetGenerator:
    """
    Main interface for generating various concept drift datasets.
    """
    
    def __init__(self, random_seed: int = 42):
        self.random_seed = random_seed
        np.random.seed(random_seed)
        
    def get_all_datasets(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
        """
        Generate all available datasets for comprehensive evaluation.
        
        Returns:
            Dictionary mapping dataset names to (X, y) tuples
        """
        datasets = {}
        
        # Advanced synthetic datasets
        datasets['sine_clusters'] = AdvancedSyntheticGenerator.generate_sine_clusters(
            n_samples=3000, n_features=4
        )
        
        datasets['multi_modal_drift'] = AdvancedSyntheticGenerator.generate_multi_modal_drift(
            n_samples=3000, n_features=6, n_modes=4
        )
        
        datasets['gradual_concept_drift'] = AdvancedSyntheticGenerator.generate_gradual_concept_drift(
            n_samples=3000, n_features=4, drift_rate=0.002
        )
        
        # Real-world inspired datasets
        datasets['financial_market'] = RealWorldDataGenerator.generate_financial_market_drift(
            n_samples=4000, n_features=8
        )
        
        datasets['sensor_network'] = RealWorldDataGenerator.generate_sensor_network_drift(
            n_samples=3500, n_features=6
        )
        
        datasets['user_behavior'] = RealWorldDataGenerator.generate_user_behavior_drift(
            n_samples=4500, n_features=10
        )
        
        # Enhanced versions of classic datasets
        datasets['enhanced_sea'] = self._generate_enhanced_sea()
        datasets['enhanced_rotating_hyperplane'] = self._generate_enhanced_rotating_hyperplane()
        
        # Preprocess all datasets
        processed_datasets = {}
        for name, (X, y) in datasets.items():
            # Scale features
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X)
            
            # Reduce dimensionality if needed (for quantum kernel compatibility)
            if X_scaled.shape[1] > 4:
                from sklearn.decomposition import PCA
                pca = PCA(n_components=4, random_state=self.random_seed)
                X_scaled = pca.fit_transform(X_scaled)
            elif X_scaled.shape[1] < 4:
                # Pad with zeros if less than 4 features
                padding = np.zeros((X_scaled.shape[0], 4 - X_scaled.shape[1]))
                X_scaled = np.hstack([X_scaled, padding])
                
            processed_datasets[name] = (X_scaled, y)
            
        return processed_datasets
    
    def _generate_enhanced_sea(self) -> Tuple[np.ndarray, np.ndarray]:
        """Generate enhanced SEA dataset with noise and class imbalance."""
        n_samples = 3000
        X = np.random.uniform(0, 10, (n_samples, 4))
        y = np.zeros(n_samples, dtype=int)
        
        drift_points = [1000, 2000]
        
        for i in range(n_samples):
            # Add noise that increases over time
            noise_level = 0.1 + 0.2 * (i / n_samples)
            noise = np.random.normal(0, noise_level, 4)
            X[i] += noise
            
            # Concept-dependent decision boundaries
            if i < drift_points[0]:
                threshold = 8.0 + np.random.normal(0, 0.5)  # Concept 1 with noise
                y[i] = 1 if (X[i, 0] + X[i, 1]) <= threshold else 0
            elif i < drift_points[1]:
                threshold = 9.5 + np.sin(i / 100) * 0.5  # Concept 2 with oscillation
                y[i] = 1 if (X[i, 0] + X[i, 1] + 0.3 * X[i, 2]) <= threshold else 0
            else:
                # Concept 3 with feature interactions
                decision_value = (X[i, 0] + X[i, 1] - 0.2 * X[i, 2] * X[i, 3])
                threshold = 7.5 + np.random.normal(0, 0.3)
                y[i] = 1 if decision_value <= threshold else 0
                
            # Introduce class imbalance in later concepts
            if i > drift_points[1] and np.random.random() < 0.3:
                y[i] = 1 - y[i]  # Flip 30% of labels to create imbalance
                
        return X, y
    
    def _generate_enhanced_rotating_hyperplane(self) -> Tuple[np.ndarray, np.ndarray]:
        """Generate enhanced rotating hyperplane with multiple rotation axes."""
        n_samples = 3000
        X = np.random.uniform(-1, 1, (n_samples, 4))
        y = np.zeros(n_samples, dtype=int)
        
        for i in range(n_samples):
            # Multiple rotation axes
            angle1 = 2 * np.pi * i / n_samples
            angle2 = np.pi * i / n_samples
            
            # Primary rotation
            w1 = np.array([np.cos(angle1), np.sin(angle1), 0.1, 0.1])
            
            # Secondary rotation (slower)
            w2 = np.array([0.1, 0.1, np.cos(angle2), np.sin(angle2)])
            
            # Combined decision boundary
            decision_value1 = np.dot(X[i], w1)
            decision_value2 = np.dot(X[i], w2)
            
            # Time-dependent combination of decision boundaries
            t = i / n_samples
            combined_value = (1 - t) * decision_value1 + t * decision_value2
            
            # Add temporal noise
            noise = 0.1 * np.sin(10 * np.pi * t) * np.random.normal(0, 1)
            y[i] = 1 if combined_value + noise > 0 else 0
            
        return X, y


def get_enhanced_datasets(random_seed: int = 42) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
    """
    Convenience function to get all enhanced datasets.
    
    Args:
        random_seed: Random seed for reproducibility
        
    Returns:
        Dictionary of dataset name -> (X, y) pairs
    """
    generator = ConceptDriftDatasetGenerator(random_seed=random_seed)
    return generator.get_all_datasets()