# -*- coding: utf-8 -*-
from common.log import access_log
from common.config import load_config
from hgtft.dataset.normalization_processed_data import NormalizationProcessedData
from hgtft.dataset.create_dataset import CreateDataset


def dataset_main(config_name):
    configs = load_config(config_name)

    project_sample_list = configs.project
    data_type = configs.data_type
    flow_list = configs.flow

    if data_type == 'train':
        configs.project_train = project_sample_list
        configs.project_validation = []
        configs.project_test = []
    elif data_type == 'validation':
        configs.project_train = []
        configs.project_validation = project_sample_list
        configs.project_test = []
    elif data_type == 'test':
        configs.project_train = []
        configs.project_validation = []
        configs.project_test = project_sample_list

    for flow in flow_list:
        if flow == 'normalization':
            access_log.info('====================== start data normalization ======================')
            for i in range(0, len(project_sample_list), int(configs.max_workers)):
                start = i
                end = i + int(configs.max_workers)
                prat_normalization_list = project_sample_list[start: end]
                if len(prat_normalization_list) > 0:
                    data_normalization = NormalizationProcessedData(configs)
                    data_normalization.project_normalization(prat_normalization_list, scaler_func='minmax')

        if flow == 'create':
            access_log.info('====================== start create dataset ======================')
            for i in range(0, len(project_sample_list), int(configs.max_workers)):
                start = i
                end = i + int(configs.max_workers)
                prat_normalization_list = project_sample_list[start: end]
                create_dataset = CreateDataset(configs, project_list=prat_normalization_list)
                create_dataset.start()
