"""
Dataset loading and preprocessing functions for experiments.

This module handles:
- Loading datasets from various sources (OpenML, Friedman, embedded)
- Data preprocessing and categorical encoding
- Dataset validation
"""

import numpy as np
import openml
from sklearn.datasets import make_friedman1
from category_encoders import TargetEncoder
from sklearn.compose import ColumnTransformer
import pandas as pd


def preprocess_openml_data(X, y):
    """
    Preprocess OpenML data to handle categorical variables.
    Uses TargetEncoder for categorical variables.

    Args:
        X: Feature data (DataFrame or array-like)
        y: Target data (Series or array-like)

    Returns:
        Tuple of (X_processed, y_processed) as numpy arrays
    """
    # Convert to DataFrame if not already
    if not isinstance(X, pd.DataFrame):
        X = pd.DataFrame(X)
    if not isinstance(y, pd.Series):
        y = pd.Series(y)

    # Detect feature types
    categorical_features = list(X.select_dtypes(include=["category", "object"]).columns)
    numerical_features = list(X.select_dtypes(include=["number"]).columns)
    boolean_features = list(X.select_dtypes(include=["bool"]).columns)

    # If no categorical features, return as-is (but include boolean columns)
    if len(categorical_features) == 0:
        # Convert boolean to int and combine with numerical
        X_combined = X.copy()
        for col in boolean_features:
            X_combined[col] = X_combined[col].astype(int)
        return X_combined.values, y.astype(float).values

    # Use TargetEncoder for categorical variables, passthrough for numerical and boolean
    transformers = []
    if len(numerical_features) > 0:
        transformers.append(("num", "passthrough", numerical_features))
    if len(boolean_features) > 0:
        transformers.append(("bool", "passthrough", boolean_features))
    if len(categorical_features) > 0:
        transformers.append(("cat", TargetEncoder(), categorical_features))

    preprocessor = ColumnTransformer(transformers=transformers)
    X_processed = preprocessor.fit_transform(X, y)

    # Convert to dense array if sparse
    if hasattr(X_processed, "toarray"):
        X_processed = X_processed.toarray()

    # Check for NaN values
    if np.isnan(X_processed).any():
        raise ValueError(
            "Input data contains NaN values after preprocessing. Please clean the data before proceeding."
        )

    return X_processed, y.astype(float).values


def load_dataset(dataset_config):
    """
    Load dataset based on configuration.

    Args:
        dataset_config: Dictionary with dataset configuration including:
            - type: "friedman", "openml", "openml_task_collection", or "embedded"
            - Other type-specific parameters

    Returns:
        Dictionary with keys: name, X, y, type, id (if applicable)
    """
    dataset_type = dataset_config.get("type", "friedman")

    if dataset_type == "friedman":
        return _load_friedman_dataset(dataset_config)
    elif dataset_type == "openml":
        return _load_openml_dataset(dataset_config)
    elif dataset_type == "openml_task":
        return _load_openml_task_dataset(dataset_config)
    elif dataset_type == "openml_task_collection":
        return _load_task_collection_metadata(dataset_config)
    elif dataset_type == "embedded":
        return _load_embedded_dataset(dataset_config)
    elif dataset_type == "local_npy":
        return _load_local_npy_dataset(dataset_config)
    else:
        raise ValueError(f"Unknown dataset type: {dataset_type}")


def _load_friedman_dataset(dataset_config):
    """Load Friedman synthetic dataset."""
    n_samples = dataset_config.get("n_samples", 1000)
    n_features = dataset_config.get("n_features", 10)
    noise = dataset_config.get("noise", 1.0)
    random_state = dataset_config.get("random_state", 42)

    X, y = make_friedman1(
        n_samples=n_samples,
        n_features=n_features,
        noise=noise,
        random_state=random_state,
    )

    return {
        "name": f"friedman1_n{n_samples}_f{n_features}",
        "X": X,
        "y": y,
        "type": "friedman",
    }


def _load_openml_dataset(dataset_config):
    """Load dataset from OpenML by dataset ID."""
    dataset_id = dataset_config.get("dataset_id")
    if not dataset_id:
        raise ValueError("dataset_id required for OpenML dataset type")

    dataset = openml.datasets.get_dataset(dataset_id)
    X, y, _, _ = dataset.get_data(target=dataset.default_target_attribute)

    # Apply preprocessing
    X_processed, y_processed = preprocess_openml_data(X, y)

    return {
        "name": dataset.name,
        "X": X_processed,
        "y": y_processed,
        "type": "openml",
        "id": dataset_id,
    }


def _load_openml_task_dataset(dataset_config):
    """Load dataset from OpenML task."""
    task_id = dataset_config.get("task_id")
    if not task_id:
        raise ValueError("task_id required for OpenML task type")

    task = openml.tasks.get_task(task_id)
    dataset = task.get_dataset()
    X, y, _, _ = dataset.get_data(target=dataset.default_target_attribute)

    # Apply preprocessing
    X_processed, y_processed = preprocess_openml_data(X, y)

    return {
        "name": dataset.name,
        "X": X_processed,
        "y": y_processed,
        "type": "openml_task",
        "id": dataset.id,
        "task_id": task_id,
    }


def _load_task_collection_metadata(dataset_config):
    """Load metadata for task collection (does not load actual data)."""
    task_ids = dataset_config.get("task_ids", [])
    suite_id = dataset_config.get("suite_id")

    if not task_ids:
        raise ValueError("No task IDs provided for task collection")
    if not suite_id:
        raise ValueError("No suite ID provided for task collection")

    # For task collections, return metadata that will be used to create
    # separate experiments for each task
    return {
        "name": f"task_collection_suite_{suite_id}_{len(task_ids)}_tasks",
        "type": "openml_task_collection",
        "task_ids": task_ids,
        "suite_id": suite_id,
        "is_multi_dataset": True,
    }


def _load_embedded_dataset(dataset_config):
    """Load dataset that is embedded in config (e.g., from task collection)."""
    if "X" not in dataset_config or "y" not in dataset_config:
        raise ValueError("Embedded dataset must have X and y")

    # Convert lists back to numpy arrays if needed
    X = np.array(dataset_config["X"])
    y = np.array(dataset_config["y"])

    dataset_name = dataset_config.get("name", "embedded_dataset")

    return {
        "name": dataset_name,
        "X": X,
        "y": y,
        "type": "embedded",
        "id": dataset_config.get("id"),
        "task_id": dataset_config.get("task_id"),
    }


def _load_local_npy_dataset(dataset_config):
    """Load dataset from local .npy files."""
    x_path = dataset_config.get("data_path_x")
    y_path = dataset_config.get("data_path_y")

    if not x_path or not y_path:
        raise ValueError("Missing data paths for local_npy dataset")

    X = np.load(x_path)
    y = np.load(y_path)

    return {
        "name": dataset_config.get("name", "local_dataset"),
        "X": X,
        "y": y,
        "type": "local_npy",
        "id": dataset_config.get("id"),
        "task_id": dataset_config.get("task_id"),
    }
