import random
from typing import Optional, Sequence, Tuple, Dict, Any, Iterable

import numpy as np
import pandas as pd
from pandas.api.types import is_numeric_dtype
from configparser import ConfigParser
import os

def _write_stats_item(
    f,
    key: str,
    value: Any,
    prefix: str = "# ",
    indent: str = "",
) -> None:
    """
    Recursively pretty-print a (key, value) pair as commented lines.

    Handles:
      - scalars (int, float, str, bool, None)
      - dicts (nested)
      - lists/tuples (nested)
    """
    # Helper to decide if something is a "simple" scalar
    def _is_scalar(x: Any) -> bool:
        return isinstance(x, (int, float, bool, str)) or x is None

    if isinstance(value, dict):
        # Print the key as a header, then recurse into subkeys
        f.write(f"{prefix}{indent}{key}:\n")
        for sub_k, sub_v in value.items():
            _write_stats_item(f, str(sub_k), sub_v, prefix=prefix, indent=indent + "  ")

    elif isinstance(value, (list, tuple)):
        # Print the key as a header, then each element (possibly nested)
        f.write(f"{prefix}{indent}{key}:\n")
        for idx, elem in enumerate(value):
            if _is_scalar(elem):
                f.write(f"{prefix}{indent}  - {elem}\n")
            else:
                # Complex element: treat index as a label and recurse
                _write_stats_item(
                    f,
                    key=f"item_{idx}",
                    value=elem,
                    prefix=prefix,
                    indent=indent + "  ",
                )

    else:
        # Scalar value: print one line
        f.write(f"{prefix}{indent}{key}: {value}\n")


def write_processed_df_with_stats_to_csv(
    df: pd.DataFrame,
    stats: Dict[str, Any],
    filename: str,
    comment_prefix: str = "# ",
    mapping: Optional[Dict[str, str]] = None,
) -> None:
    """
    Write the processed DataFrame and *arbitrary* statistics to a CSV file.

    The output file has:
      - Comment lines at the top (starting with comment_prefix, default '# ')
        describing all entries in `stats`, printed generically:
          * scalars       -> "# key: value"
          * dicts         -> nested "# key:" blocks
          * lists/tuples  -> "# key:" followed by "- item" lines; nested
                             structures handled recursively.
      - (Optionally) a block for `mapping` (original -> feature name).
      - Then the usual CSV header and data rows from `df`.

    You can freely add new entries (even nested) to `stats` and they will
    automatically appear in the comment header without changing this code.

    Parameters
    ----------
    df : pd.DataFrame
        Data to write.
    stats : Dict[str, Any]
        Arbitrary statistics dictionary.
    filename : str
        Path of the CSV file to write.
    comment_prefix : str
        Prefix used for comment lines (default "# ").
    mapping : Dict[str, str], optional
        Optional mapping from original column names to feature names (f1, f2, ...),
        written as a separate commented block if provided.
    """
    with open(filename, "w", newline="") as f:
        # Top-level stats block
        # print hash # lines
        f.write(f"{comment_prefix}" + "#" * 80 + "\n")
        f.write(f"{comment_prefix}Dataset statistics\n")
        for key, value in stats.items():
            _write_stats_item(f, str(key), value, prefix=comment_prefix, indent="")

        # Optional mapping block
        if mapping is not None:
            f.write(f"{comment_prefix}\n")
            f.write(f"{comment_prefix}Original column -> feature name mapping\n")
            for orig, feat in mapping.items():
                f.write(f"{comment_prefix}{orig} -> {feat}\n")

        f.write(f"{comment_prefix}\n")  # separator before data
        f.write(f"{comment_prefix}" + "#" * 80 + "\n")

        # Actual CSV
        df.to_csv(f, index=False)

###################################################################################
def convert_heart_failure_to_facility_client_df(
    df_raw: pd.DataFrame,
    n_groups: int = 3,
    sex_col: str = "sex",
    male_value: int = 1,
    female_value: int = 0,
    third_group_col: str = "age",
    third_group_threshold: float = 50.0,
    facility_probability: float = 0.5,
    max_capacity: int = 10,
    seed: Optional[int] = None,
) -> Tuple[pd.DataFrame, Dict[str, str], Dict[str, Any]]:
    """
    Convert the heart-failure dataset into the synthetic facility/client format.

    Output schema:
      - Feature columns: f1, f2, ..., f_d (all original columns, numeric-coded
        if needed, each normalized to unit L2 norm).
      - is_facility : 1 if facility, 0 if client (random according to
        facility_probability).
      - capacity    : 0 for clients, 1..max_capacity for facilities.
      - group1..group_n_groups:
          For n_groups >= 3 and heart-failure defaults:
            group1: sex == male_value
            group2: sex == female_value
            group3: df_raw[third_group_col] <= third_group_threshold
          For g > 3: random overlapping groups are generated.

    Additionally returns statistics:
      stats = {
        "num_points": int,
        "num_facilities": int,
        "num_clients": int,
        "group_counts": {group_name: int},
        "group_facility_counts": {group_name: int},
        "group_client_counts": {group_name: int},
      }

    Parameters
    ----------
    df_raw : pd.DataFrame
        Original heart-failure dataframe (e.g. read from heart-failure.csv).
    n_groups : int
        Number of group columns to generate (>= 1). For your first experiments
        this should be 3.
    sex_col : str
        Name of the sex column in df_raw (default "sex").
    male_value : int
        Value in sex_col indicating male (heart-failure: 1).
    female_value : int
        Value in sex_col indicating female (heart-failure: 0).
    third_group_col : str
        Column to define the third group (e.g. "age").
    third_group_threshold : float
        Threshold for third group: third_group = 1 if
        df_raw[third_group_col] <= third_group_threshold.
    facility_probability : float
        Probability that a point is designated as a facility.
    max_capacity : int
        Maximum capacity assigned to a facility (minimum is 1).
    seed : Optional[int]
        Random seed for reproducibility.

    Returns
    -------
    df_final : pd.DataFrame
        DataFrame with columns [f1..f_d, is_facility, capacity, group1..group_t].
    mapping : Dict[str, str]
        Mapping from original column name -> new feature name ("f_i").
    stats : Dict[str, Any]
        Basic statistics on clients/facilities and group memberships.
    """
    if seed is not None:
        rng = np.random.default_rng(seed)
        py_rng = random.Random(seed)

    # 1. Work on a copy so we don't mutate the original df
    df = df_raw.copy()
    original_cols = list(df.columns)

    # 2. Ensure all original columns are numeric (factorize non-numeric)
    for col in original_cols:
        if not is_numeric_dtype(df[col]):
            df[col], _ = pd.factorize(df[col], sort=True)
        df[col] = df[col].astype(float)

    # 3. Normalize each original column to unit L2 norm
    for col in original_cols:
        v = df[col].to_numpy(dtype=float)
        norm = np.linalg.norm(v)
        if norm > 0:
            df[col] = v / norm
        else:
            df[col] = 0.0

    # 4. Rename normalized original columns to f1..f_d
    feature_cols_new = [f"f{i+1}" for i in range(len(original_cols))]
    mapping = dict(zip(original_cols, feature_cols_new))
    df_features = df.rename(columns=mapping)

    n = len(df_raw)

    # 5. Build group memberships based on the *raw* (unscaled) data
    group_mat = np.zeros((n, n_groups), dtype=int)

    # group1: male
    if n_groups >= 1:
        group_mat[:, 0] = (df_raw[sex_col] == male_value).astype(int)

    # group2: female
    if n_groups >= 2:
        group_mat[:, 1] = (df_raw[sex_col] == female_value).astype(int)

    # group3: third_group_col <= threshold (e.g., age <= 50)
    if n_groups >= 3:
        group_mat[:, 2] = (df_raw[third_group_col] <= third_group_threshold).astype(int)

    # Any extra groups (n_groups > 3): random overlapping groups
    for g in range(3, n_groups):
        group_mat[:, g] = (rng.random(n) < 0.5).astype(int)

    group_col_names = [f"group{i+1}" for i in range(n_groups)]
    for j, name in enumerate(group_col_names):
        df_features[name] = group_mat[:, j]

    # 6. Assign facilities and capacities (like synthetic generator)
    is_facility = (rng.random(n) < facility_probability).astype(int)
    capacities = np.zeros(n, dtype=int)
    fac_idx = np.where(is_facility == 1)[0]
    if len(fac_idx) > 0:
        # capacities[fac_idx] = np.random.randint(1, max_capacity + 1, size=len(fac_idx))
        capacities[fac_idx] = rng.integers(1, max_capacity + 1, size=len(fac_idx))

    df_features["is_facility"] = is_facility
    df_features["capacity"] = capacities

    # 7. Reorder columns: features, is_facility, capacity, groups
    df_final = df_features[feature_cols_new + ["is_facility", "capacity"] + group_col_names]

    # 8. Compute statistics
    num_points = int(n)
    num_facilities = int(is_facility.sum())
    num_clients = int(num_points - num_facilities)

    group_counts: Dict[str, int] = {}
    group_facility_counts: Dict[str, int] = {}
    group_client_counts: Dict[str, int] = {}

    is_fac_mask = (is_facility == 1)
    is_client_mask = (is_facility == 0)

    for j, gname in enumerate(group_col_names):
        col_vals = group_mat[:, j]
        group_counts[gname] = int(col_vals.sum())
        group_facility_counts[gname] = int(col_vals[is_fac_mask].sum())
        group_client_counts[gname] = int(col_vals[is_client_mask].sum())

    stats: Dict[str, Any] = {
        "seed": seed,
        "num_points": num_points,
        "num_features": len(feature_cols_new),
        "num_clients": num_clients,
        "num_facilities": num_facilities,
        "facility_probability": facility_probability,
        "max_capacity": max_capacity,
        "num_groups": n_groups,
        "group_parameters": ["male (1)", "female (0)", f"{third_group_col} (<= {third_group_threshold})"],
        "group_counts": group_counts,
        "group_facility_counts": group_facility_counts,
        "group_client_counts": group_client_counts,
    }

    return df_final, mapping, stats

def get_heart_dataset_df(n_groups: int =3,
                         max_capacity: int = -1, 
                         facility_probability: float = 0.3,
                         k: int = 5,
                         seed: int = 123456789,
                         write_to_file: bool = False):

    config = ConfigParser()
    config.read('config.ini')
    datasets_raw_dir = config.get('PATH', 'datasets-raw')
    in_file = os.path.join(datasets_raw_dir, 'heart-failure.csv')

    # Load the real dataset
    df_raw = pd.read_csv(in_file)
    n, z = df_raw.shape
    if max_capacity <= 0:
        max_capacity = int(n/k)

    # Convert to facility/client format with 3 intersecting groups:
    # group1 = male, group2 = female, group3 = age <= 50
    df_conv, mapping, stats = convert_heart_failure_to_facility_client_df(
        df_raw,
        n_groups=n_groups,
        sex_col="sex",
        male_value=1,
        female_value=0,
        third_group_col="age",
        third_group_threshold=50.0,
        facility_probability=facility_probability,
        max_capacity=max_capacity,
        seed=seed,
    )

    if write_to_file:
        datasets_processed_dir = config.get('PATH', "datasets-processed-t%d"%(n_groups))
        out_file = os.path.join(datasets_processed_dir,
                                'heart-failure-t%d-cap%s.csv'%(n_groups, max_capacity))
        write_processed_df_with_stats_to_csv(
            df_conv,
            stats,
            out_file
        )

    return df_conv, stats

##################################################################################
def convert_student_mat_to_facility_client_df(
    df_raw: pd.DataFrame,
    n_groups: int = 3,
    sex_col: str = "sex",
    male_value: str = "M",
    female_value: str = "F",
    third_group_col: str = "age",
    third_group_threshold: float = 18.0,
    facility_probability: float = 0.5,
    max_capacity: int = 10,
    seed: Optional[int] = None,
) -> Tuple[pd.DataFrame, Dict[str, str], Dict[str, Any]]:
    """
    Convert the student-mat dataset into the synthetic facility/client format.

    Steps
    -----
    1. Split original columns into numeric and categorical.
    2. One-hot encode all categorical columns (0/1 dummy variables).
    3. Concatenate numeric + dummy columns into feature matrix.
       - Each feature column is L2-normalized to unit norm.
    4. Rename feature columns to f1..f_d.
    5. Add:
       - is_facility (Bernoulli(facility_probability), local RNG)
       - capacity (0 for clients, 1..max_capacity for facilities)
       - group1..group_n_groups:
            group1: sex == male_value   (e.g., "M")
            group2: sex == female_value (e.g., "F")
            group3: df_raw[third_group_col] < third_group_threshold (minors)
            groups 4..n_groups: random overlapping groups via local RNG.
    6. Compute summary statistics.

    Output schema
    -------------
    df_final columns:
      f1, f2, ..., f_d, is_facility, capacity, group1, ..., group_n_groups

    mapping:
      dict original_column_name -> string of feature names (if a categorical
      column expands to multiple one-hot features, they are comma-separated).

    stats:
      {
        "num_points": int,
        "num_facilities": int,
        "num_clients": int,
        "group_counts": {group_name: int},
        "group_facility_counts": {group_name: int},
        "group_client_counts": {group_name: int},
      }

    Important
    ---------
    - Uses *local* RNG instances; does NOT modify global random state.

    Parameters
    ----------
    df_raw : pd.DataFrame
        Raw student-mat dataframe (e.g., read from student-mat.csv).
    n_groups : int
        Number of group columns to generate (default 3).
    sex_col : str
        Name of the sex column (student-mat: "sex").
    male_value : str
        Value in sex_col indicating male ("M").
    female_value : str
        Value in sex_col indicating female ("F").
    third_group_col : str
        Column to define the third group (e.g. "age").
    third_group_threshold : float
        Threshold for third group: group3 = 1 if df_raw[third_group_col] < threshold.
    facility_probability : float
        Probability that a data point is marked as a facility.
    max_capacity : int
        Maximum capacity assigned to a facility (min 1).
    seed : Optional[int]
        Seed for local RNGs used inside this function.

    Returns
    -------
    df_final : pd.DataFrame
    mapping : Dict[str, str]
    stats : Dict[str, Any]
    """
    # Local RNGs – do NOT touch global state
    rng = np.random.default_rng(seed)
    py_rng = random.Random(seed)

    # 1. Work on a copy so we don't mutate original
    df = df_raw.copy()
    all_cols = list(df.columns)

    # 2. Split into numeric and categorical columns
    numeric_cols: List[str] = []
    categorical_cols: List[str] = []
    for col in all_cols:
        if is_numeric_dtype(df[col]):
            numeric_cols.append(col)
        else:
            categorical_cols.append(col)

    # Ensure numeric columns are float
    df_numeric = df[numeric_cols].astype(float) if numeric_cols else pd.DataFrame(index=df.index)

    # One-hot encode all categorical columns
    # Each original categorical column c becomes several columns like c_value1, c_value2, ...
    df_cat_dummies_list = []
    cat_dummy_cols: Dict[str, List[str]] = {}  # original categorical col -> list of dummy column names

    for col in categorical_cols:
        dummies = pd.get_dummies(df[col].astype("category"), prefix=col)
        dummies = dummies.astype(float)
        df_cat_dummies_list.append(dummies)
        cat_dummy_cols[col] = list(dummies.columns)

    if df_cat_dummies_list:
        df_cat_dummies = pd.concat(df_cat_dummies_list, axis=1)
    else:
        df_cat_dummies = pd.DataFrame(index=df.index)

    # 3. Combine numeric + dummies into feature matrix
    df_features_raw = pd.concat([df_numeric, df_cat_dummies], axis=1)
    feature_orig_cols = list(df_features_raw.columns)

    # L2-normalize each feature column
    for col in feature_orig_cols:
        v = df_features_raw[col].to_numpy(dtype=float)
        norm = np.linalg.norm(v)
        if norm > 0:
            df_features_raw[col] = v / norm
        else:
            df_features_raw[col] = 0.0

    # 4. Rename feature columns to f1..f_d
    feature_cols_new = [f"f{i+1}" for i in range(len(feature_orig_cols))]
    rename_dict = {old: new for old, new in zip(feature_orig_cols, feature_cols_new)}
    df_features = df_features_raw.rename(columns=rename_dict)

    # Build mapping from original columns -> feature names
    # - numeric col: maps to corresponding single feature name
    # - categorical col: maps to comma-separated list of its one-hot features
    mapping: Dict[str, str] = {}

    # Numeric columns
    for col in numeric_cols:
        if col in feature_orig_cols:
            new_name = rename_dict[col]
            mapping[col] = new_name

    # Categorical columns
    for col in categorical_cols:
        dummy_cols_for_col = cat_dummy_cols.get(col, [])
        new_names = [rename_dict[dcol] for dcol in dummy_cols_for_col]
        mapping[col] = ",".join(new_names)

    n = len(df_raw)

    # 5. Group memberships (use *raw* df for defining groups)
    group_mat = np.zeros((n, n_groups), dtype=int)

    # group1: male
    if n_groups >= 1:
        group_mat[:, 0] = (df_raw[sex_col] == male_value).astype(int)

    # group2: female
    if n_groups >= 2:
        group_mat[:, 1] = (df_raw[sex_col] == female_value).astype(int)

    # group3: minors: age < third_group_threshold
    if n_groups >= 3:
        group_mat[:, 2] = (df_raw[third_group_col] < third_group_threshold).astype(int)

    # Extra groups (4..n_groups): random overlapping membership
    for g in range(3, n_groups):
        group_mat[:, g] = (rng.random(n) < 0.5).astype(int)

    group_col_names = [f"group{i+1}" for i in range(n_groups)]
    for j, name in enumerate(group_col_names):
        df_features[name] = group_mat[:, j]

    # 6. Facilities and capacities (local RNG)
    is_facility = (rng.random(n) < facility_probability).astype(int)
    capacities = np.zeros(n, dtype=int)
    fac_idx = np.where(is_facility == 1)[0]
    if len(fac_idx) > 0:
        capacities[fac_idx] = rng.integers(1, max_capacity + 1, size=len(fac_idx))

    df_features["is_facility"] = is_facility
    df_features["capacity"] = capacities

    # 7. Reorder columns: features, is_facility, capacity, groups
    df_final = df_features[feature_cols_new + ["is_facility", "capacity"] + group_col_names]

    # 8. Statistics
    num_points = int(n)
    num_facilities = int(is_facility.sum())
    num_clients = int(num_points - num_facilities)

    group_counts: Dict[str, int] = {}
    group_facility_counts: Dict[str, int] = {}
    group_client_counts: Dict[str, int] = {}

    is_fac_mask = (is_facility == 1)
    is_client_mask = (is_facility == 0)

    for j, gname in enumerate(group_col_names):
        col_vals = group_mat[:, j]
        group_counts[gname] = int(col_vals.sum())
        group_facility_counts[gname] = int(col_vals[is_fac_mask].sum())
        group_client_counts[gname] = int(col_vals[is_client_mask].sum())

    stats: Dict[str, Any] = {
        "seed": seed,
        "num_points": num_points,
        "num_features": len(feature_cols_new),
        "num_clients": num_clients,
        "num_facilities": num_facilities,
        "facility_probability": facility_probability,
        "max_capacity": max_capacity,
        "num_groups": n_groups,
        "group_parameters": ["male (1)", "female (0)", f"{third_group_col} (<= {third_group_threshold})"],
        "group_counts": group_counts,
        "group_facility_counts": group_facility_counts,
        "group_client_counts": group_client_counts,
    }

    return df_final, mapping, stats

def get_student_mat_dataset_df(n_groups: int =3,
                               max_capacity: int = -1, 
                               facility_probability: float = 0.3,
                               k: int = 5,
                               seed: int = 123456789,
                               write_to_file: bool = False):
    config = ConfigParser()
    config.read('config.ini')
    datasets_raw_dir = config.get('PATH', 'datasets-raw')
    in_file = os.path.join(datasets_raw_dir, 'student-mat.csv')

    # Load the real dataset
    df_raw = pd.read_csv(in_file, sep=';')
    n, z = df_raw.shape
    if max_capacity <= 0:
        max_capacity = int(n/k)

    # Convert to facility/client format with 3 intersecting groups:
    # group1 = male, group2 = female, group3 = age <= 50
    df_conv, mapping, stats = convert_heart_failure_to_facility_client_df(
        df_raw,
        n_groups=n_groups,
        sex_col="sex",
        male_value="M",
        female_value="F",
        third_group_col="age",
        third_group_threshold=18.0,
        facility_probability=facility_probability,
        max_capacity=max_capacity,
        seed=seed,
    )

    if write_to_file:
        datasets_processed_t3_dir = config.get('PATH', 'datasets-processed-t%d'%(n_groups))
        out_file = os.path.join(datasets_processed_t3_dir,
                                'student-mat-t%d-cap%s.csv'%(n_groups, max_capacity))
        write_processed_df_with_stats_to_csv(
            df_conv,
            stats,
            out_file
        )

    return df_conv, stats

###############################################################################
def convert_student_por_to_facility_client_df(
    df_raw: pd.DataFrame,
    n_groups: int = 3,
    sex_col: str = "sex",
    male_value: str = "M",
    female_value: str = "F",
    third_group_col: str = "age",
    third_group_threshold: float = 18.0,
    facility_probability: float = 0.5,
    max_capacity: int = 10,
    seed: Optional[int] = None,
) -> Tuple[pd.DataFrame, Dict[str, str], Dict[str, Any]]:
    """
    Convert the student-por dataset into the synthetic facility/client format.

    This is a thin wrapper around convert_student_mat_to_facility_client_df,
    since student-por has the same schema as student-mat.

    Groups (for n_groups >= 3):
      group1: sex == male_value
      group2: sex == female_value
      group3: df_raw[third_group_col] < third_group_threshold (minors)

    All categorical columns are one-hot encoded, each feature column is
    L2-normalized, and we use local RNGs for facilities and capacities.
    """
    return convert_student_mat_to_facility_client_df(
        df_raw=df_raw,
        n_groups=n_groups,
        sex_col=sex_col,
        male_value=male_value,
        female_value=female_value,
        third_group_col=third_group_col,
        third_group_threshold=third_group_threshold,
        facility_probability=facility_probability,
        max_capacity=max_capacity,
        seed=seed,
    )

def get_student_por_dataset_df(
    n_groups: int =3,
    max_capacity: int = -1,
    facility_probability: float = 0.3,
    k: int = 5,
    seed: int = 123456789,
    write_to_file: bool = True,
):
    """
    Load student-por.csv, preprocess it into the facility/client format with
    3 groups (male, female, minor), and optionally write to a CSV with stats.
    """
    config = ConfigParser()
    config.read('config.ini')
    datasets_raw_dir = config.get('PATH', 'datasets-raw')
    in_file = os.path.join(datasets_raw_dir, 'student-por.csv')

    # UCI student-por is also semicolon-separated
    df_raw = pd.read_csv(in_file, sep=";")
    n, z = df_raw.shape
    if max_capacity <= 0:
        max_capacity = int(n/k)

    df_conv, mapping, stats = convert_student_por_to_facility_client_df(
        df_raw,
        n_groups=n_groups,
        sex_col="sex",
        male_value="M",
        female_value="F",
        third_group_col="age",
        third_group_threshold=18.0,
        facility_probability=facility_probability,
        max_capacity=max_capacity,
        seed=seed,
    )
    if write_to_file:
        datasets_processed_t3_dir = config.get('PATH', 'datasets-processed-t%d'%(n_groups))
        out_file = os.path.join(datasets_processed_t3_dir,
                                'student-por-t%d-cap%s.csv'%(n_groups, max_capacity))
        write_processed_df_with_stats_to_csv(
            df_conv,
            stats,
            out_file
        )

    return df_conv, stats

###################################################################################

def convert_npha_doctor_visits_to_facility_client_df(
    df_raw: pd.DataFrame,
    n_groups: int = 3,
    gender_col: str = "gender",
    male_value: str = "Male",
    female_value: str = "Female",
    employment_col: str = "employment",
    employment_positive_value: str = "Employed",
    facility_probability: float = 0.5,
    max_capacity: int = 10,
    seed: Optional[int] = None,
) -> Tuple[pd.DataFrame, Dict[str, str], Dict[str, Any]]:
    """
    Convert the NPHA doctor-visits dataset into the synthetic facility/client format.

    Steps
    -----
    1. Separate numeric and categorical columns.
    2. One-hot encode all categorical columns (0/1 dummy variables).
    3. Concatenate numeric + dummy columns into the feature matrix.
       - Each feature column is L2-normalized to unit norm.
    4. Rename feature columns to f1..f_d.
    5. Add:
       - is_facility (Bernoulli(facility_probability), using local RNG)
       - capacity (0 for clients, 1..max_capacity for facilities)
       - group1..group_n_groups:
            group1: df_raw[gender_col] == male_value
            group2: df_raw[gender_col] == female_value
            group3: "employed" based on employment_col:
                * if employment_col is numeric: value == 1
                * else: value == employment_positive_value (default "Employed")
            groups 4..n_groups: random overlapping groups using local RNG.
    6. Compute summary statistics.

    Output schema
    -------------
    df_final columns:
      f1, f2, ..., f_d, is_facility, capacity, group1, ..., group_n_groups

    mapping:
      dict original_column_name -> feature name(s)
      - numeric column -> single feature name
      - categorical column -> comma-separated list of one-hot feature names.

    stats:
      {
        "num_points": int,
        "num_facilities": int,
        "num_clients": int,
        "group_counts": {group_name: int},
        "group_facility_counts": {group_name: int},
        "group_client_counts": {group_name: int},
      }

    Important
    ---------
    - Uses *local* RNG instances; does NOT modify global random state.

    Parameters
    ----------
    df_raw : pd.DataFrame
        Raw NPHA doctor-visits dataframe.
    n_groups : int
        Number of group columns (default 3).
    gender_col : str
        Name of the gender column (e.g., "gender" or similar).
    male_value : str
        Value in gender_col indicating male (e.g., "Male").
    female_value : str
        Value in gender_col indicating female (e.g., "Female").
    employment_col : str
        Column to define the third group (employment status).
    employment_positive_value : str
        For non-numeric employment_col, the value that denotes "employed".
        For numeric employment_col, value 1 is treated as "employed".
    facility_probability : float
        Probability that a point is designated as a facility.
    max_capacity : int
        Maximum capacity assigned to a facility (minimum is 1).
    seed : Optional[int]
        Seed for local RNGs used inside this function.

    Returns
    -------
    df_final : pd.DataFrame
        Processed facility/client dataframe.
    mapping : Dict[str, str]
        Original column -> feature name(s).
    stats : Dict[str, Any]
        Summary statistics.
    """
    # Local RNGs – do NOT touch global state
    rng = np.random.default_rng(seed)
    py_rng = random.Random(seed)

    # 1. Work on a copy so we don't mutate original
    df = df_raw.copy()
    all_cols = list(df.columns)

    # 2. Split into numeric and categorical columns
    numeric_cols: List[str] = []
    categorical_cols: List[str] = []
    for col in all_cols:
        if is_numeric_dtype(df[col]):
            numeric_cols.append(col)
        else:
            categorical_cols.append(col)

    # Ensure numeric columns are float
    df_numeric = df[numeric_cols].astype(float) if numeric_cols else pd.DataFrame(index=df.index)

    # One-hot encode categorical columns
    df_cat_dummies_list = []
    cat_dummy_cols: Dict[str, List[str]] = {}  # original categorical col -> list of dummy column names

    for col in categorical_cols:
        dummies = pd.get_dummies(df[col].astype("category"), prefix=col)
        dummies = dummies.astype(float)
        df_cat_dummies_list.append(dummies)
        cat_dummy_cols[col] = list(dummies.columns)

    if df_cat_dummies_list:
        df_cat_dummies = pd.concat(df_cat_dummies_list, axis=1)
    else:
        df_cat_dummies = pd.DataFrame(index=df.index)

    # 3. Combine numeric + dummies into feature matrix
    df_features_raw = pd.concat([df_numeric, df_cat_dummies], axis=1)
    feature_orig_cols = list(df_features_raw.columns)

    # L2-normalize each feature column
    for col in feature_orig_cols:
        v = df_features_raw[col].to_numpy(dtype=float)
        norm = np.linalg.norm(v)
        if norm > 0:
            df_features_raw[col] = v / norm
        else:
            df_features_raw[col] = 0.0

    # 4. Rename feature columns to f1..f_d
    feature_cols_new = [f"f{i+1}" for i in range(len(feature_orig_cols))]
    rename_dict = {old: new for old, new in zip(feature_orig_cols, feature_cols_new)}
    df_features = df_features_raw.rename(columns=rename_dict)

    # Build mapping from original columns -> feature names
    mapping: Dict[str, str] = {}

    # Numeric columns map 1:1
    for col in numeric_cols:
        if col in feature_orig_cols:
            mapping[col] = rename_dict[col]

    # Categorical columns map to a comma-separated list of their one-hot feature names
    for col in categorical_cols:
        dummy_cols_for_col = cat_dummy_cols.get(col, [])
        new_names = [rename_dict[dcol] for dcol in dummy_cols_for_col]
        mapping[col] = ",".join(new_names)

    n = len(df_raw)

    # 5. Group memberships (use *raw* df for defining groups)
    group_mat = np.zeros((n, n_groups), dtype=int)

    # group1: male
    if n_groups >= 1 and gender_col in df_raw.columns:
        group_mat[:, 0] = (df_raw[gender_col] == male_value).astype(int)

    # group2: female
    if n_groups >= 2 and gender_col in df_raw.columns:
        group_mat[:, 1] = (df_raw[gender_col] == female_value).astype(int)

    # group3: employed, based on employment_col
    if n_groups >= 3 and employment_col in df_raw.columns:
        if is_numeric_dtype(df_raw[employment_col]):
            # Numeric employment: treat "1" as employed by default
            group_mat[:, 2] = (df_raw[employment_col].astype(float) == 1.0).astype(int)
        else:
            # Categorical employment: compare to employment_positive_value
            group_mat[:, 2] = (df_raw[employment_col] == employment_positive_value).astype(int)

    # Extra groups (4..n_groups): random overlapping membership
    for g in range(3, n_groups):
        group_mat[:, g] = (rng.random(n) < 0.5).astype(int)

    group_col_names = [f"group{i+1}" for i in range(n_groups)]
    for j, name in enumerate(group_col_names):
        df_features[name] = group_mat[:, j]

    # 6. Facilities and capacities (using local RNG)
    is_facility = (rng.random(n) < facility_probability).astype(int)
    capacities = np.zeros(n, dtype=int)
    fac_idx = np.where(is_facility == 1)[0]
    if len(fac_idx) > 0:
        capacities[fac_idx] = rng.integers(1, max_capacity + 1, size=len(fac_idx))

    df_features["is_facility"] = is_facility
    df_features["capacity"] = capacities

    # 7. Reorder columns: features, is_facility, capacity, groups
    df_final = df_features[feature_cols_new + ["is_facility", "capacity"] + group_col_names]

    # 8. Statistics
    num_points = int(n)
    num_facilities = int(is_facility.sum())
    num_clients = int(num_points - num_facilities)

    group_counts: Dict[str, int] = {}
    group_facility_counts: Dict[str, int] = {}
    group_client_counts: Dict[str, int] = {}

    is_fac_mask = (is_facility == 1)
    is_client_mask = (is_facility == 0)

    for j, gname in enumerate(group_col_names):
        col_vals = group_mat[:, j]
        group_counts[gname] = int(col_vals.sum())
        group_facility_counts[gname] = int(col_vals[is_fac_mask].sum())
        group_client_counts[gname] = int(col_vals[is_client_mask].sum())

    """
    stats: Dict[str, Any] = {
        "num_points": num_points,
        "num_facilities": num_facilities,
        "num_clients": num_clients,
        "group_counts": group_counts,
        "group_facility_counts": group_facility_counts,
        "group_client_counts": group_client_counts,
    }
    """

    stats: Dict[str, Any] = {
        "seed": seed,
        "num_points": num_points,
        "num_features": len(feature_cols_new),
        "num_clients": num_clients,
        "num_facilities": num_facilities,
        "facility_probability": facility_probability,
        "max_capacity": max_capacity,
        "num_groups": n_groups,
        "group_parameters": ["male (1)", "female (0)", f"{employment_col} ({employment_positive_value})"],
        "group_counts": group_counts,
        "group_facility_counts": group_facility_counts,
        "group_client_counts": group_client_counts,
    }

    return df_final, mapping, stats


def get_npha_doctor_visits_dataset_df(
    n_groups: int =3,
    max_capacity: int = -1,
    facility_probability: float = 0.3,
    k: int = 5,
    seed: int = 123456789,
    write_to_file: bool = True
):
    """
    Load NPHA-doctor-visits.csv, preprocess it into the facility/client format
    with 3 groups (male, female, employed), and optionally write to a CSV with stats.
    """
    config = ConfigParser()
    config.read('config.ini')
    datasets_raw_dir = config.get('PATH', 'datasets-raw')
    in_file = os.path.join(datasets_raw_dir, 'NPHA-doctor-visits.csv')

    df_raw = pd.read_csv(in_file)  # adjust sep=... if needed for your file
    n, z = df_raw.shape
    if max_capacity <= 0:
        max_capacity = int(n/k)

    df_conv, mapping, stats = convert_npha_doctor_visits_to_facility_client_df(
        df_raw,
        n_groups=n_groups,
        gender_col="Gender",            # adjust if column name differs
        male_value=1,
        female_value=2,
        employment_col="Employment",    # adjust to your column name
        employment_positive_value=1,  # adjust to your coding
        facility_probability=facility_probability,
        max_capacity=max_capacity,
        seed=seed,
    )

    if write_to_file:
        datasets_processed_t3_dir = config.get('PATH', 'datasets-processed-t%d'%(n_groups))
        out_file = os.path.join(datasets_processed_t3_dir,
                                'NPHA-doctor-visits-t%d-cap%s.csv'%(n_groups, max_capacity))
        write_processed_df_with_stats_to_csv(
            df_conv,
            stats,
            out_file
        )


    return df_conv, stats

################################################################################
def get_dataset_df(dataset_name: str, 
                   n_groups: int =3,
                   max_capacity: int = -1,
                   facility_probability: float = 0.3,
                   k: int = 5,
                   seed: int = 123456789,
                   write_to_file: bool = False):

    if dataset_name == 'heart':
        df, stats = get_heart_dataset_df(n_groups,
                                         max_capacity, 
                                         facility_probability, 
                                         k,
                                         seed,
                                         write_to_file)
    elif dataset_name == 'student-mat':
        df, stats = get_student_mat_dataset_df(n_groups, 
                                               max_capacity, 
                                               facility_probability, 
                                               k,
                                               seed,
                                               write_to_file)
    elif dataset_name == 'student-por':
        df, stats = get_student_por_dataset_df(n_groups, 
                                               max_capacity, 
                                               facility_probability, 
                                               k,
                                               seed,
                                               write_to_file)
    elif dataset_name == 'NPHA':
        df, stats = get_npha_doctor_visits_dataset_df(n_groups, 
                                                      max_capacity, 
                                                      facility_probability, 
                                                      k,
                                                      seed,
                                                      write_to_file)
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}")

    return df, stats

###################################################################################
if __name__ == "__main__":
    get_heart_dataset_df(n_groups=4, max_capacity=-1, facility_probability=0.3, write_to_file=True)

    get_student_mat_dataset_df(n_groups=4, max_capacity=-1, facility_probability=0.3, write_to_file=True)

    get_student_por_dataset_df(n_groups=4, max_capacity=-1, facility_probability=0.3, write_to_file=True)

    get_npha_doctor_visits_dataset_df(n_groups=4, max_capacity=-1, facility_probability=0.3, k = 5, write_to_file=True)
