import os
from typing import Union, Tuple, Optional, List

import numpy as np
import pandas as pd
import torch
from omegaconf import DictConfig, ListConfig
import numpy as np
from sklearn.neighbors import KernelDensity
from sklearn.utils import resample
from scipy.stats import zscore

def train_test_split(
        x: Union[torch.Tensor, pd.DataFrame],
        ratio: float = 0.7,
        random_state: int = 42
) -> Tuple[Union[torch.Tensor, pd.DataFrame], Union[torch.Tensor, pd.DataFrame]]:
    """
    Split a dataset into train and test subsets based on a ratio.

    If `x` is a torch.Tensor, it uses a PyTorch Generator with the given random_seed
    to randomly permute the indices. If `x` is a DataFrame, it reorders rows accordingly.

    Args:
        x (Union[torch.Tensor, pd.DataFrame]):
            The dataset to be split.
        ratio (float, optional):
            Fraction of data to go into the train set (e.g., 0.7 means 70% train, 30% test).
            Defaults to 0.7.
        random_state (int, optional):
            Seed for reproducibility. Defaults to 42.

    Returns:
        Tuple[Union[torch.Tensor, pd.DataFrame], Union[torch.Tensor, pd.DataFrame]]:
            (train_subset, test_subset) in the same type as `x`.
    """
    if isinstance(x, torch.Tensor):
        device = x.device
    else:
        device = 'cpu'

    generator = torch.Generator(device=device).manual_seed(random_state)
    n_train = int(len(x) * ratio)
    ind = torch.randperm(len(x), generator=generator)

    if isinstance(x, pd.DataFrame):
        return x.iloc[ind[:n_train]], x.iloc[ind[n_train:]]
    else:
        return x[ind[:n_train]], x[ind[n_train:]]


def setup_dataframes(cfg: DictConfig) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Read training and test data from file paths specified in a config object.

    Args:
        cfg (DictConfig):
            An OmegaConf-based configuration containing paths:
              - cfg.dataset.train_path
              - cfg.dataset.test_path

    Returns:
        (pd.DataFrame, pd.DataFrame): The training and test DataFrames.
    """
    train_df = pd.read_csv(cfg.dataset.train_path)
    test_df = pd.read_csv(cfg.dataset.test_path)
    return train_df, test_df


def drop_target(cfg: DictConfig) -> DictConfig:
    """
    Return a copy of a configuration where the target column is removed
    from the categorical and numerical lists.

    This is useful if you want to ensure the target is never treated as a feature.

    Args:
        cfg (DictConfig):
            Configuration object containing:
              - cfg.categorical_columns
              - cfg.numerical_columns
              - cfg.target_column

    Returns:
        DictConfig: A copy of the input config with the target column removed
                    from the column lists.
    """
    cfg_copy = cfg.copy()

    # Remove target from the categorical list if present
    if cfg_copy.categorical_columns is not None and cfg_copy.target_column in cfg_copy.categorical_columns:
        cfg_copy.categorical_columns.remove(cfg_copy.target_column)

    # Remove target from the numerical list if present
    if cfg_copy.numerical_columns is not None:
        cfg_copy.numerical_columns = ListConfig(
            [
                e for e in cfg_copy.numerical_columns
                if not e.startswith(f"{cfg_copy.target_column}!!")
            ]
        )
    return cfg_copy


def create_dataset_config(
        save_path: str,
        name: str,
        train_csv_path: str,
        test_csv_path: Optional[str] = None,
        holdout_csv_path: Optional[str] = None,
        target_column: Optional[str] = None,
        categorical_columns: Optional[List[str]] = None,
        auto: bool = False
) -> None:
    """
    Generate and save a configuration file describing a dataset, including:
      - Paths to train/test/holdout splits
      - Counts of rows/columns
      - Which columns are categorical vs. numerical
      - Basic missing-value statistics

    It attempts to identify numerical columns by type, excluding any specified
    `categorical_columns`, and then labels the rest as categorical. It also
    excludes columns that appear to have >95% distinct values (likely ID columns).
    Optionally uses a heuristic to automatically treat low-cardinality numeric
    columns as categorical.

    Args:
        save_path (str):
            File path where the YAML-like configuration should be saved.
        name (str):
            Name of the dataset (written under "name").
        train_csv_path (str):
            Path to the training CSV file.
        test_csv_path (str, optional):
            Path to the test CSV file. If the file doesn't exist, an empty DataFrame is used.
            Defaults to None.
        holdout_csv_path (str, optional):
            Path to a holdout CSV file for e.g. DPIMIA usage. Defaults to None.
        target_column (str, optional):
            Target column name in the training CSV. If None, defaults to the last column.
            Defaults to None.
        categorical_columns (list of str, optional):
            Columns to treat as categorical. Defaults to None, in which case
            columns are inferred from data.
        auto (bool, optional):
            If True, attempt to detect low-cardinality numeric columns and move them
            to the categorical list. Defaults to False.

    Returns:
        None
    """
    train_csv_path = train_csv_path.replace(os.sep, '/')
    test_csv_path = test_csv_path.replace(os.sep, '/') if test_csv_path is not None else None
    holdout_csv_path = holdout_csv_path.replace(os.sep, '/') if holdout_csv_path is not None else None
    train_df = pd.read_csv(train_csv_path)
    # If target not provided, use the last column
    if target_column is None:
        target_column = train_df.columns[-1]

    # Convert None to empty list for consistency
    if categorical_columns is None:
        categorical_columns = []

    # Identify numerical vs. categorical columns
    numeric_candidates = train_df.select_dtypes(include=['number']).columns.tolist()
    numerical_columns = [col for col in numeric_candidates if col not in categorical_columns]

    # Exclude columns that are likely IDs (high distinct values)
    all_categorical = [col for col in train_df.columns if col not in numerical_columns]
    refined_categorical = [
        col for col in all_categorical
        if len(train_df[col].unique()) <= len(train_df) * 0.95
    ]
    categorical_columns = refined_categorical

    # Combine columns in the order: [categorical + numerical]
    columns = categorical_columns + numerical_columns

    # Load test_df (or empty if file doesn't exist)
    if test_csv_path and os.path.exists(test_csv_path):
        test_df = pd.read_csv(test_csv_path)
    else:
        test_df = pd.DataFrame(columns=columns)

    # Load holdout_df (or empty if file doesn't exist)
    if holdout_csv_path and os.path.exists(holdout_csv_path):
        holdout_df = pd.read_csv(holdout_csv_path)
    else:
        holdout_df = pd.DataFrame(columns=columns)

    # Gather typed numeric columns & track auto-categorical possibilities
    numerical_columns_with_dtypes = []
    auto_categorical_columns = []
    auto_numerical_columns = []

    for col in numerical_columns:
        col_data = train_df[col].dropna()

        # Check if ~99% of non-NaNs are integral
        is_mostly_int = np.mean(col_data.astype(int) == col_data) > 0.99
        dtype = 'int' if is_mostly_int else 'float'

        # If a numeric column has very few unique values, it might be categorical/binary
        if len(train_df[col].unique()) < min(len(col_data) * 0.1, 20):
            print(
                f"\t{name}: column \"{col}\" looks like categorical/binary. "
                f"Check DataFrame. Unique values: {train_df[col].unique().tolist()}"
            )
            auto_categorical_columns.append(col)
        else:
            auto_numerical_columns.append(f"{col}!!{dtype}")

        numerical_columns_with_dtypes.append(f"{col}!!{dtype}")

    # If auto is True, move flagged columns from numeric to categorical
    if auto:
        categorical_columns = categorical_columns + auto_categorical_columns

        # Rebuild numeric type list excluding reclassified columns
        numerical_columns_with_dtypes = [
            col for col in numerical_columns_with_dtypes
            if col.split("!!")[0] not in auto_categorical_columns
        ]

    # Build up lines for a YAML-like configuration file
    config_lines = [
        f"name: \"{name}\"",
        f"train_path: \"{train_csv_path}\"",
        f"test_path: \"{test_csv_path}\"",
        f"holdout_path: \"{holdout_csv_path}\"",
        f"n_rows_train: {len(train_df)}",
        f"n_rows_test: {len(test_df)}",
        f"n_rows_holdout: {len(holdout_df)}",
        f"n_cols: {len(categorical_columns + numerical_columns_with_dtypes)}",
        f"n_categorical_cols: {len(categorical_columns)}",
        f"n_numerical_cols: {len(numerical_columns_with_dtypes)}",
        f"has_missing_values: {train_df[columns].isna().to_numpy().sum() > 0}",
        f"missing_element_ratio: {train_df[columns].isna().to_numpy().mean()}",
        f"n_rows_missing_train: {len(train_df.drop(train_df.dropna(subset=columns).index, axis=0))}",
        f"n_rows_missing_test: {len(test_df.drop(test_df.dropna(subset=columns).index, axis=0))}",
        f"n_rows_missing_holdout: {len(holdout_df.drop(holdout_df.dropna(subset=columns).index, axis=0))}",
        "missing_values:",
        "  - ''",
        "categorical_columns:"
    ]

    # Add each categorical column
    for col in categorical_columns:
        if len(train_df[col].unique()) > len(train_df) * 0.95:
            print(f"\t{name}: column {col} excluded due to high distinct count")
        else:
            config_lines.append(f"  - \"{col}\"")

    # Add typed numerical columns
    config_lines.append("numerical_columns:")
    for col in numerical_columns_with_dtypes:
        config_lines.append(f"  - \"{col}\"")

    # Add the target column
    config_lines.append(f"target_column: \"{target_column}\"")

    # Finally, write out the config
    with open(save_path, 'w') as file:
        file.write("\n".join(config_lines))






def get_conditional_y_from_rare_targets(
        y,
        rare_thr=0.8,
        method="kde",
):
    """
    Extracts rare-target y values to be used as conditional input
    for a generative model (e.g., tabular diffusion).

    Parameters
    ----------
    y : array-like of shape (n,)
        Target values.
    rare_thr : float, default=0.8
        Threshold above which φ(y) is considered rare.
    method : {"kde", "zscore", "ecdf"}, default="kde"
        Method to compute relevance score φ(y).

    Returns
    -------
    y_cond : np.ndarray of shape (n_samples,)
        Target values from rare region (for conditional generation).
    """
    y = np.asarray(y).reshape(-1, 1)

    # Step 1: Compute φ(y)
    if method == "kde":
        kde = KernelDensity(kernel="gaussian", bandwidth=1.0).fit(y)
        log_p = kde.score_samples(y)
        p = np.exp(log_p)
        relevance = 1 - (p / np.max(p))

    elif method == "zscore":
        z = np.abs(zscore(y.ravel()))
        relevance = (z / np.max(z)).clip(0, 1)

    elif method == "ecdf":
        from statsmodels.distributions.empirical_distribution import ECDF
        F = ECDF(y.ravel())(y.ravel())
        relevance = 2 * np.abs(F - 0.5)

    else:
        raise ValueError(f"Unknown method: {method}")

    relevance = relevance.ravel()

    # Step 2: Filter rare targets
    rare_mask = relevance > rare_thr
    y_rare = y[rare_mask].ravel()

    if len(y_rare) == 0:
        raise ValueError("No rare target samples found. Try lowering `rare_thr`.")

    return y_rare