import os
import pickle
import numpy as np
import pandas as pd
from feature_engine.imputation import RandomSampleImputer
import lightgbm as lgb
import miceforest as mf

from .data_utils import replace_nan_with_mean_or_mode


def impute_naive(df: pd.DataFrame) -> pd.DataFrame:
    """Naive mean/mode imputation for each column."""
    return replace_nan_with_mean_or_mode(df)


def impute_cdf(df: pd.DataFrame, random_state: int = 42) -> pd.DataFrame:
    """Random-sample (empirical CDF) imputation."""
    rsi = RandomSampleImputer(random_state=random_state)
    return rsi.fit_transform(df)


def lgbm_impute(df: pd.DataFrame,
                 sample: bool = False,
                 random_state: int = 42,
                 n_estimators: int = 50,
                 std_multiplier: float = 1.0,
                 beta: float = 1.0,
                 cache_path: str = None,
                 use_log: bool = True,
                 save_predictions: bool = True,
                 predictions_cache_path: str = None) -> pd.DataFrame:
    """
    LGBM imputation with label smoothing for categorical sampling.
    Works on already-transformed data (quantile-transformed numerical, categorical-encoded categorical).
    
    Args:
        df: DataFrame to impute
        sample: Whether to use sampling for imputation
        random_state: Random seed
        n_estimators: Number of estimators for LightGBM
        std_multiplier: Standard deviation multiplier for sampling
        beta: Beta parameter for shrinkage (std_hat = sigma_test * beta)
        cache_path: Optional path to cache file. If provided and file exists, loads cached result.
                    If file doesn't exist, performs imputation and saves to cache.
        use_log: Whether to use log-space variance estimation
        save_predictions: Whether to save mu_test and std_hat predictions for missing values
        predictions_cache_path: Path to save predictions cache (only used if save_predictions=True)
    
    Returns:
        Imputed DataFrame
    """
    if cache_path is not None and os.path.exists(cache_path):
        with open(cache_path, 'rb') as f:
            return pickle.load(f)
    
    np.random.seed(random_state)
    nan_idx = df.isna()
    org_df = df.copy(deep=True)
    imputed_df = df.copy(deep=True)

    dtypes = df.dtypes.tolist()
    categorical_cols = [idx for idx, col in enumerate(dtypes)
                        if pd.api.types.is_categorical_dtype(col) or col == 'object']
    df[categorical_cols] = df[categorical_cols].astype('category')
    
    # Dictionary to store predictions for missing continuous values
    predictions_dict = {}

    for col_idx, col in enumerate(df.columns):
        train_idx = df[col].notna()
        test_idx = df[col].isna()

        X_train = df.loc[train_idx, df.columns != col]
        y_train = df.loc[train_idx, col]
        X_test = df.loc[test_idx, df.columns != col]

        if col in categorical_cols:
            categories = df[col].astype("category").cat.categories
            model = lgb.LGBMClassifier(n_estimators=n_estimators,
                                       random_state=random_state,
                                       verbosity=-1)
            model.fit(X_train, y_train.cat.codes)
            pred = model.predict(X_test)

            if not sample:
                det_cats = pd.Categorical.from_codes(pred, categories=categories)
                imputed_df.loc[test_idx, col] = det_cats.astype(object)
                imputed_df[col] = imputed_df[col].astype("category").cat.set_categories(categories)
            else:
                probas = model.predict_proba(X_test)
                # Label smoothing
                alpha = 0.05  # Fixed heuristic: 5% uniform prior
                K = probas.shape[1]
                probas = (1 - alpha) * probas + (alpha / K)
                # --------------------------------------------

                samples = [np.random.choice(len(p), p=p) for p in probas]
                sampled_cats = pd.Categorical.from_codes(samples, categories=categories)
                imputed_df.loc[test_idx, col] = sampled_cats.astype(object)
                imputed_df[col] = imputed_df[col].astype("category").cat.set_categories(categories)
        else:
            model_mu = lgb.LGBMRegressor(n_estimators=n_estimators,
                                        random_state=random_state,
                                        verbosity=-1)
            model_mu.fit(X_train, y_train)
            pred = model_mu.predict(X_test)

            if not sample:
                imputed_df.loc[test_idx, col] = pred
            else:
                # 1. Compute training residuals
                mu_train = model_mu.predict(X_train)
                # Use a small epsilon to avoid log(0)
                epsilon = 1e-6
                log_res_sq = np.log((y_train - mu_train)**2 + epsilon)

                # 2. Fit variance model on log-space
                model_log_sigma = lgb.LGBMRegressor(n_estimators=n_estimators,
                                                     random_state=random_state,
                                                     verbosity=-1)
                model_log_sigma.fit(X_train, log_res_sq)

                # 3. Predict and transform back
                mu_test = model_mu.predict(X_test)
                log_sigma_sq_test = model_log_sigma.predict(X_test)
                sigma_test = np.sqrt(np.exp(log_sigma_sq_test))

                std_hat = sigma_test * beta
                samples = np.random.normal(mu_test, std_hat)
                imputed_df.loc[test_idx, col] = samples
                
                # Save predictions if requested
                if save_predictions:
                    # Get row indices as positional indices (0, 1, 2, ...)
                    # test_idx is a boolean mask, so we convert to integer indices
                    row_indices = np.where(test_idx)[0]
                    predictions_dict[col_idx] = {
                        'row_indices': row_indices,
                        'mu_test': mu_test.copy(),
                        'std_hat': std_hat.copy()
                    }

    org_df[nan_idx] = imputed_df[nan_idx]
    if cache_path is not None:
        cache_dir = os.path.dirname(cache_path)
        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)
        with open(cache_path, 'wb') as f:
            pickle.dump(org_df, f)
    if save_predictions and predictions_cache_path is not None and len(predictions_dict) > 0:
        cache_dir = os.path.dirname(predictions_cache_path)
        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)
        with open(predictions_cache_path, 'wb') as f:
            pickle.dump(predictions_dict, f)
    return org_df


def miceforest_impute(df: pd.DataFrame, random_state: int = 42, cache_path: str = None) -> pd.DataFrame:
    """
    MICE imputation using miceforest library.
    Works on already-transformed data (quantile-transformed numerical, categorical-encoded categorical).
    
    Args:
        df: DataFrame to impute
        random_state: Random seed
        cache_path: Optional path to cache file. If provided and file exists, loads cached result.
                    If file doesn't exist, performs imputation and saves to cache.
    
    Returns:
        Imputed DataFrame
    """
    # Check for cached result
    if cache_path is not None and os.path.exists(cache_path):
        print(f'⚠️  Loading cached miceforest imputation from {cache_path}')
        with open(cache_path, 'rb') as f:
            return pickle.load(f)
    
    np.random.seed(random_state)
    nan_idx = df.isna()
    org_df = df.copy(deep=True)

    dtypes = df.dtypes.tolist()
    numerical_cols = [idx for idx, col in enumerate(dtypes) if pd.api.types.is_numeric_dtype(col)]
    categorical_cols = [idx for idx, col in enumerate(dtypes)
                        if pd.api.types.is_categorical_dtype(col) or col == 'object']

    df_work = df.copy(deep=True)

    cat_categories = {}
    for idx in categorical_cols:
        col_name = df.columns[idx]
        cat_series = df[col_name].astype("category")
        cats = cat_series.cat.categories
        codes = cat_series.cat.codes.astype(float)
        codes = codes.replace(-1, np.nan)
        df_work[col_name] = codes
        cat_categories[col_name] = cats

    numeric_cols = df_work.select_dtypes(include=[np.number]).columns

    df_work.columns = df_work.columns.astype(str)
    kernel = mf.ImputationKernel(
        df_work,
        num_datasets=1,
        random_state=random_state,
    )
    kernel.mice(5)
    imputed = kernel.complete_data(0)

    imputed.columns = org_df.columns 
    imputed_converted = imputed.copy(deep=True)
    for col_name, cats in cat_categories.items():
        col_vals = imputed_converted[col_name].to_numpy()
        if len(cats) == 0:
            continue
        col_int = pd.Series(col_vals).round().astype(int).to_numpy()
        col_int = np.clip(col_int, 0, len(cats) - 1)
        mapped_cat = pd.Categorical.from_codes(col_int, categories=cats)
        imputed_converted[col_name] = mapped_cat.astype(object)

    org_df[nan_idx] = imputed_converted[nan_idx]
    
    # Save to cache if cache_path is provided
    if cache_path is not None:
        cache_dir = os.path.dirname(cache_path)
        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)
        with open(cache_path, 'wb') as f:
            pickle.dump(org_df, f)
        print(f'✅ Saved miceforest imputation to cache: {cache_path}')
    
    return org_df


def impute_zero(df: pd.DataFrame, cache_path: str = None) -> pd.DataFrame:
    """
    Baseline imputation: zeros for numerical, preserve NaN for categorical (will become UNK tokens).
    Works on already-transformed data (quantile-transformed numerical, categorical-encoded categorical).
    
    For categorical columns, missing values are preserved as NaN so they can be encoded as UNK tokens.
    During training, UNK tokens are ignored in loss computation. During sampling, UNK tokens are
    replaced with the most frequent category in postprocessing.
    
    Args:
        df: DataFrame to impute
        cache_path: Optional path to cache file. If provided and file exists, loads cached result.
                    If file doesn't exist, performs imputation and saves to cache.
    
    Returns:
        Imputed DataFrame (with NaN preserved for categorical columns)
    """
    # Check for cached result
    if cache_path is not None and os.path.exists(cache_path):
        print(f'⚠️  Loading cached zero imputation from {cache_path}')
        with open(cache_path, 'rb') as f:
            return pickle.load(f)
    
    nan_idx = df.isna()
    org_df = df.copy(deep=True)

    dtypes = df.dtypes.tolist()
    numerical_cols = [idx for idx, col in enumerate(dtypes) if pd.api.types.is_numeric_dtype(col)]
    categorical_cols = [idx for idx, col in enumerate(dtypes)
                        if pd.api.types.is_categorical_dtype(col) or col == 'object']

    imputed_converted = df.copy(deep=True)

    # Handle categorical columns: preserve NaN (will become UNK tokens)
    # No imputation needed - NaN will be handled during encoding to become UNK_INDEX
    for idx in categorical_cols:
        col_name = df.columns[idx]
        # Keep original column as-is (NaN preserved)
        imputed_converted[col_name] = df[col_name]

    # Handle numerical columns: impute with 0
    for idx in numerical_cols:
        col_name = df.columns[idx]
        col_vals = df[col_name].copy()
        missing_mask = col_vals.isna()
        if missing_mask.any():
            col_vals[missing_mask] = 0.0
            imputed_converted[col_name] = col_vals
        else:
            imputed_converted[col_name] = df[col_name]

    # Only replace missing values for numerical columns
    # For categorical, NaN is preserved (will become UNK)
    for idx in numerical_cols:
        col_name = df.columns[idx]
        org_df.loc[org_df[col_name].isna(), col_name] = imputed_converted.loc[org_df[col_name].isna(), col_name]
    
    # Save to cache if cache_path is provided
    if cache_path is not None:
        cache_dir = os.path.dirname(cache_path)
        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)
        with open(cache_path, 'wb') as f:
            pickle.dump(org_df, f)
        print(f'✅ Saved zero imputation to cache: {cache_path}')
        print(f'   Note: Categorical NaN preserved (will become UNK tokens)')
    
    return org_df


def impute_noise(df: pd.DataFrame, random_state: int = 42, cache_path: str = None) -> pd.DataFrame:
    """
    Baseline imputation: random Gaussian noise for numerical, uniform random category for categorical.
    Works on already-transformed data (quantile-transformed numerical, categorical-encoded categorical).
    
    For numerical columns, uses Gaussian noise with mean=0 and std equal to the observed (non-missing) 
    values' standard deviation. Falls back to std=1.0 if all values are missing or std is 0/NaN.
    
    Args:
        df: DataFrame to impute
        random_state: Random seed
        cache_path: Optional path to cache file. If provided and file exists, loads cached result.
                    If file doesn't exist, performs imputation and saves to cache.
    
    Returns:
        Imputed DataFrame
    """
    # Check for cached result
    if cache_path is not None and os.path.exists(cache_path):
        print(f'⚠️  Loading cached noise imputation from {cache_path}')
        with open(cache_path, 'rb') as f:
            return pickle.load(f)
    
    np.random.seed(random_state)
    nan_idx = df.isna()
    org_df = df.copy(deep=True)

    dtypes = df.dtypes.tolist()
    numerical_cols = [idx for idx, col in enumerate(dtypes) if pd.api.types.is_numeric_dtype(col)]
    categorical_cols = [idx for idx, col in enumerate(dtypes)
                        if pd.api.types.is_categorical_dtype(col) or col == 'object']

    df_work = df.copy(deep=True)
    imputed_converted = df.copy(deep=True)

    # Handle categorical columns
    cat_categories = {}
    for idx in categorical_cols:
        col_name = df.columns[idx]
        cat_series = df[col_name].astype("category")
        cats = cat_series.cat.categories
        codes = cat_series.cat.codes.astype(float)
        codes = codes.replace(-1, np.nan)
        df_work[col_name] = codes
        cat_categories[col_name] = cats
        
        # Get valid category codes (excluding NaN)
        valid_codes = codes.dropna().unique()
        if len(valid_codes) == 0:
            valid_codes = np.array([0])  # Default to first category if all are NaN
        
        # Impute missing values with uniform random sample from valid codes
        missing_mask = codes.isna()
        if missing_mask.any():
            n_missing = missing_mask.sum()
            # Uniform random sampling from valid category codes
            random_codes = np.random.choice(valid_codes, size=n_missing)
            codes_imputed = codes.copy()
            codes_imputed[missing_mask] = random_codes
            # Convert back to categorical
            col_int = codes_imputed.round().astype(int).to_numpy()
            col_int = np.clip(col_int, 0, len(cats) - 1)
            mapped_cat = pd.Categorical.from_codes(col_int, categories=cats)
            imputed_converted[col_name] = mapped_cat.astype(object)
        else:
            imputed_converted[col_name] = cat_series.astype(object)

    # Handle numerical columns: impute with random Gaussian noise (mean=0, std=observed_std)
    for idx in numerical_cols:
        col_name = df.columns[idx]
        col_vals = df[col_name].copy()
        missing_mask = col_vals.isna()
        if missing_mask.any():
            n_missing = missing_mask.sum()
            # Compute std from observed (non-missing) values
            observed_vals = col_vals[~missing_mask]
            if len(observed_vals) > 0:
                col_std = observed_vals.std()
                # Use std=1.0 as fallback if all values are the same (std=0) or if std is NaN
                if pd.isna(col_std) or col_std == 0.0:
                    col_std = 1.0
            else:
                # If all values are missing, use std=1.0 as default
                col_std = 1.0
            # Random Gaussian noise: mean=0, std=observed_std
            noise = np.random.normal(0.0, col_std, size=n_missing)
            col_vals[missing_mask] = noise
            imputed_converted[col_name] = col_vals
        else:
            imputed_converted[col_name] = df[col_name]

    # Only replace missing values
    org_df[nan_idx] = imputed_converted[nan_idx]
    
    # Save to cache if cache_path is provided
    if cache_path is not None:
        cache_dir = os.path.dirname(cache_path)
        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)
        with open(cache_path, 'wb') as f:
            pickle.dump(org_df, f)
        print(f'✅ Saved noise imputation to cache: {cache_path}')
    
    return org_df

