# -*- coding: utf-8 -*-
import argparse
from copy import deepcopy
import os

from common.constants import SEED
from common.utils import load_config_json
from common.datasets import dataset_load_methods

from tabular_diffusion.denoising_models.complete_transformer_denoising_model import CompleteTransformerDenoisingModel
from tabular_diffusion.utils_func import main_train, main_test


def main():
    parser = argparse.ArgumentParser(description='AILab_TabGen')
    parser.add_argument('-d', '--dataset', default=None, help='Dataset name', type=str, required=True)
    parser.add_argument('-td', '--train_diffusion', action='store_true', help='Train autoencoder', required=False)
    parser.add_argument('-sv', '--supervised_model', action='store_true',
                        help='Supervised model experiment', required=False)
    args = deepcopy(parser.parse_args().__dict__)

    # Input arguments into variables
    dataset_name = args.get('dataset')
    path_config_file = os.path.join(os.getcwd(), "conf", "tabular_diffusion", 'diffusion_' + dataset_name + '.json')
    data_dir = os.path.join(os.getcwd(), "data")

    print(f'Current TabGenDDPM config path: \n{path_config_file}')
    print(f'Current data path: \n{data_dir}')

    # Load model configuration parameters
    train_params, optimizer_params, model_params, mask_params, exp_params = load_config_json(
        path_config_file)

    # Fetch the correct load method
    load_method = dataset_load_methods.get(dataset_name)
    if load_method is None:
        print("Dataset '{}' is not available. Please indicate one of the following: {}".format(dataset_name, ', '.join(
            ["'{}'".format(k) for k in dataset_load_methods.keys()])))
        return

    # Train the requested model
    if args.get('train_diffusion'):
        ################################################################################################################
        # TRAIN
        ################################################################################################################
        print('TRAIN-SCRIPT')

        # Load data
        print('Loading data...')
        x_train_torch, y_train_torch, x_test_torch, y_test_torch, meta = load_method(
            data_dir=data_dir, seed=train_params[SEED], device='cpu', include_target=False,
            cat_weights_method="Uniform")

        # Diffusion training
        main_train(dn_fn=CompleteTransformerDenoisingModel,
                   x_train_torch=x_train_torch,
                   y_train_torch=y_train_torch,
                   x_test_torch=x_test_torch,
                   y_test_torch=y_test_torch,
                   meta=meta,
                   model_params=model_params,
                   optimizer_params=optimizer_params,
                   train_params=train_params,
                   mask_params=mask_params)

    if args.get('supervised_model'):
        ################################################################################################################
        # ML UTILITY TEST
        ################################################################################################################
        print('ML UTILITY TEST-SCRIPT')

        # Load data
        print('Loading data...')
        x_train_torch, y_train_torch, x_test_torch, y_test_torch, meta = load_method(
            data_dir=data_dir, seed=exp_params[SEED], device='cpu', include_target=False,
            cat_weights_method="Uniform")

        main_test(dn_fn=CompleteTransformerDenoisingModel,
                  x_train_torch=x_train_torch,
                  y_train_torch=y_train_torch,
                  x_test_torch=x_test_torch,
                  y_test_torch=y_test_torch,
                  meta=meta,
                  model_params=model_params,
                  train_params=train_params,
                  svt_exp_params=exp_params,
                  dataset_name=dataset_name)


if __name__ == '__main__':
    main()
