import argparse
from misc.utils import *
import yaml
from preprocessing.imputation import Imputer
from sklearn.preprocessing import LabelEncoder
import warnings
warnings.simplefilter('ignore')

from BorutaShap import BorutaShap

import logging
from sklearn.ensemble import IsolationForest
import os
import pickle
import importlib

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

def main(): 
    parser = argparse.ArgumentParser(add_help=True)

    parser.add_argument('--data_config', type=str, default='6M_mortality', choices=[file.split('.')[0] for file in os.listdir('./data_config')] )
    parser.add_argument('--show_drop_cols', '-dc', action='store_true', help="Show drop columns")
    parser.add_argument('--random_seed', '-s', type=int, default=0, help="Value of Random Seed\nDefault is 0")
    parser.add_argument('--model', type=str, default='TabNet', choices=["XGB", "FTTransformer", "CategoryEmbedding", "DreamquarkTabNet"], help="Select Model\nDefault is TabNet")
    
    parser.add_argument('--self_training', default='None', choices=['naive', 'curriculum', "fixed", "fixed_percentiles", "None"])
    parser.add_argument('--limited_training_sample', default=None)
    parser.add_argument('--limited_validation_sample', default=None)
    parser.add_argument('--prior', default="Likelihood", choices=["Likelihood", "Density"])

    parser.add_argument('--data_editor', type=str, default=None, choices=["Mahalanobis", "Likelihood"])
    parser.add_argument('--dist_threshold', type=float, default = None)
    
    parser.add_argument('--do_pretraining', action='store_true')

    parser.add_argument('--feature_corruption', default=0, type=float)
    
    parser.add_argument('--save_best_hparams', '-p', action='store_true', help="Allow to save best hyper parameters")
    parser.add_argument('--load_hparams', action='store_true',  help="Allow to load hyperparameters")
    parser.add_argument('--use_CV', action='store_true')
    parser.add_argument('--hparam_search_only', action='store_true', help="Only find best hyperparameters")
    
    parser.add_argument('--report', action='store_true')
    parser.add_argument('--use_borutashap', action='store_true')

    parser.add_argument('--save_log', action='store_true', help="Allow to save log")
    parser.add_argument('--save_model', action='store_true', help="Allow to save the model")
    parser.add_argument('--save_data', action='store_true')
    parser.add_argument('--save_data_editor', action='store_true')
    
    parser.add_argument('--use_temperature', action="store_true")
    parser.add_argument('--use_histogram_binning', action="store_true")
    parser.add_argument('--use_spline_calibrator', action="store_true")
    parser.add_argument('--use_gaussian_process', action="store_true")

    # n_jobs
    parser.add_argument('--n_jobs', type=int, default = None)

    # optuna
    parser.add_argument('--n_trials', type=int, default = None)

    # self_training
    parser.add_argument('--alpha', type=float, default = None)
    parser.add_argument('--delta', type=float, default = None)
    parser.add_argument('--threshold', type=float, default = None)
    parser.add_argument('--auto_alpha', action="store_true")
    
    # model
    parser.add_argument('--hparams', type=str, default = None)
    parser.add_argument('--fast_dev_run', action="store_true")
    parser.add_argument('--device', type=str, default = None)
    parser.add_argument('--gpus', nargs='+', default=None, type = int)
    parser.add_argument('--batch_size', type=int, default = None)
    parser.add_argument('--early_stopping_patience', type=int, default = None)
    parser.add_argument('--n_splits', type=int, default=None)
    args = parser.parse_args()

    if args.model == "XGB":
        from config.xgb_config import config
    elif args.model == "TabTransformer":
        from config.tabtransformer_config import config
    elif args.model == "CategoryEmbedding":
        from config.categoryembedding_config import config
    elif args.model == "DreamquarkTabNet":
        from config.dreamquark_tabnet_config import config

    config = merge_config(config, args)
    print("Configuration is Loaded")

    if args.model in ["FTTransformer", "CategoryEmbedding"]:
        os.environ["CUDA_VISIBLE_DEVICES"]=",".join(map(str, config.model.gpus))
    
    datalib = importlib.import_module('data_module')
    datamodule = getattr(datalib, config.data.data_module)(config)
    print("DataModule is Loaded")

    data, label, numeric_cols, category_cols = datamodule.prepare_data()

    if config.runner_option.report:
        if not os.path.exists("report"):
            os.mkdir("report")
        report_path = config.model.path.split('/')[-1].split('-fold')[0]

        if not os.path.exists(f'report/{report_path}'):
            os.mkdir(f'report/{report_path}')

        logger = setup_logger("logger", config.runner_option, config, save_dir=f"report/{report_path}")
    else:
        logger = setup_logger("logger", config.runner_option, config)

    logger.info("Configs")
    logger.info(config)

    if config.model.fast_dev_run:
        if hasattr(config.optuna, 'n_trials'):
            config.optuna.n_trials = 1
        if hasattr(config.model, 'max_epochs'):
            config.model.max_epochs = 1
        if hasattr(config.model, 'early_stopping_rounds'):
            config.model.early_stopping_rounds = 1
        if hasattr(config.model, 'early_stopping_patience'):
            config.model.early_stopping_patience = 1

    import runners
    runner = getattr(runners, config.runner_option.model)(config, data, label, numeric_cols, category_cols, logger)
    
    import torch
    print('Current cuda device:', torch.cuda.current_device())
    print('Count of using GPUs:', torch.cuda.device_count())
    
    runner.run()


if __name__ =='__main__':
    main()