import os
import openml
import pandas as pd
import numpy as np
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, LabelEncoder

def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def fetch_openml_dataset_and_rename_label(dataset_id):
    dataset = openml.datasets.get_dataset(dataset_id, download_data=True)
    target_name = dataset.default_target_attribute
    df, *_ = dataset.get_data(dataset_format="dataframe")
    if target_name != "label":
        df = df.rename(columns={target_name: "label"})
    return df, "label", dataset.name

def preprocess_and_balance(df, label_column, samples_per_class, save_path):
    numeric_columns = df.select_dtypes(include=['number']).columns.tolist()
    categorical_columns = df.select_dtypes(include=['object', 'category']).columns.tolist()
    categorical_columns = [col for col in categorical_columns if col != label_column]
    class_counts = df[label_column].value_counts()
    eligible_classes = class_counts[class_counts >= samples_per_class]
    if eligible_classes.empty:
        raise ValueError("No class has enough samples above the threshold.")
    min_samples = eligible_classes.min()
    filtered_df = df[df[label_column].isin(eligible_classes.index)]
    balanced_df = (
        filtered_df.groupby(label_column)
        .apply(lambda x: x.sample(min_samples, random_state=42))
        .reset_index(drop=True)
    )
    if numeric_columns:
        imputer = SimpleImputer(strategy='mean')
        scaler = StandardScaler()
        balanced_df[numeric_columns] = imputer.fit_transform(balanced_df[numeric_columns])
        balanced_df[numeric_columns] = scaler.fit_transform(balanced_df[numeric_columns])
    if categorical_columns:
        balanced_df = pd.get_dummies(balanced_df, columns=categorical_columns, dtype=int)
    le = LabelEncoder()
    balanced_df[label_column] = le.fit_transform(balanced_df[label_column])
    balanced_df.to_csv(save_path, index=False)
    return balanced_df

def custom_train_test_split_per_class(df, label_column, samples_per_class_train):
    train_rows = []
    test_rows = []
    for label, group in df.groupby(label_column):
        if len(group) < samples_per_class_train:
            raise ValueError(f"Not enough samples in class {label} for the requested train split.")
        train = group.sample(samples_per_class_train, random_state=42)
        test = group.drop(train.index)
        train_rows.append(train)
        test_rows.append(test)
    train_df = pd.concat(train_rows).sample(frac=1, random_state=42).reset_index(drop=True)
    test_df = pd.concat(test_rows).sample(frac=1, random_state=42).reset_index(drop=True)
    return train_df, test_df

if __name__ == "__main__":
    dest_dir = "openml_datasets/default-of-credit-card-clients_categorical"
    ensure_dir(dest_dir)
    dataset_id = 42477   # This is the OpenML ID for default-of-credit-card-clients_categorical
    samples_per_class = 2500    # per class (train + test)
    samples_per_class_train = 2000  # per class for train

    # Step 1: Download & preprocess
    df, label_column, dataset_name = fetch_openml_dataset_and_rename_label(dataset_id)
    save_path = os.path.join(dest_dir, f"preprocessed_{dataset_name}.csv")
    processed_df = preprocess_and_balance(df, label_column, samples_per_class, save_path)
    print(f"Processed {dataset_name}, shape={processed_df.shape}")

    # Step 2: Custom per-class split
    train_df, test_df = custom_train_test_split_per_class(processed_df, label_column, samples_per_class_train)

    # Step 3: Save as dicts and CSVs
    train_dict = {
        "X": train_df.drop(columns=[label_column]).to_numpy(),
        "y": train_df[label_column].to_numpy(),
        "feature_names": list(train_df.drop(columns=[label_column]).columns),
        "label_column": label_column,
    }
    test_dict = {
        "X": test_df.drop(columns=[label_column]).to_numpy(),
        "y": test_df[label_column].to_numpy(),
        "feature_names": list(test_df.drop(columns=[label_column]).columns),
        "label_column": label_column,
    }
    np.savez(os.path.join(dest_dir, "train_dict.npz"), **train_dict)
    np.savez(os.path.join(dest_dir, "test_dict.npz"), **test_dict)
    train_df.to_csv(os.path.join(dest_dir, "train_data.csv"), index=False)
    test_df.to_csv(os.path.join(dest_dir, "test_df.csv"), index=False)
    processed_df.to_csv(os.path.join(dest_dir, "full_balanced_processed.csv"), index=False)

    print(f"Train set: X={train_dict['X'].shape}, y={train_dict['y'].shape}")
    print(f"Test set:  X={test_dict['X'].shape}, y={test_dict['y'].shape}")
    print(f"Feature names: {train_dict['feature_names']}")
    print(f"Files written to: {dest_dir}/train_dict.npz, test_dict.npz, full_balanced_processed.csv")
