import numpy as np
import pandas as pd
from typing import Tuple, Optional, Dict, List, Union, Any
from dataclasses import dataclass
import openml
from sklearn.preprocessing import StandardScaler
import requests
from scipy.io import arff
from io import StringIO
import torch
from torchvision.datasets import EMNIST
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import QuantileTransformer, OneHotEncoder, FunctionTransformer
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.compose import TransformedTargetRegressor

# Import Grinsztajn preprocessing utilities
import sys
sys.path.append('.')  # Adjust if grinsztajn_preprocessing.py is in a different location

def _as_dataframe(X: Union[np.ndarray, pd.DataFrame]) -> pd.DataFrame:
    if isinstance(X, pd.DataFrame):
        return X.copy()
    X = np.asarray(X)
    return pd.DataFrame(X, columns=[f"x{j}" for j in range(X.shape[1])])

def _is_categorical_series(s: pd.Series) -> bool:
    return (pd.api.types.is_object_dtype(s) or
            pd.api.types.is_categorical_dtype(s) or
            pd.api.types.is_bool_dtype(s))

def grinsztajn_remove_side_issues(
    X: Union[np.ndarray, pd.DataFrame],
    y: np.ndarray,
    task: str = "regression",
    *,
    # Paper says "columns containing many missing data" but does not give a threshold -> keep it parametric
    missing_col_threshold: float = 0.5,
    max_train_size: int = 10_000,
    random_state: int = 0,
    # Categorical / numerical filtering
    max_categorical_cardinality: int = 20,
    min_unique_numerical: int = 10,
    convert_binary_numerical_to_categorical: bool = True,
) -> Tuple[pd.DataFrame, np.ndarray, Dict]:
    """
    Replicates Grinsztajn et al. "Removing side issues" block:
      - Remove missing data (drop columns with many missing, then rows with any missing)
      - (Classification) Binarize to top-2 classes and balance 50/50
      - Drop categorical features with > 20 categories
      - Drop numerical features with < 10 unique values
      - Convert numerical binary features (2 unique) to categorical

    Notes:
      * The paper does not specify the exact missing_col_threshold; keep it configurable.
      * Truncation to max_train_size is *not* applied here (needs the train split); see `truncate_train_split`.
    """
    rng = np.random.default_rng(random_state)
    Xdf = _as_dataframe(X)
    y = np.asarray(y)
    info: Dict = {}

    # -----------------------------
    # 1) Missing data removal
    # -----------------------------
    miss_rate = Xdf.isna().mean(axis=0)
    drop_cols_missing = miss_rate[miss_rate > missing_col_threshold].index.tolist()
    if drop_cols_missing:
        Xdf = Xdf.drop(columns=drop_cols_missing)
    # drop rows with any missing
    row_ok = ~Xdf.isna().any(axis=1)
    Xdf = Xdf.loc[row_ok].reset_index(drop=True)
    y = y[row_ok.to_numpy()]

    info["dropped_cols_missing"] = drop_cols_missing
    info["n_after_missing"] = int(len(Xdf))

    # -----------------------------
    # 2) Type inference (cat vs num) + conversions
    # -----------------------------
    cat_cols: List[str] = []
    num_cols: List[str] = []

    for c in Xdf.columns:
        s = Xdf[c]
        if _is_categorical_series(s):
            cat_cols.append(c)
        else:
            # try coerce numeric
            Xdf[c] = pd.to_numeric(s, errors="coerce")
            # if coercion introduces missing, treat as categorical-ish (safer)
            if Xdf[c].isna().any():
                cat_cols.append(c)
            else:
                num_cols.append(c)

    # Convert binary numerical to categorical (2 unique values)
    if convert_binary_numerical_to_categorical and num_cols:
        moved = []
        for c in list(num_cols):
            nun = Xdf[c].nunique(dropna=False)
            if nun == 2:
                Xdf[c] = Xdf[c].astype("category")
                num_cols.remove(c)
                cat_cols.append(c)
                moved.append(c)
        info["binary_num_to_cat"] = moved

    # Drop categorical with too many categories (>20)
    drop_cat = []
    for c in list(cat_cols):
        nun = Xdf[c].nunique(dropna=False)
        if nun > max_categorical_cardinality:
            drop_cat.append(c)
    if drop_cat:
        Xdf = Xdf.drop(columns=drop_cat)
        cat_cols = [c for c in cat_cols if c not in drop_cat]
    info["dropped_cat_high_cardinality"] = drop_cat

    # Drop numerical with too few unique values (<10)
    drop_num = []
    for c in list(num_cols):
        nun = Xdf[c].nunique(dropna=False)
        if nun < min_unique_numerical:
            drop_num.append(c)
    if drop_num:
        Xdf = Xdf.drop(columns=drop_num)
        num_cols = [c for c in num_cols if c not in drop_num]
    info["dropped_num_low_unique"] = drop_num

    info["cat_cols"] = cat_cols
    info["num_cols"] = num_cols
    info["p_after_feature_filters"] = int(Xdf.shape[1])

    # -----------------------------
    # 3) Classification binarization + balancing
    # -----------------------------
    if task == "classification":
        y1d = np.asarray(y).ravel()
        # top-2 most frequent classes
        vals, counts = np.unique(y1d, return_counts=True)
        top2 = vals[np.argsort(-counts)[:2]]
        mask = np.isin(y1d, top2)
        Xdf = Xdf.loc[mask].reset_index(drop=True)
        y1d = y1d[mask]

        # balance 50/50
        idx0 = np.where(y1d == top2[0])[0]
        idx1 = np.where(y1d == top2[1])[0]
        m = min(len(idx0), len(idx1))
        idx0 = rng.choice(idx0, size=m, replace=False)
        idx1 = rng.choice(idx1, size=m, replace=False)
        keep = np.concatenate([idx0, idx1])
        rng.shuffle(keep)

        Xdf = Xdf.iloc[keep].reset_index(drop=True)
        y = y1d[keep]

        info["binarized_classes"] = top2.tolist()
        info["n_after_balance"] = int(len(Xdf))

    return Xdf, y, info


def _load_scm20d_raw() -> Tuple[np.ndarray, np.ndarray]:
    """
    Load SCM20D (multi-target regression) from OpenML - RAW version.
    16 regression targets.
    Returns raw data without any preprocessing for Grinsztajn pipeline.
    """
    scm20d_id = 41486
    dataset = openml.datasets.get_dataset(scm20d_id)
    data, _, _, _ = dataset.get_data(
        dataset_format="dataframe",
        target=None
    )

    n_targets = 16
    Y = data.iloc[:, -n_targets:].to_numpy(dtype=np.float32)
    X = data.iloc[:, :-n_targets]

    return X, Y


def _load_sgemm_raw() -> Tuple[np.ndarray, np.ndarray]:
    """
    Load SGEMM GPU kernel performance (multi-target regression) from OpenML - RAW version.
    4 regression targets.
    Returns raw data without any preprocessing for Grinsztajn pipeline.
    """
    sgemm_id = 44069
    dataset = openml.datasets.get_dataset(sgemm_id)
    data, _, _, _ = dataset.get_data(
        dataset_format="dataframe",
        target=None
    )

    n_targets = 4
    Y = data.iloc[:, -n_targets:].to_numpy(dtype=np.float32)
    X = data.iloc[:, :-n_targets]

    return X, Y


def _load_scm1d_raw() -> Tuple[np.ndarray, np.ndarray]:
    """
    Load SCM1D (multi-target regression) from OpenML - RAW version.
    16 regression targets.
    Returns raw data without any preprocessing for Grinsztajn pipeline.
    """
    scm1d_id = 41485
    dataset = openml.datasets.get_dataset(scm1d_id)
    data, _, _, _ = dataset.get_data(
        dataset_format="dataframe",
        target=None
    )

    n_targets = 16
    Y = data.iloc[:, -n_targets:].to_numpy(dtype=np.float32)
    X = data.iloc[:, :-n_targets]

    return X, Y


def _load_rf2_raw() -> Tuple[np.ndarray, np.ndarray]:
    """
    Load RF2 (multi-target regression) from OpenML - RAW version.
    8 regression targets.
    Returns raw data without any preprocessing for Grinsztajn pipeline.
    """
    rf2_id = 41484
    dataset = openml.datasets.get_dataset(rf2_id)
    data, _, _, _ = dataset.get_data(
        dataset_format="dataframe",
        target=None
    )

    n_targets = 8
    Y = data.iloc[:, -n_targets:].to_numpy(dtype=np.float32)
    X = data.iloc[:, :-n_targets]

    # Remove rows with NaN/Inf
    X_np = X.values.astype(float)
    valid = np.isfinite(X_np).all(axis=1) & np.isfinite(Y).all(axis=1)
    if not valid.all():
        n_invalid = (~valid).sum()
        print(f"  [RF2] Removing {n_invalid}/{len(X)} samples with NaN/Inf")
        X = X[valid].reset_index(drop=True)
        Y = Y[valid]

    return X, Y


def _apply_grinsztajn_preprocessing(
    X: pd.DataFrame,
    y: np.ndarray,
    dataset_name: str,
    random_state: int = 42
) -> Tuple[pd.DataFrame, np.ndarray, dict]:
    """
    Apply Grinsztajn et al. preprocessing pipeline to a dataset.
    
    This includes:
    - Removing columns with >50% missing data
    - Removing rows with any missing data
    - Identifying and filtering categorical/numerical features
    - Dropping categorical features with >20 categories
    - Dropping numerical features with <10 unique values
    - Converting binary numerical features to categorical
    
    Parameters
    ----------
    X : pd.DataFrame
        Raw feature matrix
    y : np.ndarray
        Raw target array
    dataset_name : str
        Name of dataset (for logging)
    random_state : int
        Random seed
        
    Returns
    -------
    X_clean : pd.DataFrame
        Cleaned feature matrix
    y_clean : np.ndarray
        Cleaned target array
    info : dict
        Information about the preprocessing steps
    """
    print(f"\n[{dataset_name}] Applying Grinsztajn preprocessing...")
    print(f"  Original shape: X={X.shape}, y={y.shape}")
    
    # Apply Grinsztajn "removing side issues" preprocessing
    X_clean, y_clean, info = grinsztajn_remove_side_issues(
        X=X,
        y=y,
        task="regression",
        missing_col_threshold=0.5,
        max_train_size=10_000,  # Not applied here, done later in split
        random_state=random_state,
        max_categorical_cardinality=20,
        min_unique_numerical=10,
        convert_binary_numerical_to_categorical=True,
    )
    
    print(f"  After preprocessing: X={X_clean.shape}, y={y_clean.shape}")
    print(f"  Categorical columns: {len(info['cat_cols'])}")
    print(f"  Numerical columns: {len(info['num_cols'])}")
    print(f"  Dropped columns (missing): {len(info['dropped_cols_missing'])}")
    print(f"  Dropped columns (high cardinality): {len(info['dropped_cat_high_cardinality'])}")
    print(f"  Dropped columns (low unique): {len(info['dropped_num_low_unique'])}")
    
    return X_clean, y_clean, info


def _read_arff_url(url: str) -> pd.DataFrame:
    """Télécharge et parse un fichier ARFF depuis une URL."""
    r = requests.get(url, timeout=120)
    r.raise_for_status()
    txt = r.text
    if "@DATA" not in txt.upper():
        raise RuntimeError("Le contenu téléchargé ne ressemble pas à un .arff (probable HTML/redirect).")
    data, _ = arff.loadarff(StringIO(txt))
    df = pd.DataFrame(data)
    # decode bytes -> str si nécessaire
    for c in df.columns:
        if df[c].dtype == object and len(df) and isinstance(df[c].iloc[0], (bytes, bytearray)):
            df[c] = df[c].str.decode("utf-8")
    return df


def _load_rf1() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Load RF1 (multi-target regression) from Mulan.
    8 regression targets.
    Returns: X_train, Y_train, X_test, Y_test (already split)
    """
    MULAN_BASE = "https://sourceforge.net/projects/mulan/files/datasets/multi-target%20regression%20datasets"
    train_url = f"{MULAN_BASE}/rf1-train.arff/download"
    test_url = f"{MULAN_BASE}/rf1-test.arff/download"
    n_targets = 8
    
    print("  Downloading RF1 train set...")
    df_train = _read_arff_url(train_url)
    print("  Downloading RF1 test set...")
    df_test = _read_arff_url(test_url)
    
    X_train_df = df_train.iloc[:, :-n_targets]
    Y_train_df = df_train.iloc[:, -n_targets:]
    X_test_df = df_test.iloc[:, :-n_targets]
    Y_test_df = df_test.iloc[:, -n_targets:]
    
    # Vérifier et nettoyer les données Y
    print(f"  Checking for NaN/Inf in Y...")
    Y_train = Y_train_df.to_numpy(dtype=np.float32)
    Y_test = Y_test_df.to_numpy(dtype=np.float32)
    
    # Supprimer les lignes avec NaN/Inf dans Y
    valid_train_y = np.isfinite(Y_train).all(axis=1)
    valid_test_y = np.isfinite(Y_test).all(axis=1)
    
    if not valid_train_y.all():
        n_invalid = (~valid_train_y).sum()
        print(f"  WARNING: Removing {n_invalid}/{len(Y_train)} train samples with NaN/Inf in Y")
        X_train_df = X_train_df[valid_train_y]
        Y_train = Y_train[valid_train_y]
    
    if not valid_test_y.all():
        n_invalid = (~valid_test_y).sum()
        print(f"  WARNING: Removing {n_invalid}/{len(Y_test)} test samples with NaN/Inf in Y")
        X_test_df = X_test_df[valid_test_y]
        Y_test = Y_test[valid_test_y]
    
    # Vérifier X avant standardisation
    print(f"  Checking for NaN/Inf in X...")
    X_train_raw = X_train_df.values.astype(float)
    X_test_raw = X_test_df.values.astype(float)
    
    valid_train_x = np.isfinite(X_train_raw).all(axis=1)
    valid_test_x = np.isfinite(X_test_raw).all(axis=1)
    
    if not valid_train_x.all():
        n_invalid = (~valid_train_x).sum()
        print(f"  WARNING: Removing {n_invalid}/{len(X_train_raw)} train samples with NaN/Inf in X")
        X_train_df = X_train_df[valid_train_x]
        Y_train = Y_train[valid_train_x]
    
    if not valid_test_x.all():
        n_invalid = (~valid_test_x).sum()
        print(f"  WARNING: Removing {n_invalid}/{len(X_test_raw)} test samples with NaN/Inf in X")
        X_test_df = X_test_df[valid_test_x]
        Y_test = Y_test[valid_test_x]
    
    # Identifier et supprimer les colonnes constantes (basé sur train uniquement)
    nunique = X_train_df.nunique(dropna=False)
    keep_cols = nunique[nunique > 1].index.tolist()
    X_train_df = X_train_df[keep_cols]
    X_test_df = X_test_df[keep_cols]  # Appliquer le même filtrage au test
    
    # Standardiser (fit sur train, transform sur train et test)
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train_df.values.astype(float)).astype(np.float32)
    X_test = scaler.transform(X_test_df.values.astype(float)).astype(np.float32)
    
    # Vérifier après standardisation
    if not np.isfinite(X_train).all():
        print(f"  ERROR: X_train contains NaN/Inf after standardization!")
        print(f"    NaN count: {np.isnan(X_train).sum()}, Inf count: {np.isinf(X_train).sum()}")
        # Identifier les colonnes problématiques
        nan_cols = np.where(~np.isfinite(X_train).all(axis=0))[0]
        print(f"    Problematic columns: {nan_cols}")
        raise ValueError("X_train contains NaN/Inf after standardization")
    
    if not np.isfinite(X_test).all():
        print(f"  ERROR: X_test contains NaN/Inf after standardization!")
        raise ValueError("X_test contains NaN/Inf after standardization")
    
    # DIAGNOSTIC: Vérifier la distribution train vs test
    print(f"\n  [DIAGNOSTIC] Checking train/test distribution:")
    print(f"    Y_train: mean={Y_train.mean(axis=0).mean():.4f}, std={Y_train.std(axis=0).mean():.4f}, "
          f"min={Y_train.min():.4f}, max={Y_train.max():.4f}")
    print(f"    Y_test:  mean={Y_test.mean(axis=0).mean():.4f}, std={Y_test.std(axis=0).mean():.4f}, "
          f"min={Y_test.min():.4f}, max={Y_test.max():.4f}")
    print(f"    X_train: mean={X_train.mean():.4f}, std={X_train.std():.4f}")
    print(f"    X_test:  mean={X_test.mean():.4f}, std={X_test.std():.4f}")
    
    # Vérifier si les distributions sont très différentes
    y_mean_ratio = Y_test.mean() / Y_train.mean() if Y_train.mean() != 0 else np.inf
    y_std_ratio = Y_test.std() / Y_train.std() if Y_train.std() != 0 else np.inf
    
    if abs(y_mean_ratio - 1.0) > 0.5 or abs(y_std_ratio - 1.0) > 0.5:
        print(f"  WARNING: Train/Test distributions differ significantly!")
        print(f"    Y mean ratio (test/train): {y_mean_ratio:.2f}")
        print(f"    Y std ratio (test/train): {y_std_ratio:.2f}")
    
    print(f"\n  Final shapes: X_train={X_train.shape}, Y_train={Y_train.shape}, X_test={X_test.shape}, Y_test={Y_test.shape}")
    
    return X_train, Y_train, X_test, Y_test


def _load_rf1_resplit() -> Tuple[pd.DataFrame, np.ndarray]:
    """
    Load RF1 and combine train+test, then return for custom splitting.
    Use this if pre-split has distribution mismatch issues.
    Returns: X_df, Y (combined, not split)
    """
    MULAN_BASE = "https://sourceforge.net/projects/mulan/files/datasets/multi-target%20regression%20datasets"
    train_url = f"{MULAN_BASE}/rf1-train.arff/download"
    test_url = f"{MULAN_BASE}/rf1-test.arff/download"
    n_targets = 8
    
    print("  Downloading RF1 train set...")
    df_train = _read_arff_url(train_url)
    print("  Downloading RF1 test set...")
    df_test = _read_arff_url(test_url)
    
    # Combine train and test
    df_combined = pd.concat([df_train, df_test], ignore_index=True)
    
    X_df = df_combined.iloc[:, :-n_targets]
    Y_df = df_combined.iloc[:, -n_targets:]
    
    # Clean NaN/Inf
    X = X_df.values.astype(float)
    Y = Y_df.to_numpy(dtype=np.float32)
    
    valid = np.isfinite(X).all(axis=1) & np.isfinite(Y).all(axis=1)
    if not valid.all():
        n_invalid = (~valid).sum()
        print(f"  Removing {n_invalid}/{len(X)} samples with NaN/Inf")
        X_df = X_df[valid].reset_index(drop=True)
        Y = Y[valid]
    
    # Remove extreme outliers using IQR method (per column)
    print(f"  Detecting and removing extreme outliers in Y...")
    outlier_mask = np.zeros(len(Y), dtype=bool)
    
    for col_idx in range(Y.shape[1]):
        y_col = Y[:, col_idx]
        q1 = np.percentile(y_col, 25)
        q3 = np.percentile(y_col, 75)
        iqr = q3 - q1
        
        # Use 5x IQR as threshold (more conservative than typical 1.5x)
        lower_bound = q1 - 5 * iqr
        upper_bound = q3 + 5 * iqr
        
        col_outliers = (y_col < lower_bound) | (y_col > upper_bound)
        if col_outliers.any():
            n_outliers = col_outliers.sum()
            outlier_values = y_col[col_outliers]
            print(f"    Column {col_idx}: {n_outliers} outliers detected (values: {outlier_values[:5]}...)")
            print(f"      IQR bounds: [{lower_bound:.2f}, {upper_bound:.2f}]")
        
        outlier_mask |= col_outliers
    
    if outlier_mask.any():
        n_outliers_total = outlier_mask.sum()
        print(f"  Removing {n_outliers_total}/{len(Y)} samples with extreme outliers")
        X_df = X_df[~outlier_mask].reset_index(drop=True)
        Y = Y[~outlier_mask]
    
    print(f"  Combined dataset: {X_df.shape[0]} samples, {X_df.shape[1]} features, {Y.shape[1]} outputs")
    
    return X_df, Y

def load_dataset(
    name: str, 
    use_rf1_resplit: bool = False,
    use_grinsztajn_preprocessing: bool = False,
    random_state: int = 42,
    **kwargs
) -> Tuple:
    """
    Load one of the benchmarks with optional Grinsztajn preprocessing.

        - 'scm20d'  -> Y in R^{16} (multi-target regression)
        - 'sgemm'   -> Y in R^{4} (multi-target regression)
        - 'rf1'     -> Y in R^{8} (multi-target regression)
        - 'emnist' or 'emnist_balanced' -> 47-class classification

    Parameters
    ----------
    name : str
        Dataset name
    use_rf1_resplit : bool
        If True and name=='rf1', combine train+test for custom splitting
    use_grinsztajn_preprocessing : bool
        If True, apply Grinsztajn et al. preprocessing pipeline
        (removing missing data, filtering features, etc.)
    random_state : int
        Random seed for preprocessing
    **kwargs : dict
        Additional arguments (e.g., for EMNIST)

    Returns
    -------
    For scm20d/sgemm with use_grinsztajn_preprocessing=True:
        X : pd.DataFrame - Preprocessed features (still as DataFrame)
        y : np.ndarray - Target array
        
    For rf1:
        Same as before (depends on use_rf1_resplit)
        
    For EMNIST:
        X_df, y, classes
    """
    name = name.lower()
    
    if name == "scm20d":
        X, y = _load_scm20d_raw()
        if use_grinsztajn_preprocessing:
            X, y, _ = _apply_grinsztajn_preprocessing(X, y, "scm20d", random_state)
        return X, y
        
    elif name == "sgemm":
        X, y = _load_sgemm_raw()
        if use_grinsztajn_preprocessing:
            X, y, _ = _apply_grinsztajn_preprocessing(X, y, "sgemm", random_state)
        return X, y
        
    elif name == "scm1d":
        X, y = _load_scm1d_raw()
        if use_grinsztajn_preprocessing:
            X, y, _ = _apply_grinsztajn_preprocessing(X, y, "scm1d", random_state)
        return X, y
        
    elif name == "rf2":
        X, y = _load_rf2_raw()
        if use_grinsztajn_preprocessing:
            X, y, _ = _apply_grinsztajn_preprocessing(X, y, "rf2", random_state)
        return X, y
        
    elif name == "rf1":
        if use_rf1_resplit:
            X, y = _load_rf1_resplit()
            if use_grinsztajn_preprocessing:
                X, y, _ = _apply_grinsztajn_preprocessing(X, y, "rf1", random_state)
            return X, y
        else:
            return _load_rf1()
            
    elif name in ("emnist", "emnist_balanced"):
        root = kwargs.get("root", "./data_cache/emnist")
        train = kwargs.get("train", True)
        download = kwargs.get("download", False)
        max_rows = kwargs.get("max_rows", None)
        seed = kwargs.get("seed", 42)
        return load_emnist_balanced_as_df(
            root=root, train=train, download=download, 
            max_rows=max_rows, seed=seed
        )
    else:
        raise ValueError(
            f"Unknown dataset name '{name}'. "
            f"Expected one of: 'scm20d', 'scm1d', 'sgemm', 'rf1', 'rf2', 'emnist'."
        )


def load_emnist_balanced_as_df(
    root: str = "./data_cache",
    train: bool = True,
    download: bool = False,
    max_rows: int | None = None,
    seed: int = 0,
) -> tuple[pd.DataFrame, pd.Series, list[str]]:
    """
    Load EMNIST Balanced dataset and return as DataFrame.
    
    Parameters
    ----------
    root : str
        Root directory where dataset is stored or will be downloaded.
    train : bool
        If True, load training set; else load test set.
    download : bool
        If True, download dataset if not found locally.
    max_rows : int or None
        If specified, randomly sample this many rows.
    seed : int
        Random seed for sampling.
    
    Returns
    -------
    X_df : pd.DataFrame, shape (n, 784)
        Flattened pixel values in uint8.
    y : pd.Series, shape (n,)
        Integer labels.
    classes : list[str]
        List mapping label indices to class names/symbols.
    """
    try:
        ds = EMNIST(root=root, split="balanced", train=train, download=download)
    except Exception as e:
        raise RuntimeError(
            f"EMNIST introuvable dans {root}. "
            f"Si tu veux rester sans réseau, copie/monte le cache localement. "
            f"Sinon, lance une fois avec download=True."
        ) from e

    X = ds.data          # torch uint8, (n, 28, 28)
    y = ds.targets       # torch int64, (n,)
    classes = ds.classes

    n = X.shape[0]
    if max_rows is not None and max_rows < n:
        g = torch.Generator().manual_seed(seed)
        idx = torch.randperm(n, generator=g)[:max_rows]
        X = X[idx]
        y = y[idx]

    X_flat = X.reshape(X.shape[0], -1).cpu().numpy().astype(np.uint8)
    y_np = y.cpu().numpy().astype(np.int64)

    X_df = pd.DataFrame(X_flat, columns=[f"px_{i}" for i in range(X_flat.shape[1])])
    y_s = pd.Series(y_np, name="label")
    return X_df, y_s, classes


def summarize_classif_df(name: str, X: pd.DataFrame, y: pd.Series, classes: list[str], topk: int = 10) -> None:
    """
    Print summary statistics for a classification dataset.
    
    Parameters
    ----------
    name : str
        Dataset name for display.
    X : pd.DataFrame
        Feature matrix.
    y : pd.Series
        Labels.
    classes : list[str]
        Class names.
    topk : int
        Number of top classes to display in distribution.
    """
    print("=" * 90)
    print(name)
    print(f"X: shape={X.shape}, dtypes={X.dtypes.unique()[:5]}")
    print(f"y: shape={y.shape}, n_classes_observees={y.nunique()}, n_classes_catalogue={len(classes)}")
    mem_mb = (X.memory_usage(deep=True).sum() + y.memory_usage(deep=True)) / (1024**2)
    print(f"memoire approx (X+y): {mem_mb:.1f} MB")
    print(f"missing: X={int(X.isnull().sum().sum())}, y={int(y.isnull().sum())}")
    print("y distribution (top classes):")


@dataclass
class PrepArtifacts:
    """Everything you may want to reuse later (transformers, column names, etc.)."""
    task: str
    model_family: str
    handles_categoricals_natively: bool
    feature_preprocessor: Optional[Any]          # ColumnTransformer or custom callable
    y_transformer: Optional[Any]                 # transformer with transform/inverse_transform
    cat_cols: list[str]
    num_cols: list[str]


def _as_dataframe(X: Union[np.ndarray, pd.DataFrame]) -> pd.DataFrame:
    if isinstance(X, pd.DataFrame):
        return X.copy()
    X = np.asarray(X)
    return pd.DataFrame(X, columns=[f"x{j}" for j in range(X.shape[1])])


def _safe_onehot_encoder() -> OneHotEncoder:
    # sklearn>=1.2 uses sparse_output; older uses sparse
    try:
        return OneHotEncoder(handle_unknown="ignore", sparse_output=False)
    except TypeError:
        return OneHotEncoder(handle_unknown="ignore", sparse=False)


def _truncate_rows(Xdf: pd.DataFrame, y: np.ndarray, max_rows: int, rng: np.random.Generator):
    if max_rows is None or len(Xdf) <= max_rows:
        return Xdf, y
    idx = rng.choice(len(Xdf), size=max_rows, replace=False)
    return Xdf.iloc[idx].reset_index(drop=True), np.asarray(y)[idx]


def _is_heavy_tailed_auto(y: np.ndarray) -> bool:
    """
    Heuristic (NOT specified in the paper): tries to detect 'house-prices-like' heavy tails.
    If you want strict control, set y_transform_mode explicitly instead of relying on auto.
    """
    y = np.asarray(y)
    if y.ndim == 1:
        y_ = y.reshape(-1, 1)
    else:
        y_ = y
    if not np.all(np.isfinite(y_)):
        return False
    if np.any(y_ <= 0):
        return False
    # ratio q99 / median averaged across outputs
    q99 = np.quantile(y_, 0.99, axis=0)
    med = np.quantile(y_, 0.50, axis=0)
    ratio = np.mean(q99 / np.maximum(med, 1e-12))
    return ratio > 50.0


def build_regression_target_transformer(
    y: np.ndarray,
    mode: str = "none",
    random_state: int = 0
):
    """
    mode:
      - "none": identity
      - "log": log1p/exp1m (requires y>=0 elementwise)
      - "gauss": QuantileTransformer(output_distribution="normal") on y
      - "log+gauss": log1p then gauss
      - "auto": apply "log" iff heavy-tailed heuristic triggers, else "none"
    Returns: transformer with fit/transform/inverse_transform (sklearn Pipeline)
    """
    if mode == "auto":
        mode = "log" if _is_heavy_tailed_auto(y) else "none"

    steps = []
    if mode in ("log", "log+gauss"):
        # log1p works for y>=0; if your data can be negative, don't use this.
        steps.append(("log1p", FunctionTransformer(np.log1p, inverse_func=np.expm1, validate=False)))

    if mode in ("gauss", "log+gauss"):
        qt = QuantileTransformer(
            output_distribution="normal",
            random_state=random_state
        )
        steps.append(("qt", qt))

    if not steps:
        return None

    return Pipeline(steps)


def wrap_regressor_with_target_transform(
    base_regressor,
    y_transformer: Optional[Any]
):
    """
    Paper uses TransformedTargetRegressor + QuantileTransformer as an *hyperparameter option*.
    If y_transformer is None -> returns base_regressor unchanged.
    """
    if y_transformer is None:
        return base_regressor
    return TransformedTargetRegressor(regressor=base_regressor, transformer=y_transformer)


def grinsztajn_full_paper_preparation(
    X: Union[np.ndarray, pd.DataFrame],
    y: np.ndarray,
    *,
    task: str,  # "classification" or "regression"
    model_family: str,  # "nn" or "tree" or "linear" (only used to pick feature transforms)
    handles_categoricals_natively: bool,
    # medium regime: 10_000 ; large regime (appendix): 50_000
    max_train_size: int = 10_000,
    max_val_size: int = 50_000,
    max_test_size: int = 50_000,
    random_state: int = 0,
    # --- "Removing side issues" params (3.2) ---
    missing_col_threshold: float = 0.5,   # paper does not give a numeric threshold
    max_categorical_cardinality: int = 20,
    min_unique_numerical: int = 10,
    convert_binary_numerical_to_categorical: bool = True,
    # --- "Data preparation" params (3.5) ---
    gaussianize_features_for_nn: bool = True,
    y_transform_mode: str = "none",  # "none"|"log"|"gauss"|"log+gauss"|"auto"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, PrepArtifacts]:
    """
    Implements the paper’s end-to-end *data* preparation:
      1) Remove side issues (Sec 3.2): missing, balance/binarize classif, low/high-card filters
      2) Split (App A.3): train = 70% OR capped by max_train_size; remaining -> val 30% / test 70%
         + truncate val/test to 50k
      3) Data preparation (Sec 3.5):
         - For NN: Gaussianize numerical features with QuantileTransformer
         - OneHotEncoder for models that don't handle categoricals natively
         - Regression target transform option (log heavy-tailed; optional gaussianized target)
    Returns numpy arrays + artifacts (transformers, column lists).
    """
    rng = np.random.default_rng(random_state)
    Xdf = _as_dataframe(X)
    y = np.asarray(y)

    # 1) Removing side issues (your function already implements it)
    X_clean, y_clean, info = grinsztajn_remove_side_issues(
        X=Xdf,
        y=y,
        task=task,
        missing_col_threshold=missing_col_threshold,
        max_train_size=max_train_size,  # not applied inside; applied below in split
        random_state=random_state,
        max_categorical_cardinality=max_categorical_cardinality,
        min_unique_numerical=min_unique_numerical,
        convert_binary_numerical_to_categorical=convert_binary_numerical_to_categorical,
    )
    cat_cols = list(info.get("cat_cols", []))
    num_cols = list(info.get("num_cols", []))

    n = len(X_clean)
    if n < 10:
        raise ValueError(f"Too few samples after preprocessing: n={n}")

    # 2) Train / Val / Test split (A.3)
    # Train fraction: 70% unless that exceeds max_train_size.
    train_frac = 0.70
    if int(np.floor(train_frac * n)) > max_train_size:
        train_frac = max_train_size / float(n)

    # Split off train
    strat = y_clean if task == "classification" and np.asarray(y_clean).ndim == 1 else None
    X_train, X_tmp, y_train, y_tmp = train_test_split(
        X_clean, y_clean,
        train_size=train_frac,
        random_state=random_state,
        shuffle=True,
        stratify=strat,
    )

    # Remaining 30%: val=30% of remaining, test=70% of remaining
    # (i.e. val_frac_of_tmp = 0.30)
    strat_tmp = y_tmp if task == "classification" and np.asarray(y_tmp).ndim == 1 else None
    X_val, X_test, y_val, y_test = train_test_split(
        X_tmp, y_tmp,
        train_size=0.30,
        random_state=random_state + 1,
        shuffle=True,
        stratify=strat_tmp,
    )

    # Enforce explicit truncations (paper: val/test truncated to 50k; train capped by max_train_size)
    X_train, y_train = _truncate_rows(X_train, y_train, max_train_size, rng)
    X_val, y_val = _truncate_rows(X_val, y_val, max_val_size, rng)
    X_test, y_test = _truncate_rows(X_test, y_test, max_test_size, rng)

    # 3) Data preparation (3.5): build feature preprocessor fitted on train only
    transformers = []
    if num_cols:
        if model_family.lower() == "nn" and gaussianize_features_for_nn:
            qtX = QuantileTransformer(
                output_distribution="normal",
                random_state=random_state
            )
            transformers.append(("num", qtX, num_cols))
        else:
            # Paper does not mandate scaling for non-NN models; keep passthrough.
            transformers.append(("num", "passthrough", num_cols))

    if cat_cols:
        if handles_categoricals_natively:
            # Keep categoricals as integer codes (compact + compatible with many libs).
            def _cat_to_codes(df: pd.DataFrame) -> np.ndarray:
                out = []
                for c in cat_cols:
                    s = df[c]
                    if not pd.api.types.is_categorical_dtype(s):
                        s = s.astype("category")
                    out.append(s.cat.codes.to_numpy().reshape(-1, 1))
                return np.concatenate(out, axis=1).astype(np.int64)

            # We’ll apply this manually (since ColumnTransformer expects sklearn transformers).
            feature_preprocessor = ("native_cat_codes", _cat_to_codes, transformers)
        else:
            ohe = _safe_onehot_encoder()
            transformers.append(("cat", ohe, cat_cols))
            feature_preprocessor = ColumnTransformer(transformers=transformers, remainder="drop", sparse_threshold=0.0)
    else:
        feature_preprocessor = ColumnTransformer(transformers=transformers, remainder="drop", sparse_threshold=0.0)

    # Fit & transform features
    if isinstance(feature_preprocessor, tuple) and feature_preprocessor[0] == "native_cat_codes":
        _, cat_fn, num_transformers = feature_preprocessor
        # Fit numeric transformer (if any)
        num_ct = ColumnTransformer(transformers=num_transformers, remainder="drop", sparse_threshold=0.0)
        num_ct.fit(X_train)

        Xtr_num = num_ct.transform(X_train) if num_cols else np.zeros((len(X_train), 0), dtype=np.float32)
        Xva_num = num_ct.transform(X_val)   if num_cols else np.zeros((len(X_val), 0), dtype=np.float32)
        Xte_num = num_ct.transform(X_test)  if num_cols else np.zeros((len(X_test), 0), dtype=np.float32)

        Xtr_cat = cat_fn(X_train)
        Xva_cat = cat_fn(X_val)
        Xte_cat = cat_fn(X_test)

        X_train_np = np.concatenate([np.asarray(Xtr_num), Xtr_cat], axis=1)
        X_val_np   = np.concatenate([np.asarray(Xva_num), Xva_cat], axis=1)
        X_test_np  = np.concatenate([np.asarray(Xte_num), Xte_cat], axis=1)

        fitted_feature_preprocessor = (num_ct, cat_fn, num_cols, cat_cols)
    else:
        feature_preprocessor.fit(X_train)
        X_train_np = np.asarray(feature_preprocessor.transform(X_train))
        X_val_np   = np.asarray(feature_preprocessor.transform(X_val))
        X_test_np  = np.asarray(feature_preprocessor.transform(X_test))
        fitted_feature_preprocessor = feature_preprocessor

    # Target transform (regression only) — paper: log for heavy-tailed, + optional gaussianized y via TTR+QT
    y_transformer = None
    if task == "regression":
        y_transformer = build_regression_target_transformer(
            y_train,
            mode=y_transform_mode,
            random_state=random_state
        )
        if y_transformer is not None:
            y_transformer.fit(y_train)
            y_train_fit = y_transformer.transform(y_train)
            y_val_fit   = y_transformer.transform(y_val)
            y_test_fit  = y_transformer.transform(y_test)
        else:
            y_train_fit, y_val_fit, y_test_fit = y_train, y_val, y_test
    else:
        y_train_fit, y_val_fit, y_test_fit = y_train, y_val, y_test

    artifacts = PrepArtifacts(
        task=task,
        model_family=model_family,
        handles_categoricals_natively=handles_categoricals_natively,
        feature_preprocessor=fitted_feature_preprocessor,
        y_transformer=y_transformer,
        cat_cols=cat_cols,
        num_cols=num_cols,
    )

    return (
        X_train_np, np.asarray(y_train_fit),
        X_val_np,   np.asarray(y_val_fit),
        X_test_np,  np.asarray(y_test_fit),
        artifacts
    )

from dataclasses import dataclass, asdict
from typing import Dict, Tuple, Optional, Literal
import numpy as np


SyntheticKind = Literal["linear_lowrank", "nonlinear_mlp", "regime_switch"]


@dataclass
class SyntheticConfig:
    kind: SyntheticKind = "linear_lowrank"
    n_samples: int = 12000
    n_features: int = 30
    n_outputs: int = 20

    # Structure (signal)
    rank: int = 5                 # low-rank structure for outputs
    nonlinear_strength: float = 0.0  # used in linear_lowrank (tanh on latent)

    # Noise
    noise_std: float = 0.5
    output_noise_corr: float = 0.25   # rho in [0, 1)
    heteroskedastic: bool = True
    hetero_strength: float = 0.75     # multiplies noise by (1 + hetero_strength * |x·v|)
    heavy_tailed: bool = False        # Student-t noise
    t_df: float = 3.0

    # Outliers
    outlier_frac: float = 0.0
    outlier_scale: float = 8.0

    # Nonlinear MLP specifics
    hidden_dim: int = 64

    # Regime switch specifics
    regime_margin: float = 0.0        # threshold for x[:,0] + margin > 0

    # Reproducibility
    seed: int = 42
    dtype: str = "float32"


def _rng(seed: int) -> np.random.Generator:
    return np.random.default_rng(seed)


def _make_output_noise_cov(d: int, rho: float, eps: float = 1e-6) -> np.ndarray:
    rho = float(np.clip(rho, 0.0, 0.999))
    cov = (1.0 - rho) * np.eye(d) + rho * np.ones((d, d))
    cov += eps * np.eye(d)
    return cov


def _sample_noise(
    rng: np.random.Generator,
    n: int,
    d: int,
    std: np.ndarray,                 # shape (n, 1) or (n, d)
    cov: Optional[np.ndarray],
    heavy_tailed: bool,
    t_df: float,
) -> np.ndarray:
    # base iid noise
    if heavy_tailed:
        z = rng.standard_t(df=t_df, size=(n, d))
    else:
        z = rng.standard_normal(size=(n, d))

    # correlate across outputs if cov provided
    if cov is not None:
        L = np.linalg.cholesky(cov)
        z = z @ L.T

    return z * std


def generate_synthetic_multioutput(
    cfg: SyntheticConfig,
) -> Tuple[np.ndarray, np.ndarray, Dict]:
    """
    Returns
    -------
    X : (n_samples, n_features)
    y : (n_samples, n_outputs)
    meta : dict with config + some ground-truth params
    """
    rng = _rng(cfg.seed)
    n, p, d = int(cfg.n_samples), int(cfg.n_features), int(cfg.n_outputs)
    r = int(max(1, min(cfg.rank, p, d)))

    # Features
    X = rng.standard_normal(size=(n, p))

    # Per-sample noise scale (heteroskedastic)
    if cfg.heteroskedastic:
        v = rng.standard_normal(size=(p,))
        v /= (np.linalg.norm(v) + 1e-12)
        s = np.abs(X @ v)  # (n,)
        std_scale = (1.0 + cfg.hetero_strength * s).reshape(n, 1)
    else:
        v = None
        std_scale = np.ones((n, 1), dtype=float)

    cov = _make_output_noise_cov(d, cfg.output_noise_corr) if cfg.output_noise_corr > 0 else None

    # Mean function
    meta: Dict = {"config": asdict(cfg)}

    if cfg.kind == "linear_lowrank":
        # y_mean = X @ W @ A  with W:(p,r), A:(r,d)
        W = rng.standard_normal(size=(p, r)) / np.sqrt(p)
        A = rng.standard_normal(size=(r, d)) / np.sqrt(r)

        Z = X @ W  # (n,r)
        if cfg.nonlinear_strength > 0:
            Z = (1.0 - cfg.nonlinear_strength) * Z + cfg.nonlinear_strength * np.tanh(Z)

        y_mean = Z @ A

        meta.update({"W": W, "A": A, "v_hetero": v})

    elif cfg.kind == "nonlinear_mlp":
        # 2-layer random MLP (fixed) to create nonlinear signal
        h = int(max(8, cfg.hidden_dim))
        W1 = rng.standard_normal(size=(p, h)) / np.sqrt(p)
        b1 = rng.standard_normal(size=(h,))
        W2 = rng.standard_normal(size=(h, d)) / np.sqrt(h)
        b2 = rng.standard_normal(size=(d,))

        H = X @ W1 + b1
        H = np.maximum(H, 0.0)  # ReLU
        y_mean = H @ W2 + b2

        meta.update({"W1": W1, "b1": b1, "W2": W2, "b2": b2, "v_hetero": v})

    elif cfg.kind == "regime_switch":
        # Two low-rank linear regimes; gating depends on first feature
        W1 = rng.standard_normal(size=(p, r)) / np.sqrt(p)
        A1 = rng.standard_normal(size=(r, d)) / np.sqrt(r)
        W2 = rng.standard_normal(size=(p, r)) / np.sqrt(p)
        A2 = rng.standard_normal(size=(r, d)) / np.sqrt(r)

        gate = (X[:, 0] + cfg.regime_margin > 0).astype(float).reshape(n, 1)
        y1 = (X @ W1) @ A1
        y2 = (X @ W2) @ A2
        y_mean = gate * y1 + (1.0 - gate) * y2

        meta.update({"W1": W1, "A1": A1, "W2": W2, "A2": A2, "gate_margin": cfg.regime_margin, "v_hetero": v})

    else:
        raise ValueError(f"Unknown kind={cfg.kind}")

    # Noise
    std = (cfg.noise_std * std_scale).astype(float)  # (n,1)
    eps = _sample_noise(
        rng=rng, n=n, d=d, std=std, cov=cov,
        heavy_tailed=cfg.heavy_tailed, t_df=cfg.t_df
    )
    y = y_mean + eps

    # Inject outliers in y (optional)
    if cfg.outlier_frac > 0:
        m = int(np.floor(cfg.outlier_frac * n))
        if m > 0:
            idx = rng.choice(n, size=m, replace=False)
            y[idx] += cfg.outlier_scale * rng.standard_normal(size=(m, d))

    X = X.astype(cfg.dtype, copy=False)
    y = y.astype(cfg.dtype, copy=False)
    meta["noise_cov"] = cov
    return X, y, meta


def make_dataset(
    n_outputs: int,
    kind: SyntheticKind = "linear_lowrank",
    n_samples: int = 12000,
    n_features: int = 30,
    seed: int = 42,
    **kwargs,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Convenience wrapper returning only (X, y) like your loader.
    """
    cfg = SyntheticConfig(
        kind=kind,
        n_samples=n_samples,
        n_features=n_features,
        n_outputs=n_outputs,
        seed=seed,
        **kwargs,
    )
    X, y, _ = generate_synthetic_multioutput(cfg)
    return X, y

# linear_lowrank : multi-output corrélé, bon contrôle via rank et output_noise_corr.
# nonlinear_mlp : non-linéaire (souvent plus dur pour les modèles de base).
# regime_switch : non-stationnarité (deux régimes), utile pour voir la sensibilité des méthodes.