from utils import TabularConfig, read_config, generate_seed_set
from tabular_dataset import  get_dataset
import os
from globa_utils import setup_seed
from args import get_args, complete_cfg_by_args
from train_eval import train
# os.environ["CUDA_VISIBLE_DEVICES"]='6'


args = get_args()


def main():
    dst_name = args.dst_name
    TC = TabularConfig()
    cfg = read_config(cfg_path=TC.get_dataset_config_path() + dst_name + '.yaml')
    complete_cfg_by_args(cfg, args)
    print(cfg)
    cfg['result_save_path'] = os.path.join(TC.get_experiment_save_path(), dst_name, cfg['result_save_path'])
    if not os.path.exists(cfg['result_save_path'].rsplit('/', 1)[0]):
        os.mkdir(cfg['result_save_path'].rsplit('/', 1)[0])
    train_set, col_info = get_dataset(dst_name, split='train', rand_number=cfg['split_seed'],
                                      test_ratio=cfg['test_ratio'])
    test_set, _ = get_dataset(dst_name, split='test', rand_number=cfg['split_seed'],
                              test_ratio=cfg['test_ratio'])

    # ============================== model prepare and train =======================================================
    seed_set = generate_seed_set()
    setup_seed(0)
    train(cfg, seed_set, train_set, test_set, col_info, label_list=None)


if __name__ == '__main__':
    main()

# CUDA_VISIBLE_DEVICES=0 python eval_main.py --dst_name reuters --val_method split_free_joint --k 1 --model xgb --result_save_name split_free_joint.pkl

