from utils import TabularConfig, read_config
from tabular_dataset import  get_dataset
import os
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import pickle
import argparse


def cal_scaler(data_x, numeric_col):
    if numeric_col is not None and len(numeric_col) > 0:
        scaler = StandardScaler()
        scaler.fit(data_x[numeric_col])
        return scaler
    else:
        return {}


def save_scaler_main(dst_name):
    TC = TabularConfig()
    cfg = read_config(cfg_path=TC.get_dataset_config_path() + dst_name + '.yaml')
    scaler_save_path = TC.get_scaler_save_path()
    if not os.path.exists(scaler_save_path):
        os.makedirs(scaler_save_path)
    train_set, col_info = get_dataset(dst_name, split='train', rand_number=cfg['split_seed'],
                                      test_ratio=cfg['test_ratio'])
    scaler = cal_scaler(train_set[0], col_info.cont_name)
    pickle.dump(scaler, open(os.path.join(scaler_save_path, dst_name+'_scaler.pkl'), 'wb'))



