import os
import logging

import numpy as np
import pandas as pd
import torch
from prettytable import PrettyTable
from scipy import stats
from sklearn import model_selection, preprocessing

from .data_utils import FastTensorDataLoader, OriginalData
from .imputers import lgbm_impute, miceforest_impute, impute_naive, impute_cdf, impute_zero, impute_noise


def set_seeds(seed, cuda_deterministic=False):
    """Set random seeds for reproducibility."""
    import random
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            if cuda_deterministic:
                torch.backends.cudnn.deterministic = True
                torch.backends.cudnn.benchmark = False

import json
import pickle

import warnings
warnings.filterwarnings('ignore')

class DataWrangler(object):
    """
    Data preparation for pre-processing data and post-processing generated data.

    """

    def __init__(self, dataset, data_path, sample_path, logdir, config, val_prop, test_prop, seed, preproc, beta='0p7', use_log='switch', skip_preprocess_in_init=False):
        self.data_name = dataset
        self.data_path = data_path
        self.preproc = preproc
        self.sample_path = sample_path
        self.config = config  # Store config for later access (e.g., n_cached_imputations)
        self.beta = beta  # beta parameter for split-normal shrinkage (format: '0p7' for 0.7)
        self.use_log = use_log  # use_log mode: 'switch', 'lognormal', or 'splitnormal'
        
        path = self.data_path
        info_path = f'{path}/info.json'
        with open(info_path, 'r') as f:
            info = json.load(f)
            
        col_names = info['column_names'] 
        self.target = col_names[info['target_col_idx'][0]]
            
        self.task = info['task_type']

        self.target_is_int = False

        self.cat_features = []
        self.cont_features = [] 
        

        self.val_prop = val_prop
        self.test_prop = test_prop
        
        self.cont_scaler = config.data.cont_scaler
        self.standardize_data = config.data.standardize_data
        self.cat_encoding = config.data.cat_encoding
        self.num_y_classes = None
        
        # y_cond removed - y is always included in X_cat (classification) or X_cont (regression)
        self.drop_cont_missings = True

        if logdir is not None:
            self.logdir = os.path.join(logdir, "data")
            if not os.path.exists(self.logdir):
                os.makedirs(self.logdir)

        if skip_preprocess_in_init:
            self.data = None
        else:
            self.data = self.preprocess_data(seed)

    def _load_raw_data(self, path, info):
        """Load raw data from numpy files."""
        X_cont_train = np.load(f'{path}/X_num_train.npy')
        X_cont_test = np.load(f'{path}/X_num_test.npy')
        X_cat_train = np.load(f'{path}/X_cat_train.npy', allow_pickle=True)
        X_cat_test = np.load(f'{path}/X_cat_test.npy', allow_pickle=True)
        y_train = np.load(f'{path}/y_train.npy', allow_pickle=True)
        y_test = np.load(f'{path}/y_test.npy', allow_pickle=True)
        
        mask_num_train = np.load(f'{path}/mask_num_train.npy')
        mask_num_test = np.load(f'{path}/mask_num_test.npy')
        mask_cat_train = np.load(f'{path}/mask_cat_train.npy')
        mask_cat_test = np.load(f'{path}/mask_cat_test.npy')
        y_mask_train = np.load(f'{path}/y_mask_train.npy')
        y_mask_test = np.load(f'{path}/y_mask_test.npy')
        
        return (X_cont_train, X_cont_test, X_cat_train, X_cat_test, 
                y_train, y_test, mask_num_train, mask_num_test, 
                mask_cat_train, mask_cat_test, y_mask_train, y_mask_test)

    def _prepare_missing_masks(self, X_cat_train, X_cat_test, mask_cat_train, 
                               mask_cat_test, y_train, y_test, y_mask_train, 
                               y_mask_test, info):
        """Prepare DataFrames with missing values marked as NaN. Matches parent (cdtd) flow."""
        X_cat_train_df = pd.DataFrame(X_cat_train)
        X_cat_test_df = pd.DataFrame(X_cat_test)
        
        if info['task_type'] != 'regression':
            # mask_cat has shape (n, 1 + n_cat): first column is y mask, rest are cat features
            mask_cat_train_feat = mask_cat_train[:, 1:]
            mask_cat_test_feat = mask_cat_test[:, 1:]
            assert mask_cat_train_feat.shape == X_cat_train_df.shape, (
                "mask_cat_train[:, 1:] shape %s must match X_cat_train shape %s" % (
                    mask_cat_train_feat.shape, X_cat_train_df.shape))
            assert mask_cat_test_feat.shape == X_cat_test_df.shape, (
                "mask_cat_test[:, 1:] shape %s must match X_cat_test shape %s" % (
                    mask_cat_test_feat.shape, X_cat_test_df.shape))
            X_cat_train_df = X_cat_train_df.where(mask_cat_train_feat != 0, np.nan)
            X_cat_test_df = X_cat_test_df.where(mask_cat_test_feat != 0, np.nan)
        else:
            assert mask_cat_train.shape == X_cat_train_df.shape
            assert mask_cat_test.shape == X_cat_test_df.shape
            X_cat_train_df = X_cat_train_df.where(mask_cat_train != 0, np.nan)
            X_cat_test_df = X_cat_test_df.where(mask_cat_test != 0, np.nan)
            
        y_train_df = pd.DataFrame(y_train)
        y_train_df = y_train_df.where(y_mask_train != 0, np.nan)
        y_test_df = pd.DataFrame(y_test)
        y_test_df = y_test_df.where(y_mask_test != 0, np.nan)
        
        nan_cat_tr = X_cat_train_df.isna().sum().sum() if hasattr(X_cat_train_df, 'isna') else 0
        nan_y_tr = y_train_df.isna().sum().sum() if hasattr(y_train_df, 'isna') else 0
        print(f"[PREP] _prepare_missing_masks: X_cat NaN count train={nan_cat_tr}, y NaN count train={nan_y_tr}")
        return X_cat_train_df, X_cat_test_df, y_train_df, y_test_df

    def _impute_categorical(self, X_cat_train_df, X_cat_test_df, y_train_df, 
                          y_test_df, info):
        """Impute categorical features and target (for convenience, impute first)."""
        if self.preproc == 'm':
            X_cat_train = impute_naive(X_cat_train_df).to_numpy().astype("str")
            X_cat_test = impute_naive(X_cat_test_df).to_numpy().astype("str")
            
            if info['task_type'] != 'regression':
                y_train = impute_naive(y_train_df).to_numpy().astype("str")
                y_test = impute_naive(y_test_df).to_numpy().astype("str")
            else:
                y_train = y_train_df.to_numpy()
                y_test = y_test_df.to_numpy()
                
        elif self.preproc == 'r':
            X_cat_train = impute_cdf(X_cat_train_df).to_numpy().astype("str") if X_cat_train_df.shape[1] > 0 else X_cat_train_df.to_numpy()
            X_cat_test = impute_cdf(X_cat_test_df).to_numpy().astype("str") if X_cat_test_df.shape[1] > 0 else X_cat_test_df.to_numpy()
            
            if info['task_type'] != 'regression':
                y_train = impute_cdf(y_train_df).to_numpy().astype("str")
                y_test = impute_cdf(y_test_df).to_numpy().astype("str")
            else:
                y_train = y_train_df.to_numpy()
                y_test = y_test_df.to_numpy()
        else:
            # For LGB_D/LGB_S, categorical imputation happens later
            X_cat_train = X_cat_train_df.to_numpy()
            X_cat_test = X_cat_test_df.to_numpy()
            y_train = y_train_df.to_numpy()
            y_test = y_test_df.to_numpy()
            
        return X_cat_train, X_cat_test, y_train, y_test

    def _impute_continuous_features(self, X_continuous, split, X_continuous_train=None):
        """Impute missing values in continuous features.
        For 'm'/'r': uses training statistics to avoid data leakage.
        """
        if self.preproc == 'm':
            if X_continuous_train is not None and split != "train":
                df_train = pd.DataFrame(X_continuous_train)
                df_test = pd.DataFrame(X_continuous)
                for col in df_test.columns:
                    if df_test[col].isna().any():
                        if pd.api.types.is_numeric_dtype(df_train[col]):
                            fill_value = df_train[col].mean()
                            if pd.isna(fill_value):
                                fill_value = 0.0
                        else:
                            mode_result = df_train[col].mode()
                            fill_value = mode_result.iloc[0] if len(mode_result) > 0 else (df_train[col].iloc[0] if len(df_train[col]) > 0 and not pd.isna(df_train[col].iloc[0]) else 0)
                        df_test[col] = df_test[col].fillna(fill_value)
                X_imputed = df_test.to_numpy().astype("float")
            else:
                X_imputed = impute_naive(pd.DataFrame(X_continuous)).to_numpy().astype("float")
        elif self.preproc == 'r':
            if X_continuous_train is not None and split != "train":
                from feature_engine.imputation import RandomSampleImputer
                rsi = RandomSampleImputer(random_state=42)
                rsi.fit(pd.DataFrame(X_continuous_train))
                X_imputed = rsi.transform(pd.DataFrame(X_continuous)).to_numpy().astype("float")
            else:
                X_imputed = impute_cdf(pd.DataFrame(X_continuous), random_state=42).to_numpy().astype("float")
        else:
            X_imputed = X_continuous.astype("float")
        return X_imputed

    def _rearrange_data_with_y(self, X_cat_train_df, X_cat_test_df, X_cont_train, X_cont_test, 
                               y_train_df, y_test_df, info):
        """
        Rearrange data by prepending y to X_cont (regression) or X_cat (classification).
        This ensures consistent indices throughout the pipeline.
        
        Returns:
            X_cat_train_rearranged, X_cat_test_rearranged: Categorical data (DataFrame with y prepended for classification)
            X_cont_train_rearranged, X_cont_test_rearranged: Continuous data (numpy array with y prepended for regression)
            y_train, y_test: Separate y arrays (for reference, but y is also in X_cat or X_cont)
        """
        if self.task == "regression":
            y_train_vals = y_train_df.values.ravel() if y_train_df.values.ndim > 1 else y_train_df.values
            y_test_vals = y_test_df.values.ravel() if y_test_df.values.ndim > 1 else y_test_df.values
            X_cont_train_rearranged = np.column_stack((y_train_vals, X_cont_train))
            X_cont_test_rearranged = np.column_stack((y_test_vals, X_cont_test))
            X_cat_train_rearranged = X_cat_train_df
            X_cat_test_rearranged = X_cat_test_df
        else:
            X_cat_train_rearranged = pd.concat([y_train_df, X_cat_train_df], axis=1)
            X_cat_train_rearranged.columns = range(X_cat_train_rearranged.columns.size)
            X_cat_test_rearranged = pd.concat([y_test_df, X_cat_test_df], axis=1)
            X_cat_test_rearranged.columns = range(X_cat_test_rearranged.columns.size)
            X_cont_train_rearranged = X_cont_train
            X_cont_test_rearranged = X_cont_test
        
        y_train = y_train_df.values
        y_test = y_test_df.values
        
        return (X_cat_train_rearranged, X_cat_test_rearranged, 
                X_cont_train_rearranged, X_cont_test_rearranged, y_train, y_test)

    def _transform_observed_data(self, X_cat_train_df, X_cat_test_df, X_cont_train, X_cont_test, 
                                 info):
        """
        Transform observed data: quantile transform for continuous, categorical encoding.
        Assumes data is already rearranged (y prepended to X_cont for regression, X_cat for classification).
        Only transforms observed values (preserves NaN for missing values).
        
        Stores cont_feature_indices_for_transformer for consistent inverse transform.
        
        Returns:
            X_combined_train_transformed: Combined DataFrame with transformed data (NaN preserved)
            X_combined_test_transformed: Combined DataFrame with transformed data (NaN preserved)
            cont_enc: Fitted QuantileTransformer
            cat_enc: None (no encoding before imputation)
        """
        n_cat, n_num = X_cat_train_df.shape[1], X_cont_train.shape[1]
        
        X_cont_train_df = pd.DataFrame(X_cont_train)
        X_cont_test_df = pd.DataFrame(X_cont_test)
        X_combined_train = pd.concat([X_cat_train_df, X_cont_train_df], axis=1)
        X_combined_train.columns = range(X_combined_train.columns.size)
        X_combined_test = pd.concat([X_cat_test_df, X_cont_test_df], axis=1)
        X_combined_test.columns = range(X_combined_test.columns.size)
        
        dtypes = X_combined_train.dtypes.tolist()
        numerical_cols = [idx for idx, col in enumerate(dtypes) if pd.api.types.is_numeric_dtype(X_combined_train.iloc[:, idx])]
        categorical_cols = [idx for idx, col in enumerate(dtypes) if not pd.api.types.is_numeric_dtype(X_combined_train.iloc[:, idx])]
        
        X_combined_train_transformed = X_combined_train.copy()
        X_combined_test_transformed = X_combined_test.copy()
        
        n_cat_in_combined = X_cat_train_df.shape[1]
        n_cont_in_combined = X_cont_train.shape[1]
        cont_indices_in_combined = list(range(n_cat_in_combined, n_cat_in_combined + n_cont_in_combined))
        
        cont_enc = None
        if len(cont_indices_in_combined) > 0:
            cont_enc = preprocessing.QuantileTransformer(
                output_distribution="normal",
                n_quantiles=max(min(X_combined_train.shape[0] // 30, 1000), 10),
                subsample=int(1e9),
                random_state=42,
            )
            cont_data_train = X_combined_train.iloc[:, cont_indices_in_combined]
            cont_enc.fit(cont_data_train)
            
            self.cont_feature_indices_for_transformer = np.array([idx - n_cat_in_combined for idx in cont_indices_in_combined])
            
            transformed_train = cont_enc.transform(cont_data_train)
            for i, col_idx in enumerate(cont_indices_in_combined):
                X_combined_train_transformed.iloc[:, col_idx] = transformed_train[:, i]
            
            cont_data_test = X_combined_test.iloc[:, cont_indices_in_combined]
            transformed_test = cont_enc.transform(cont_data_test)
            for i, col_idx in enumerate(cont_indices_in_combined):
                X_combined_test_transformed.iloc[:, col_idx] = transformed_test[:, i]
        
            # Apply standardization BEFORE imputation (using nanmean/nanstd on observed values only)
            # This helps imputers work better on normalized data
            if self.cont_scaler == "quantile":
                # Fit standardization on observed values only (using nanmean/nanstd)
                cont_data_train_transformed = X_combined_train_transformed.iloc[:, cont_indices_in_combined]
                self.X_cont_mean = np.nanmean(cont_data_train_transformed.values, axis=0)
                self.X_cont_std = np.nanstd(cont_data_train_transformed.values, axis=0)
                # Avoid division by zero
                self.X_cont_std = np.where(self.X_cont_std == 0, 1.0, self.X_cont_std)
                
                # Apply standardization to train and test (preserving NaN)
                for col_idx in cont_indices_in_combined:
                    col_data_train = X_combined_train_transformed.iloc[:, col_idx].copy()
                    col_data_test = X_combined_test_transformed.iloc[:, col_idx].copy()
                    col_idx_in_cont = col_idx - n_cat_in_combined
                    
                    # Standardize observed values only (NaN preserved)
                    mask_train = ~pd.isna(col_data_train)
                    mask_test = ~pd.isna(col_data_test)
                    if mask_train.any():
                        col_data_train.loc[mask_train] = (
                            col_data_train.loc[mask_train] - self.X_cont_mean[col_idx_in_cont]
                        ) / self.X_cont_std[col_idx_in_cont]
                        X_combined_train_transformed.iloc[:, col_idx] = col_data_train
                    if mask_test.any():
                        col_data_test.loc[mask_test] = (
                            col_data_test.loc[mask_test] - self.X_cont_mean[col_idx_in_cont]
                        ) / self.X_cont_std[col_idx_in_cont]
                        X_combined_test_transformed.iloc[:, col_idx] = col_data_test
        else:
            self.cont_feature_indices_for_transformer = np.array([])
            self.X_cont_mean = None
            self.X_cont_std = None
        
        cat_enc = None
        if len(categorical_cols) > 0:
            # For kept preprocs (LGB_D, LGB_S, r, m, miceforest, zero, noise), use category dtype.
            for col_idx in categorical_cols:
                if not pd.api.types.is_categorical_dtype(X_combined_train_transformed.iloc[:, col_idx]):
                    X_combined_train_transformed.iloc[:, col_idx] = X_combined_train_transformed.iloc[:, col_idx].astype('category')
                if not pd.api.types.is_categorical_dtype(X_combined_test_transformed.iloc[:, col_idx]):
                    X_combined_test_transformed.iloc[:, col_idx] = X_combined_test_transformed.iloc[:, col_idx].astype('category')
        
        return X_combined_train_transformed, X_combined_test_transformed, cont_enc, cat_enc

    def _extract_imputed_data(self, X_imputed_tr, X_imputed_test, n_cat, n_num, info):
        """
        Extract and convert imputed data to proper format.
        Assumes data is already rearranged (y prepended to X_cont for regression, X_cat for classification).
        Keeps y in the rearranged structure (in X_cont for regression, X_cat for classification).
        Also returns y separately for stratification purposes.
        
        Args:
            n_cat: Number of categorical features (before rearrangement, so excludes y for classification)
            n_num: Number of continuous features (before rearrangement, so excludes y for regression)
        
        For regression: X_imputed has shape (N, n_cat + n_num + 1) where:
          - First n_cat columns: X_cat
          - Next n_num+1 columns: X_cont (with y at index 0 of this block)
        
        For classification: X_imputed has shape (N, n_cat + 1 + n_num) where:
          - First n_cat+1 columns: X_cat (with y at index 0 of this block)
          - Next n_num columns: X_cont
        """
        if self.task == "regression":
            X_cat_train = X_imputed_tr[:, :n_cat]
            X_cat_test = X_imputed_test[:, :n_cat]
            X_cont_train = np.array(X_imputed_tr[:, n_cat:n_cat + n_num + 1], dtype=np.float32)
            X_cont_test = np.array(X_imputed_test[:, n_cat:n_cat + n_num + 1], dtype=np.float32)
            y_train = X_cont_train[:, 0:1].astype("float")
            y_test = X_cont_test[:, 0:1].astype("float")
        else:
            X_cat_train = X_imputed_tr[:, :n_cat + 1]
            X_cat_test = X_imputed_test[:, :n_cat + 1]
            X_cont_train = np.array(X_imputed_tr[:, n_cat + 1:n_cat + 1 + n_num], dtype=np.float32)
            X_cont_test = np.array(X_imputed_test[:, n_cat + 1:n_cat + 1 + n_num], dtype=np.float32)
            y_train = X_cat_train[:, 0:1].astype("str")
            y_test = X_cat_test[:, 0:1].astype("str")
        
        return X_cat_train, X_cat_test, X_cont_train, X_cont_test, y_train, y_test

    def _impute_with_lightgbm(self, path, X_combined_train_transformed, X_combined_test_transformed,
                              n_cat_orig, n_num_orig):
        """Impute all features using LightGBM (offline mode)."""
        use_sampling = (self.preproc == 'LGB_S')
        if self.preproc == 'LGB_S':
            cache_path_tr = f"{path}/X_imputed_tr_{self.preproc}_beta{self.beta}_use_log{self.use_log}.pkl"
        else:
            cache_path_tr = f"{path}/X_imputed_tr_LGB_D.pkl"
        std_mult_float = 1.0
        beta_float = float(self.beta.replace('p', '.')) if (self.preproc == 'LGB_S') else 0.7
        if self.preproc == 'LGB_S':
            X_imputed_tr_df = lgbm_impute(
                X_combined_train_transformed,
                sample=use_sampling,
                random_state=42,
                std_multiplier=std_mult_float,
                beta=beta_float,
                use_log=self.use_log,
                cache_path=cache_path_tr
            )
        else:
            X_imputed_tr_df = lgbm_impute(
                X_combined_train_transformed,
                sample=use_sampling,
                random_state=42,
                std_multiplier=std_mult_float,
                cache_path=cache_path_tr
            )
        X_imputed_tr = X_imputed_tr_df.to_numpy()
        X_imputed_test = X_imputed_tr.copy()
        return self._extract_imputed_data(X_imputed_tr, X_imputed_test, n_cat_orig, n_num_orig, info)

    def _impute_with_miceforest(self, path, X_combined_train_transformed, X_combined_test_transformed,
                                n_cat_orig, n_num_orig):
        """Impute all features using miceforest (MICE with LightGBM)."""
        cache_path_tr = f"{path}/X_imputed_tr_miceforest.pkl"
        X_imputed_tr_df = miceforest_impute(
            X_combined_train_transformed,
            random_state=42,
            cache_path=cache_path_tr
        )
        X_imputed_tr = X_imputed_tr_df.to_numpy()
        X_imputed_test = X_imputed_tr.copy()
        return self._extract_imputed_data(X_imputed_tr, X_imputed_test, n_cat_orig, n_num_orig, info)

    def get_imputation_inputs(self, seed):
        """
        Load and transform data up to (but not including) imputation.
        For use in imputation benchmarks: returns state needed to run each imputer
        with cache_path=None. Does not set self.data.
        
        Returns:
            dict with keys:
                - combined_train, combined_test: X_combined_*_transformed (DataFrame with NaNs)
                - n_cat_orig, n_num_orig: int
                - info: dict from info.json
                - task: 'regression' | 'classification'
                - X_cat_train_rearr, X_cat_test_rearr: for m/r _impute_categorical
                - y_train_df, y_test_df: for m/r _impute_categorical
                - X_cont_train_transformed, X_cont_test_transformed: for m/r _impute_continuous
                - cont_indices_in_combined: list of column indices (in combined) for continuous
        """
        set_seeds(seed)
        path = self.data_path
        info_path = f'{path}/info.json'
        with open(info_path, 'r') as f:
            info = json.load(f)

        self.task_type = info['task_type']
        col_names = info['column_names']
        self.cont_features = [col_names[i] for i in info['num_col_idx']]
        self.cat_features = [col_names[i] for i in info['cat_col_idx']]
        self.target = col_names[info['target_col_idx'][0]]
        self.cont_feature_indices = np.arange(len(self.cont_features))

        (X_cont_train, X_cont_test, X_cat_train, X_cat_test,
         y_train, y_test, mask_num_train, mask_num_test,
         mask_cat_train, mask_cat_test, y_mask_train, y_mask_test) = self._load_raw_data(path, info)

        X_cat_train_df, X_cat_test_df, y_train_df, y_test_df = self._prepare_missing_masks(
            X_cat_train, X_cat_test, mask_cat_train, mask_cat_test,
            y_train, y_test, y_mask_train, y_mask_test, info
        )

        (X_cat_train_rearr, X_cat_test_rearr,
         X_cont_train_rearr, X_cont_test_rearr, y_train, y_test) = self._rearrange_data_with_y(
            X_cat_train_df, X_cat_test_df, X_cont_train, X_cont_test,
            y_train_df, y_test_df, info
        )

        if self.task == "regression":
            self.cont_features = [self.target] + self.cont_features
            self.cont_feature_indices = self.cont_feature_indices + 1
        else:
            self.cat_features = [self.target] + self.cat_features

        X_combined_train_transformed, X_combined_test_transformed, cont_enc, cat_enc = self._transform_observed_data(
            X_cat_train_rearr, X_cat_test_rearr, X_cont_train_rearr, X_cont_test_rearr, info
        )
        self.cont_enc = cont_enc

        n_cat_orig = X_cat_train.shape[1] if X_cat_train is not None else 0
        n_num_orig = X_cont_train.shape[1] if X_cont_train is not None else 0
        n_cat_in_combined = X_cat_train_rearr.shape[1]
        n_cont_in_combined = X_cont_train_rearr.shape[1]
        cont_indices_in_combined = list(range(n_cat_in_combined, n_cat_in_combined + n_cont_in_combined))

        X_cont_train_transformed = X_combined_train_transformed.iloc[:, cont_indices_in_combined].copy()
        X_cont_test_transformed = X_combined_test_transformed.iloc[:, cont_indices_in_combined].copy()

        return {
            "combined_train": X_combined_train_transformed,
            "combined_test": X_combined_test_transformed,
            "n_cat_orig": n_cat_orig,
            "n_num_orig": n_num_orig,
            "info": info,
            "task": self.task,
            "X_cat_train_rearr": X_cat_train_rearr,
            "X_cat_test_rearr": X_cat_test_rearr,
            "y_train_df": y_train_df,
            "y_test_df": y_test_df,
            "X_cont_train_transformed": X_cont_train_transformed,
            "X_cont_test_transformed": X_cont_test_transformed,
            "cont_indices_in_combined": cont_indices_in_combined,
        }

    def run_m_r_imputation_benchmark(self, state, preproc):
        """
        Run m (mean/mode) or r (random sample) imputation only, using state from
        get_imputation_inputs. Used for wall-clock benchmarking. Modifies nothing.
        """
        prev = self.preproc
        self.preproc = preproc
        info = state["info"]
        task = state["task"]
        X_cat_train_rearr = state["X_cat_train_rearr"]
        X_cat_test_rearr = state["X_cat_test_rearr"]
        y_train_df = state["y_train_df"]
        y_test_df = state["y_test_df"]
        X_cont_train_transformed = state["X_cont_train_transformed"]
        X_cont_test_transformed = state["X_cont_test_transformed"]
        try:
            if task == "regression":
                self._impute_categorical(
                    X_cat_train_rearr, X_cat_test_rearr, y_train_df, y_test_df, info
                )
                self._impute_continuous_features(X_cont_train_transformed, "train")
                self._impute_continuous_features(
                    X_cont_test_transformed, "test",
                    X_continuous_train=X_cont_train_transformed
                )
            else:
                y_train_temp = pd.DataFrame(X_cat_train_rearr.iloc[:, 0])
                y_test_temp = pd.DataFrame(X_cat_test_rearr.iloc[:, 0])
                X_cat_train_no_y = X_cat_train_rearr.iloc[:, 1:]
                X_cat_test_no_y = X_cat_test_rearr.iloc[:, 1:]
                self._impute_categorical(
                    X_cat_train_no_y, X_cat_test_no_y, y_train_temp, y_test_temp, info
                )
                self._impute_continuous_features(X_cont_train_transformed, "train")
                self._impute_continuous_features(
                    X_cont_test_transformed, "test",
                    X_continuous_train=X_cont_train_transformed
                )
        finally:
            self.preproc = prev

    def preprocess_data(self, seed):
        """Unified preprocessing pipeline."""
        set_seeds(seed)
        path = self.data_path
        info_path = f'{path}/info.json'
        with open(info_path, 'r') as f:
            info = json.load(f)
        
        self.task_type = info['task_type']
    
        (X_cont_train, X_cont_test, X_cat_train, X_cat_test, 
         y_train, y_test, mask_num_train, mask_num_test, 
         mask_cat_train, mask_cat_test, y_mask_train, y_mask_test) = self._load_raw_data(path, info)
        
        X_cat_train_df, X_cat_test_df, y_train_df, y_test_df = self._prepare_missing_masks(
            X_cat_train, X_cat_test, mask_cat_train, mask_cat_test,
            y_train, y_test, y_mask_train, y_mask_test, info
        )
        
        dataset_len = y_train_df.shape[0]
        
        y_train_original_array = y_train_df.values.ravel() if y_train_df.values.ndim > 1 else y_train_df.values.ravel()
        
        if y_train_original_array.dtype.kind in ['f', 'i', 'u']:  # float, int, uint
            has_nan = np.isnan(y_train_original_array).any()
        else:
            has_nan = pd.isna(y_train_original_array).any()
        
        if has_nan and self.task != "regression":
            y_train_original = None
        else:
            y_train_original = y_train_original_array
        
        col_names = info['column_names'] 
        self.cont_features = [col_names[i] for i in info['num_col_idx']] 
        self.cat_features = [col_names[i] for i in info['cat_col_idx']]
        self.target = col_names[info['target_col_idx'][0]]
        
        self.cont_feature_indices = np.arange(len(self.cont_features))
        
        # Rearrangement: put y into X_cont for regression, X_cat for classification
        (X_cat_train_rearr, X_cat_test_rearr,
         X_cont_train_rearr, X_cont_test_rearr, y_train, y_test) = self._rearrange_data_with_y(
            X_cat_train_df, X_cat_test_df, X_cont_train, X_cont_test,
                y_train_df, y_test_df, info
            )
        if self.task == "regression":
            self.cont_features = [self.target] + self.cont_features
            self.cont_feature_indices = self.cont_feature_indices + 1
        else:
            self.cat_features = [self.target] + self.cat_features
        
        X_combined_train_transformed, X_combined_test_transformed, cont_enc, cat_enc = self._transform_observed_data(
            X_cat_train_rearr, X_cat_test_rearr, X_cont_train_rearr, X_cont_test_rearr, info
        )
        self.cont_enc = cont_enc
        nan_comb = X_combined_train_transformed.isna().sum().sum()
        print(f"[PREP] _transform_observed_data: X_combined_train_transformed {X_combined_train_transformed.shape}, dtypes={list(X_combined_train_transformed.dtypes)}, NaN count={nan_comb}")

        n_cat_orig = X_cat_train.shape[1] if X_cat_train is not None else 0
        n_num_orig = X_cont_train.shape[1] if X_cont_train is not None else 0
        
        if self.preproc in ['lgb', 'lgbs2']:
            if self.preproc == 'lgbs2':
                cache_path = f"{path}/X_imputed_tr_{self.preproc}_beta{self.beta}_use_log{self.use_log}.pkl"
            else:
                cache_path = f"{path}/X_imputed_tr_{self.preproc}.pkl"
            
            if self.preproc == 'lgbs2':
                beta_float = float(self.beta.replace('p', '.'))
                # Save predictions cache for validation set variance computation
                predictions_cache_path = None
                if self.val_prop > 0:
                    predictions_cache_path = f"{path}/val_mu_std_hat_{self.preproc}_beta{self.beta}_use_log{self.use_log}.pkl"
                print(f"[PREP] lgbs2: calling lgbm_impute2(sample=True, beta={beta_float}, use_log={self.use_log}, cache_path={cache_path})")
                X_imputed = lgbm_impute2(
                    X_combined_train_transformed,
                    sample=True,
                    random_state=42,
                    std_multiplier=1.0,
                    beta=beta_float,
                    use_log=self.use_log,
                    cache_path=cache_path,
                    save_predictions=(self.val_prop > 0),
                    predictions_cache_path=predictions_cache_path
                ).to_numpy()
                print(f"[PREP] lgbs2: after lgbm_impute2 X_imputed shape={X_imputed.shape}, sample first col[:5]={X_imputed[:5, 0] if X_imputed.size else None}")
            else:
                cache_path = f"{path}/X_imputed_tr_lgb.pkl"
                X_imputed = lgbm_impute(
                    X_combined_train_transformed,
                    sample=False,
                    random_state=42,
                    std_multiplier=1.0,
                    cache_path=cache_path
                ).to_numpy()
            X_cat_train, X_cat_test, X_cont_train, X_cont_test, y_train, y_test = self._extract_imputed_data(
                X_imputed, X_imputed, n_cat_orig, n_num_orig, info
            )
        elif self.preproc == 'miceforest':
            cache_path = f"{path}/X_imputed_tr_miceforest.pkl"
            X_imputed = miceforest_impute(
                X_combined_train_transformed,
                random_state=42,
                cache_path=cache_path
            ).to_numpy()
            X_cat_train, X_cat_test, X_cont_train, X_cont_test, y_train, y_test = self._extract_imputed_data(
                X_imputed, X_imputed, n_cat_orig, n_num_orig, info
            )
        elif self.preproc == 'zero':
            cache_path = f"{path}/X_imputed_tr_zero.pkl"
            X_imputed = impute_zero(
                X_combined_train_transformed,
                cache_path=cache_path
            ).to_numpy()
            X_cat_train, X_cat_test, X_cont_train, X_cont_test, y_train, y_test = self._extract_imputed_data(
                X_imputed, X_imputed, n_cat_orig, n_num_orig, info
            )
        elif self.preproc == 'noise':
            cache_path = f"{path}/X_imputed_tr_noise.pkl"
            X_imputed = impute_noise(
                X_combined_train_transformed,
                random_state=42,
                cache_path=cache_path
            ).to_numpy()
            X_cat_train, X_cat_test, X_cont_train, X_cont_test, y_train, y_test = self._extract_imputed_data(
                X_imputed, X_imputed, n_cat_orig, n_num_orig, info
            )
        else:
            # For 'm' and 'r': Extract transformed continuous features from X_combined_train_transformed
            # Similar to LGB_D/LGB_S/miceforest, we impute on transformed data
            n_cat_in_combined = X_cat_train_rearr.shape[1]
            n_cont_in_combined = X_cont_train_rearr.shape[1]
            cont_indices_in_combined = list(range(n_cat_in_combined, n_cat_in_combined + n_cont_in_combined))
            
            # Extract transformed continuous features
            X_cont_train_transformed = X_combined_train_transformed.iloc[:, cont_indices_in_combined].to_numpy()
            X_cont_test_transformed = X_combined_test_transformed.iloc[:, cont_indices_in_combined].to_numpy()
            
            if self.task == "regression":
                X_cat_train, X_cat_test, y_train, y_test = self._impute_categorical(
                    X_cat_train_rearr, X_cat_test_rearr, y_train_df, y_test_df, info
                )
                X_cont_train = self._impute_continuous_features(
                    X_cont_train_transformed, "train"
                )
                X_cont_test = self._impute_continuous_features(
                    X_cont_test_transformed, "test", X_continuous_train=X_cont_train_transformed
                )
            else:
                y_train_temp = pd.DataFrame(X_cat_train_rearr.iloc[:, 0])
                y_test_temp = pd.DataFrame(X_cat_test_rearr.iloc[:, 0])
                X_cat_train_no_y = X_cat_train_rearr.iloc[:, 1:]
                X_cat_test_no_y = X_cat_test_rearr.iloc[:, 1:]

                X_cat_train, X_cat_test, y_train, y_test = self._impute_categorical(
                    X_cat_train_no_y, X_cat_test_no_y, y_train_temp, y_test_temp, info
                )
                X_cat_train = np.column_stack((y_train, X_cat_train))
                X_cat_test = np.column_stack((y_test, X_cat_test))
                X_cont_train = self._impute_continuous_features(
                    X_cont_train_transformed, "train"
                )
                X_cont_test = self._impute_continuous_features(
                    X_cont_test_transformed, "test", X_continuous_train=X_cont_train_transformed
                )

        if self.val_prop > 0:
            prop = self.val_prop / (1 - self.test_prop)
            if self.task == "regression" or y_train_original is None:
                stratify = None
            else:
                if y_train_original.dtype.kind in ['f', 'i', 'u']:
                    has_nan = np.isnan(y_train_original).any()
                else:
                    has_nan = pd.isna(y_train_original).any()
                stratify = None if has_nan else y_train_original 

            X_cat_train, X_cat_val, X_cont_train, X_cont_val, M_cat_train, M_cat_val, M_cont_train, M_cont_val, y_train, y_val = (
                model_selection.train_test_split(
                    X_cat_train,
                    X_cont_train,
                    mask_cat_train,
                    mask_num_train,
                    y_train,
                    stratify=stratify,
                    test_size=prop,
                    random_state=42,
                )
            )
        else:
            M_cat_train, M_cont_train = mask_cat_train, mask_num_train
            M_cat_val, M_cont_val = None, None
            X_cat_val, X_cont_val, y_val = None, None, None
        
        # Data is already rearranged (y is in X_cont for regression, X_cat for classification)
        # No need to rearrange again - it was done at the beginning
        
        X_cat = {"train": X_cat_train, "val": X_cat_val, "test": X_cat_test}
        X_cont = {"train": X_cont_train, "val": X_cont_val, "test": X_cont_test}
        M_cat = {"train": M_cat_train, "val": M_cat_val, "test": mask_cat_test}
        M_cont = {"train": M_cont_train, "val": M_cont_val, "test": mask_num_test}
        y = {"train": y_train, "val": y_val, "test": y_test}

        # Store number of features to construct model later on
        self.num_cat_features = (
            X_cat["train"].shape[1] if X_cat["train"] is not None else 0
        )
        self.num_cont_features = (
            X_cont["train"].shape[1] if X_cont["train"] is not None else 0
        )
        self.num_total_features = self.num_cat_features + self.num_cont_features

        # CRITICAL: Round categorical values if they are numeric (floats)
        # After imputation, categoricals may contain decimal values from some imputers.
        # Round them to integers and convert to strings BEFORE cat_int_enc.fit/transform
        # so OrdinalEncoder does not treat float values as strings ('4.2', '6.4'), creating "unknown categories"
        if self.num_cat_features > 0:
            for k in X_cat.keys():
                if X_cat[k] is not None:
                    # Check if categorical data is numeric (contains floats)
                    if X_cat[k].dtype.kind == 'f':  # float
                        # Round to nearest integer and convert to string
                        # This ensures we have discrete category values, not continuous floats
                        X_cat[k] = np.round(X_cat[k]).astype(int).astype(str)
                    elif X_cat[k].dtype.kind in ['i', 'u']:  # integer, unsigned integer
                        # Already integers, convert to strings for OrdinalEncoder
                        X_cat[k] = X_cat[k].astype(str)
                    elif X_cat[k].dtype.kind == 'O':  # object (strings)
                        # Already strings, but check if any are numeric strings that should be rounded
                        # Convert numeric strings to integers, then back to strings
                        X_cat_copy = X_cat[k].copy()
                        for col_idx in range(X_cat_copy.shape[1]):
                            col = X_cat_copy[:, col_idx]
                            # Try to convert to float, if successful, round and convert back to string
                            try:
                                numeric_col = pd.to_numeric(col, errors='coerce')
                                if not numeric_col.isna().all():
                                    # Has numeric values, round them
                                    rounded_col = np.round(numeric_col.fillna(0)).astype(int).astype(str)
                                    # Preserve NaN as string 'nan' or keep original
                                    nan_mask = numeric_col.isna()
                                    if nan_mask.any():
                                        rounded_col[nan_mask] = col[nan_mask]  # Keep original for NaN
                                    X_cat_copy[:, col_idx] = rounded_col
                            except:
                                pass  # If conversion fails, keep original
                        X_cat[k] = X_cat_copy
                    # Now X_cat contains strings, ready for OrdinalEncoder

        # Preprocess categorical classes and convert to integers
        if self.num_cat_features > 0:
            
            # Fit cat_int_enc on ALL data (including imputed values) to handle all categories
            # This ensures we can encode all categories that appear in the data, including
            # those from imputation. This matches the reference implementation.
            self.cat_int_enc = preprocessing.OrdinalEncoder()
            # For 'zero' preproc, we need to handle NaN values that will become UNK tokens
            # Replace NaN with a placeholder before fitting/transforming, then map to UNK_INDEX
            if self.preproc == 'zero':
                # Create a placeholder value that's unlikely to appear in real data
                placeholder = "__UNK_PLACEHOLDER__"
                # Prepare data for fitting: replace NaN with placeholder
                all_cat_data = []
                for k in X_cat.keys():
                    if X_cat[k] is not None:
                        X_cat_copy = X_cat[k].copy().astype(object)
                        # Replace NaN with placeholder (handle NaN in object arrays)
                        nan_mask = pd.isna(X_cat_copy)
                        if nan_mask.any():
                            X_cat_copy[nan_mask] = placeholder
                        all_cat_data.append(X_cat_copy)
                
                if all_cat_data:
                    self.cat_int_enc.fit(np.concatenate(all_cat_data))
                    # Find the encoded value for placeholder (will be UNK_INDEX)
                    # Create a test array with placeholder for each feature
                    test_array = np.array([[placeholder] * X_cat["train"].shape[1]], dtype=object)
                    placeholder_encoded = self.cat_int_enc.transform(test_array)[0]
                    self.unk_placeholder_encoded = placeholder_encoded
                else:
                    self.cat_int_enc.fit(
                        np.concatenate(
                            list(X_cat[k] for k in X_cat.keys() if X_cat[k] is not None)
                        )
                    )
                    self.unk_placeholder_encoded = None
                
                # Transform with placeholder replacement
                X_cat_transformed = {}
                for k, v in X_cat.items():
                    if v is not None:
                        v_copy = v.copy().astype(object)
                        # Replace NaN with placeholder
                        nan_mask = pd.isna(v_copy)
                        if nan_mask.any():
                            v_copy[nan_mask] = placeholder
                        X_cat_transformed[k] = self.cat_int_enc.transform(v_copy)
                    else:
                        X_cat_transformed[k] = None
                X_cat = X_cat_transformed
            else:
                self.cat_int_enc.fit(
                    np.concatenate(
                        list(X_cat[k] for k in X_cat.keys() if X_cat[k] is not None)
                    )
                )
                
                X_cat = {
                    k: self.cat_int_enc.transform(v) if v is not None else None
                    for k, v in X_cat.items()
                }
            # Compute num_cats from unique values in training data
            # Since cat_int_enc already sees all data, we use it directly (single transformation)
            # Add 1 to each category to account for UNK token (TabDDPM expects [0, n_cats] range)
            self.num_cats = None
            self.most_frequent_cat_indices = None  # Cache most frequent category for UNK_INDEX mapping
            if self.cat_encoding is None:
                num_cats = []
                most_frequent_indices = []
                for i in range(self.num_cat_features):
                    train_vals = X_cat["train"][:, i]
                    # For 'zero' preproc, exclude placeholder-encoded values when computing max
                    if self.preproc == 'zero' and hasattr(self, 'unk_placeholder_encoded') and self.unk_placeholder_encoded is not None:
                        placeholder_val = self.unk_placeholder_encoded[i]
                        train_vals_no_unk = train_vals[train_vals != placeholder_val]
                        if len(train_vals_no_unk) > 0:
                            max_val = train_vals_no_unk.max()
                        else:
                            max_val = -1  # All values are UNK
                    else:
                        max_val = train_vals.max()
                    # num_cats should be max value + 1 to cover range [0, max]
                    # Since cat_int_enc is fitted on all data, values are sequential starting from 0
                    # TabDDPM will add 1 to this for UNK token, so model expects [0, num_cats] range
                    num_cats.append(int(max_val) + 1)
                    # Compute most frequent category index for UNK_INDEX mapping (excluding UNK)
                    if self.preproc == 'zero' and hasattr(self, 'unk_placeholder_encoded') and self.unk_placeholder_encoded is not None:
                        placeholder_val = self.unk_placeholder_encoded[i]
                        train_vals_no_unk = train_vals[train_vals != placeholder_val]
                        if len(train_vals_no_unk) > 0:
                            mode_result = stats.mode(train_vals_no_unk, keepdims=True)
                            most_frequent_indices.append(int(mode_result.mode[0]))
                        else:
                            most_frequent_indices.append(0)  # Default if all are UNK
                    else:
                        mode_result = stats.mode(train_vals, keepdims=True)
                        most_frequent_indices.append(int(mode_result.mode[0]))
                    
                # For 'zero' preproc, now map placeholder-encoded values to UNK_INDEX (n_cats)
                if self.preproc == 'zero' and hasattr(self, 'unk_placeholder_encoded') and self.unk_placeholder_encoded is not None:
                    for k in X_cat.keys():
                        if X_cat[k] is not None:
                            for feat_idx in range(X_cat[k].shape[1]):
                                placeholder_val = self.unk_placeholder_encoded[feat_idx]
                                unk_index = num_cats[feat_idx]  # UNK_INDEX = n_cats
                                mask = X_cat[k][:, feat_idx] == placeholder_val
                                if mask.any():
                                    X_cat[k][mask, feat_idx] = unk_index
                
                self.num_cats = num_cats
                self.most_frequent_cat_indices = most_frequent_indices
                
                # Clamp categorical values to valid range before feeding to model
                # This ensures the model receives properly bounded values during training/generation
                # For 'zero' preproc: valid range is [0, n_cats] where n_cats is UNK_INDEX
                # For other preprocs: valid range is [0, n_cats-1]
                for k in X_cat.keys():
                    if X_cat[k] is not None:
                        for feat_idx in range(self.num_cat_features):
                            n_cats = self.num_cats[feat_idx]
                            # Clamp negative values to 0
                            neg_mask = X_cat[k][:, feat_idx] < 0
                            if neg_mask.any():
                                X_cat[k][neg_mask, feat_idx] = 0
                            
                            # Clamp out-of-range values
                            if self.preproc == 'zero':
                                # For 'zero' preproc, n_cats is UNK_INDEX, so valid range is [0, n_cats]
                                # Clamp values > n_cats to n_cats
                                unk_mask = X_cat[k][:, feat_idx] > n_cats
                                if unk_mask.any():
                                    X_cat[k][unk_mask, feat_idx] = n_cats
                            else:
                                # For other preprocs, valid range is [0, n_cats-1]
                                # Clamp values >= n_cats to n_cats-1
                                unk_mask = X_cat[k][:, feat_idx] >= n_cats
                                if unk_mask.any():
                                    X_cat[k][unk_mask, feat_idx] = n_cats - 1
        else:
            self.num_cats = None
            X_cat = {k: None for k in y.keys()}

        if self.num_cont_features == 0:
            X_cont = {k: None for k in y.keys()}
            
        # SIMPLIFIED: Use cat_int_enc directly (single transformation)
        # No need for second OrdinalEncoder since cat_int_enc already sees all data
        # X_cat already contains integer-encoded values from cat_int_enc

        # Continuous features are already transformed and standardized in _transform_observed_data
        # Standardization was applied BEFORE imputation using nanmean/nanstd on observed values
        # This ensures imputers work on normalizedtrain_test_split data
        # No additional standardization needed here (data is already standardized)
        
        return OriginalData(X_cat, X_cont, M_cat, M_cont, y)


    def get_train_loader(self, batch_size, partition="train"):
        """Transform data and construct train/valid/test loader."""

        batch_size = min(batch_size, self.data.get_train_obs())
        X_cat_transformed, X_cont_transformed, M_cat, M_cont, y = self.data.get_data()        
        
        def get_loader(partition):
            X_cat_train = (
                torch.tensor(X_cat_transformed[partition]).long()
                if self.num_cat_features > 0
                else None
            )
            X_cont_train = (
                torch.tensor(X_cont_transformed[partition]).float()
                if self.num_cont_features > 0
                else None
            )
            
            M_cat_train = (
                torch.tensor(M_cat[partition]).float()
                if self.num_cat_features > 0
                else None
            )
            M_cont_train = (
                torch.tensor(M_cont[partition]).float()
                if self.num_cont_features > 0
                else None
            )
            y_train = None  # y is already included in X_cat or X_cont

            if partition == "train":
                shuffle = True
                drop_last = True
            else:
                shuffle = False
                drop_last = False

            return FastTensorDataLoader(
                X_cat_train,
                X_cont_train,
                M_cat_train, 
                M_cont_train,
                y_train,
                batch_size=batch_size,
                shuffle=shuffle,
                drop_last=drop_last,
            )

        return get_loader(partition)

    def postprocess_gen_data(self, X_cat_gen, X_cont_gen, y_gen):
        """Apply inverse transformation to generated data points."""
        # Handle UNK_INDEX (values >= n_cats) by mapping to most frequent category
        # TabDDPM generates values in [0, n_cats] where n_cats is the UNK_INDEX,
        # but cat_int_enc only knows about [0, n_cats-1]
        # Note: If network receives integers during training, it should output integers,
        # but we apply clamping as a safety check to ensure valid range
        if X_cat_gen is not None and self.num_cat_features > 0 and self.num_cats is not None:
            X_cat_gen = X_cat_gen.copy()  # Avoid modifying original
            
            # Safety check: Round to integers (model should output integers if trained on integers)
            # This is a safety check in case model outputs floats
            X_cat_gen = np.round(X_cat_gen).astype(int)
            
            # Clamp to valid range [0, n_cats-1] for each feature
            for feat_idx in range(self.num_cat_features):
                n_cats = self.num_cats[feat_idx]
                
                # Clamp negative values to 0
                neg_mask = X_cat_gen[:, feat_idx] < 0
                if neg_mask.any():
                    X_cat_gen[neg_mask, feat_idx] = 0
                
                # Clamp values >= n_cats (UNK_INDEX) to most frequent category
                unk_mask = X_cat_gen[:, feat_idx] >= n_cats
                if unk_mask.any():
                    # Use cached most frequent category index (computed during preprocessing)
                    if self.most_frequent_cat_indices is not None and feat_idx < len(self.most_frequent_cat_indices):
                        most_frequent = self.most_frequent_cat_indices[feat_idx]
                    else:
                        # Fallback: use index 0 (first category) if not cached
                        most_frequent = 0
                    
                    # Replace UNK_INDEX with most frequent category
                    X_cat_gen[unk_mask, feat_idx] = most_frequent
                    n_replaced = unk_mask.sum()
                    if n_replaced > 0:
                        logging.warning(f"Clamped {n_replaced} UNK_INDEX values to most frequent category for feature {feat_idx}")
        
        # Now X_cat_gen contains valid integer indices in range [0, n_cats-1], safe for inverse_transform
        X_cat_gen = self.cat_int_enc.inverse_transform(X_cat_gen)

        if len(self.cont_feature_indices) > 0 and self.cont_scaler is not None:
            if self.cont_scaler == "quantile":
                if not hasattr(self, 'cont_enc') or self.cont_enc is None:
                    raise ValueError("cont_enc is not set. Cannot perform inverse transform. This should have been set during preprocessing.")

                if not hasattr(self, 'cont_feature_indices_for_transformer'):
                    raise ValueError(
                        "cont_feature_indices_for_transformer not set. "
                        "This should have been set during preprocessing."
                    )

                indices_to_transform = np.asarray(self.cont_feature_indices_for_transformer, dtype=int)
                if indices_to_transform.size == 0:
                    return X_cat_gen, X_cont_gen, y_gen
                if np.any(indices_to_transform < 0) or np.any(indices_to_transform >= X_cont_gen.shape[1]):
                    raise ValueError(
                        f"cont_feature_indices_for_transformer {indices_to_transform} "
                        f"are invalid for X_cont_gen with shape {X_cont_gen.shape}."
                    )
                
                # Apply inverse transformations in reverse order:
                # 1. First, inverse standardization (if it was applied)
                if hasattr(self, 'X_cont_mean') and hasattr(self, 'X_cont_std'):
                    X_cont_gen[:, indices_to_transform] = (
                        X_cont_gen[:, indices_to_transform] * self.X_cont_std
                    ) + self.X_cont_mean
                
                # 2. Then, inverse QuantileTransformer
                X_cont_gen[:, indices_to_transform] = self.cont_enc.inverse_transform(
                    X_cont_gen[:, indices_to_transform]
                )
            
        return X_cat_gen, X_cont_gen, y_gen

    def check_bounds(self, X_cont_gen):
        _, X_cont_train, _, _, _ = self.data.get_train_data()
        lower_bounds = X_cont_train.min(axis=0)
        upper_bounds = X_cont_train.max(axis=0)

        output_tbl = PrettyTable()
        labels = [f"feature {i}" for i in range(X_cont_train.shape[1])]
        output_tbl.field_names = ["constraint", *labels]
        output_tbl.add_row(["minimum", *lower_bounds])
        output_tbl.add_row(["< minimum", *np.sum(X_cont_gen < lower_bounds, axis=0)])
        output_tbl.add_row(["maximum", *upper_bounds])
        output_tbl.add_row(["> maximum", *np.sum(X_cont_gen > upper_bounds, axis=0)])

        with open(os.path.join(self.logdir, "bound_checks.txt"), "w") as f:
            f.write(str(output_tbl))

    def save_data(self, X_cat_gen, X_cont_gen, y_gen, sample_path = None):
        path = self.data_path
        info_path = f'{path}/info.json'
        with open(info_path, 'r') as f:
            info = json.load(f)
        
        if self.task == 'regression':
            y_gen = X_cont_gen[:, 0:1]
            X_cont_gen = X_cont_gen[:, 1:]
        else:
            y_gen = X_cat_gen[:, 0:1]
            X_cat_gen = X_cat_gen[:, 1:]
        
        syn_df = recover_data(X_cont_gen, X_cat_gen, y_gen, info)
        
        if 'idx_name_mapping' in info:
            idx_name_mapping = {int(key): value for key, value in info['idx_name_mapping'].items()}
        else:
            col_names = info.get('column_names', [str(i) for i in range(syn_df.shape[1])])
            idx_name_mapping = {i: str(col_names[i]) for i in range(min(len(col_names), syn_df.shape[1]))}
        syn_df.rename(columns=idx_name_mapping, inplace=True)
        if sample_path is not None:
            syn_df.to_csv(sample_path, index=False)
        else: 
            syn_df.to_csv(self.sample_path, index=False)



def recover_data(syn_num, syn_cat, syn_target, info):

    num_col_idx = info['num_col_idx']
    cat_col_idx = info['cat_col_idx']
    target_col_idx = info['target_col_idx']

    n_cols = len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)
    if 'idx_mapping' in info:
        idx_mapping = {int(key): value for key, value in info['idx_mapping'].items()}
    else:
        idx_mapping = {i: i for i in range(n_cols)}

    syn_df = pd.DataFrame()

    if info['task_type'] == 'regression':
        for i in range(len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)):
            if i in set(num_col_idx):
                mapped_idx = idx_mapping[i]
                if mapped_idx >= syn_num.shape[1]:
                    logging.error(f"recover_data: index {i} (num) maps to {mapped_idx} but syn_num has {syn_num.shape[1]} columns")
                else:
                    syn_df[i] = syn_num[:, mapped_idx]
            elif i in set(cat_col_idx):
                mapped_idx = idx_mapping[i] - len(num_col_idx)
                if mapped_idx >= syn_cat.shape[1]:
                    logging.error(f"recover_data: index {i} (cat) maps to {mapped_idx} but syn_cat has {syn_cat.shape[1]} columns")
                else:
                    syn_df[i] = syn_cat[:, mapped_idx]
            else:
                mapped_idx = idx_mapping[i] - len(num_col_idx) - len(cat_col_idx)
                if syn_target is not None and mapped_idx >= syn_target.shape[1]:
                    logging.error(f"recover_data: index {i} (target) maps to {mapped_idx} but syn_target has {syn_target.shape[1]} columns")
                elif syn_target is not None:
                    syn_df[i] = syn_target[:, mapped_idx]
                else:
                    logging.warning("recover_data: index %s is target but syn_target is None", i)


    else:
        for i in range(len(num_col_idx) + len(cat_col_idx) + len(target_col_idx)):
            if i in set(num_col_idx):
                mapped_idx = idx_mapping[i]
                if mapped_idx >= syn_num.shape[1]:
                    logging.error(f"recover_data: index {i} (num) maps to {mapped_idx} but syn_num has {syn_num.shape[1]} columns")
                else:
                    syn_df[i] = syn_num[:, mapped_idx]
            elif i in set(cat_col_idx):
                mapped_idx = idx_mapping[i] - len(num_col_idx)
                if mapped_idx >= syn_cat.shape[1]:
                    logging.error(f"recover_data: index {i} (cat) maps to {mapped_idx} but syn_cat has {syn_cat.shape[1]} columns")
                else:
                    syn_df[i] = syn_cat[:, mapped_idx]
            else:
                mapped_idx = idx_mapping[i] - len(num_col_idx) - len(cat_col_idx)
                if syn_target is not None and mapped_idx >= syn_target.shape[1]:
                    logging.error(f"recover_data: index {i} (target) maps to {mapped_idx} but syn_target has {syn_target.shape[1]} columns")
                elif syn_target is not None:
                    syn_df[i] = syn_target[:, mapped_idx]
                else:
                    logging.warning("recover_data: index %s is target but syn_target is None", i)

    return syn_df