"""
Data loading and preparation utilities for LapBoost.

This module provides functions for loading and preprocessing datasets
for semi-supervised learning experiments.
"""

import numpy as np
from sklearn.datasets import make_moons, make_circles, make_classification, make_regression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from typing import Tuple, Optional, Union, Dict, Any


def create_synthetic_dataset(
    dataset_type: str = 'moons',
    n_samples: int = 1000,
    noise: float = 0.1,
    n_features: int = 2,
    n_informative: int = 2,
    n_redundant: int = 0,
    n_classes: int = 2,
    labeled_ratio: float = 0.1,
    random_state: Optional[int] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Create a synthetic dataset for semi-supervised learning experiments.
    
    Parameters
    ----------
    dataset_type : str, default='moons'
        Type of dataset to generate. Options:
        - 'moons': Two interleaving half circles
        - 'circles': Two concentric circles
        - 'classification': Random n-class classification problem
        - 'regression': Random regression problem with informative features
    n_samples : int, default=1000
        Number of samples to generate
    noise : float, default=0.1
        Standard deviation of Gaussian noise added to the data
    n_features : int, default=2
        Number of features for 'classification' and 'regression' datasets
    n_informative : int, default=2
        Number of informative features for 'classification' and 'regression'
    n_redundant : int, default=0
        Number of redundant features for 'classification'
    n_classes : int, default=2
        Number of classes for 'classification'
    labeled_ratio : float, default=0.1
        Ratio of labeled samples to total samples
    random_state : int, optional
        Random seed for reproducibility
        
    Returns
    -------
    tuple
        X_labeled, y_labeled, X_unlabeled, X_all, y_all
    """
    rng = np.random.RandomState(random_state)
    
    # Generate data based on dataset type
    if dataset_type == 'moons':
        X, y = make_moons(n_samples=n_samples, noise=noise, random_state=random_state)
    elif dataset_type == 'circles':
        X, y = make_circles(n_samples=n_samples, noise=noise, factor=0.5, random_state=random_state)
    elif dataset_type == 'classification':
        X, y = make_classification(
            n_samples=n_samples,
            n_features=n_features,
            n_informative=n_informative,
            n_redundant=n_redundant,
            n_classes=n_classes,
            random_state=random_state
        )
    elif dataset_type == 'regression':
        X, y = make_regression(
            n_samples=n_samples,
            n_features=n_features,
            n_informative=n_informative,
            noise=noise,
            random_state=random_state
        )
    else:
        raise ValueError(f"Unknown dataset type: {dataset_type}")
    
    # Split into labeled and unlabeled sets
    n_labeled = int(n_samples * labeled_ratio)
    indices = rng.permutation(n_samples)
    
    labeled_idx = indices[:n_labeled]
    unlabeled_idx = indices[n_labeled:]
    
    X_labeled = X[labeled_idx]
    y_labeled = y[labeled_idx]
    X_unlabeled = X[unlabeled_idx]
    
    return X_labeled, y_labeled, X_unlabeled, X, y


def split_dataset(
    X: np.ndarray,
    y: np.ndarray,
    labeled_ratio: float = 0.1,
    test_size: float = 0.2,
    stratify: bool = True,
    scale_features: bool = True,
    random_state: Optional[int] = None
) -> Dict[str, np.ndarray]:
    """
    Split a dataset into labeled, unlabeled, and test sets for semi-supervised learning.
    
    Parameters
    ----------
    X : np.ndarray
        Feature array
    y : np.ndarray
        Target array
    labeled_ratio : float, default=0.1
        Ratio of labeled samples to total training samples
    test_size : float, default=0.2
        Ratio of test samples to total samples
    stratify : bool, default=True
        Whether to stratify the splits based on target (for classification)
    scale_features : bool, default=True
        Whether to apply standard scaling to features
    random_state : int, optional
        Random seed for reproducibility
        
    Returns
    -------
    dict
        Dictionary containing X_train_labeled, y_train_labeled, X_train_unlabeled,
        X_test, y_test, and optionally scaler
    """
    rng = np.random.RandomState(random_state)
    
    # First split into train and test
    stratify_array = y if stratify and len(np.unique(y)) < 10 else None
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state, stratify=stratify_array
    )
    
    # Scale features if requested
    if scale_features:
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)
    else:
        scaler = None
    
    # Split training data into labeled and unlabeled
    n_train = X_train.shape[0]
    # Calculate labeled samples while ensuring train_size will be valid (<1.0)
    train_size_raw = labeled_ratio / (1 - test_size)
    train_size = min(0.99, train_size_raw)  # Cap at 0.99 to ensure it's always valid
    n_labeled = int(n_train * train_size)
    
    stratify_array = y_train if stratify and len(np.unique(y)) < 10 else None
    indices = rng.permutation(n_train)
    
    if stratify and len(np.unique(y)) < 10:
        # Ensure stratified sampling for labeled data
        X_train_labeled, X_train_unlabeled, y_train_labeled, _ = train_test_split(
            X_train, y_train, 
            train_size=n_labeled/n_train, 
            random_state=random_state, 
            stratify=stratify_array
        )
    else:
        # Simple random sampling
        labeled_idx = indices[:n_labeled]
        unlabeled_idx = indices[n_labeled:]
        
        X_train_labeled = X_train[labeled_idx]
        y_train_labeled = y_train[labeled_idx]
        X_train_unlabeled = X_train[unlabeled_idx]
    
    result = {
        'X_train_labeled': X_train_labeled,
        'y_train_labeled': y_train_labeled,
        'X_train_unlabeled': X_train_unlabeled,
        'X_test': X_test,
        'y_test': y_test
    }
    
    if scaler is not None:
        result['scaler'] = scaler
    
    return result


def load_ssl_benchmark_dataset(dataset_name: str, labeled_ratio: float = 0.1, random_state: Optional[int] = None):
    """
    Load benchmark datasets for semi-supervised learning.
    
    Parameters
    ----------
    dataset_name : str
        Name of the dataset:
        - 'moons': Two interleaving half circles
        - 'circles': Two concentric circles
        - 'digits': Handwritten digit classification
        - 'breast_cancer': Medical application, high importance
        - 'wine_quality': Clean numerical data, multi-class
        - 'adult_census': Large, mixed features, real-world
        - 'isolet': High-dimensional, speech features
        - '20newsgroups': Text features, many classes
        - 'boston': Boston housing regression
        - 'diabetes': Diabetes regression
    labeled_ratio : float, default=0.1
        Ratio of labeled samples to total training samples
    random_state : int, optional
        Random seed for reproducibility
        
    Returns
    -------
    dict
        Dictionary containing X_train_labeled, y_train_labeled, X_train_unlabeled,
        X_test, y_test
    """
    from sklearn import datasets
    import pandas as pd
    import numpy as np
    from sklearn.datasets import fetch_20newsgroups
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.preprocessing import StandardScaler, LabelEncoder
    
    # Load dataset based on name
    if dataset_name == 'moons':
        return create_synthetic_dataset('moons', 1000, 0.1, labeled_ratio=labeled_ratio, random_state=random_state)
    elif dataset_name == 'circles':
        return create_synthetic_dataset('circles', 1000, 0.1, labeled_ratio=labeled_ratio, random_state=random_state)
    elif dataset_name == 'digits':
        digits = datasets.load_digits()
        return split_dataset(digits.data, digits.target, labeled_ratio=labeled_ratio, random_state=random_state)
    elif dataset_name == 'breast_cancer':
        cancer = datasets.load_breast_cancer()
        return split_dataset(cancer.data, cancer.target, labeled_ratio=labeled_ratio, random_state=random_state)
    elif dataset_name == 'boston':
        boston = datasets.fetch_california_housing()
        return split_dataset(boston.data, boston.target, labeled_ratio=labeled_ratio, 
                            stratify=False, random_state=random_state)
    elif dataset_name == 'diabetes':
        diabetes = datasets.load_diabetes()
        return split_dataset(diabetes.data, diabetes.target, labeled_ratio=labeled_ratio, 
                            stratify=False, random_state=random_state)
    elif dataset_name == 'wine_quality':
        try:
            from sklearn.datasets import fetch_openml
            wine = fetch_openml(name="wine-quality", version=1, as_frame=True, parser="auto")
            X = wine.data.values
            y = wine.target.astype(int).values
            return split_dataset(X, y, labeled_ratio=labeled_ratio, random_state=random_state)
        except:
            # Fallback to a simpler wine dataset if the full one isn't available
            wine = datasets.load_wine()
            return split_dataset(wine.data, wine.target, labeled_ratio=labeled_ratio, random_state=random_state)
    elif dataset_name == 'adult_census':
        try:
            from sklearn.datasets import fetch_openml
            adult = fetch_openml(name="adult", version=2, as_frame=True, parser="auto")
            
            # Process categorical features
            X = pd.get_dummies(adult.data).values
            le = LabelEncoder()
            y = le.fit_transform(adult.target)
            
            # Take a subset if the dataset is too large (for faster benchmarking)
            if X.shape[0] > 10000:
                rng = np.random.RandomState(random_state)
                idx = rng.choice(X.shape[0], 10000, replace=False)
                X = X[idx]
                y = y[idx]
                
            return split_dataset(X, y, labeled_ratio=labeled_ratio, random_state=random_state)
        except Exception as e:
            print(f"Error loading adult dataset: {e}")
            # Fallback to a simpler dataset
            return load_ssl_benchmark_dataset('breast_cancer', labeled_ratio=labeled_ratio, random_state=random_state)
    elif dataset_name == 'isolet':
        try:
            from sklearn.datasets import fetch_openml
            isolet = fetch_openml(name="isolet", version=1, as_frame=True, parser="auto")
            X = isolet.data.values
            # Convert to int and shift classes from 1-indexed to 0-indexed
            y = isolet.target.astype(int).values - 1
            
            # Take a subset if the dataset is too large
            if X.shape[0] > 5000:
                rng = np.random.RandomState(random_state)
                idx = rng.choice(X.shape[0], 5000, replace=False)
                X = X[idx]
                y = y[idx]
                
            return split_dataset(X, y, labeled_ratio=labeled_ratio, random_state=random_state)
        except Exception as e:
            print(f"Error loading isolet dataset: {e}")
            # Fallback to digits
            return load_ssl_benchmark_dataset('digits', labeled_ratio=labeled_ratio, random_state=random_state)
    elif dataset_name == '20newsgroups':
        try:
            # Load 20 newsgroups data
            newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))
            
            # Extract TF-IDF features
            vectorizer = TfidfVectorizer(max_features=1000)
            X = vectorizer.fit_transform(newsgroups.data).toarray()
            y = newsgroups.target
            
            # Take a subset if the dataset is too large
            if X.shape[0] > 5000:
                rng = np.random.RandomState(random_state)
                idx = rng.choice(X.shape[0], 5000, replace=False)
                X = X[idx]
                y = y[idx]
                
            return split_dataset(X, y, labeled_ratio=labeled_ratio, random_state=random_state)
        except Exception as e:
            print(f"Error loading 20newsgroups dataset: {e}")
            # Fallback
            return load_ssl_benchmark_dataset('digits', labeled_ratio=labeled_ratio, random_state=random_state)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
