import os

import pandas as pd
from ucimlrepo import fetch_ucirepo

from impugen.transform.label_encoder import _valid_ind
from impugen.utils import train_test_split
from impugen.utils.data import create_dataset_config

UCI_DATA = [
    # (291, './dataset/uci/airfoil'),
    # (176, './dataset/uci/blood'),
    (59, './dataset/uci/letter'),
    # (17, './dataset/uci/breast'),
    # (732, './dataset/uci/darwin'),
    # Diagnosing Alzheimer’s disease from on-line handwriting: A novel dataset and performance benchmarking
    # (697, './dataset/uci/student'),
    # (75, './dataset/uci/musk'),
    # (165, './dataset/uci/compression'),
    (602, './dataset/uci/bean'),
    # (53, './dataset/uci/iris'),
]

TYPE2DTYPE = dict(
    Integer=int, Binary=bool, Continuous=float, Categorical=str
)


def get_data(ucirepo_id, save_dir):
    name = save_dir.split('/')[-1]
    # if os.path.isfile(os.path.join(save_dir, 'train.csv')) and os.path.isfile(os.path.join(save_dir, 'test.csv')):
    #     return
    print(ucirepo_id)
    data = fetch_ucirepo(id=ucirepo_id)
    df = pd.concat([data.data.features, data.data.targets], axis=1)
    for col, dtype in zip(data.variables.name, data.variables.type):
        # follow uci dataset's variable type
        try:
            valid_ind = _valid_ind(df[col])
            to_dtype = TYPE2DTYPE[dtype]
            if ucirepo_id == 732 and df[col].dtype == float and to_dtype == str:  # darwin features are numeric
                continue
            if to_dtype == str:
                df.loc[valid_ind, col] = "'" + df.loc[valid_ind, col].astype(str) + "'"
            else:
                df[col] = df[col].astype(to_dtype)
        except Exception as e:
            print(e, save_dir, col, dtype)
    if ucirepo_id == 722:
        df = df.astype(bool)  # android feature are binary

    df = df.dropna(subset=df.columns, how='any')  # remove missing for evaluation
    print(save_dir, len(df))
    print(df.iloc[-4:].to_string())
    train_df, test_df = train_test_split(df, 0.9, random_state=42)
    os.makedirs(save_dir, exist_ok=True)
    train_df.to_csv(os.path.join(save_dir, 'train.csv'), index=False)
    test_df.to_csv(os.path.join(save_dir, 'test.csv'), index=False)

    total_df = pd.concat([train_df, test_df])
    train, test = train_test_split(total_df, ratio=0.5, random_state=42)
    train.to_csv(f'{save_dir}/privacy_train.csv', index=False)
    test.to_csv(f'{save_dir}/privacy_test.csv', index=False)


    create_dataset_config(
        './impugen/configs/dataset/%s.yaml' % name,
        name,
        train_csv_path=os.path.join(save_dir, 'train.csv'),
        test_csv_path = os.path.join(save_dir, 'test.csv'),
    )

    os.makedirs('./impugen/configs/dataset/privacy', exist_ok=True)
    create_dataset_config(
        './impugen/configs/dataset/privacy/%s.yaml' % name,
        'privacy_' + name,
        train_csv_path=os.path.join(save_dir, 'privacy_train.csv'),
        test_csv_path = os.path.join(save_dir, 'privacy_test.csv'),
    )


if __name__ == '__main__':
    for data in UCI_DATA:
        get_data(*data)
