import os

from sklearn.datasets import *

from impugen.utils import train_test_split
from impugen.utils.data import create_dataset_config
import pandas as pd

SKLEARN_DATA = [
    (fetch_california_housing, './dataset/sklearn/housing'),
]


def get_data(func, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    name = save_dir.split('/')[-1]
    data = func(as_frame=True)
    df = data['frame']
    feature_names = data['feature_names']
    target_name = data['target_names']
    assert len(target_name) == 1
    target_name = target_name[0]

    df = df[feature_names + [target_name]]
    df.to_csv(os.path.join(save_dir, 'data.csv'), index=False)
    train_df, test_df = train_test_split(df, 0.9, random_state=42)
    os.makedirs(save_dir, exist_ok=True)
    train_df.to_csv(os.path.join(save_dir, 'train.csv'), index=False)
    test_df.to_csv(os.path.join(save_dir, 'test.csv'), index=False)
    create_dataset_config(
        './impugen/configs/dataset/%s.yaml' % name,
        name,
        train_csv_path=os.path.join(save_dir, 'train.csv'),
        test_csv_path=os.path.join(save_dir, 'test.csv'),
        target_column=target_name
    )

    os.makedirs('./impugen/configs/dataset/privacy', exist_ok=True)
    total_df = pd.concat([train_df, test_df])
    train, test = train_test_split(total_df, ratio=0.5, random_state=42)
    train.to_csv(f'{save_dir}/privacy_train.csv', index=False)
    test.to_csv(f'{save_dir}/privacy_test.csv', index=False)
    create_dataset_config(
        './impugen/configs/dataset/privacy/%s.yaml' % name,
        'privacy_' + name,
        train_csv_path=os.path.join(save_dir, 'privacy_train.csv'),
        test_csv_path = os.path.join(save_dir, 'privacy_test.csv'),
    )



if __name__ == '__main__':
    for data in SKLEARN_DATA:
        get_data(*data)
