import os
import yaml
import json
import datetime
import argparse

from ProtLig_GPCRclassA.envyaml import EnvYAML

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True,
                        help='config file path')
    parser.add_argument('--cuda_device', type=int,
                        help='Set environment variable CUDA_VISIBLE_DEVICES')
    # parser.add_argument('--output_dir', type=str, default=None,
    #                     help='Path to a directory where output is saved')
    parser.add_argument('--output_file', type=str, default=None,
                        help='File name where output is saved')
    parser.add_argument('--job_array', action='store_true',
                        help='Whether job array is used.')
    parser.add_argument('--restore_model_dir', type=str, default=None,
                        help='Directory to restore data from. This is for validation only.')

    args = parser.parse_args()
    print('Config file: {}'.format(args.config))
    print('---------------')

    # Set visible devices:
    if args.cuda_device is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
        print('Setting CUDA_VISIBLE_DEVICES to: {}'.format(os.environ['CUDA_VISIBLE_DEVICES']))

    # Read params:
    env = EnvYAML(args.config, flatten = False)
    params = env.yaml_config

    # Read params not using envyaml library (without parsing environment variables):
    # with open(args.config, 'r') as yamlfile:
    #     params = yaml.safe_load(yamlfile)

    if args.job_array:
        params['SLURM_JOB_ARRAY'] = True

    if 'eval' not in params['ACTION'] and args.restore_model_dir is not None:
        raise ValueError('ACTION is not eval but --restore_model_dir was provided.')

    if args.restore_model_dir is not None:
        params['RESTORE_MODEL_DIR'] = args.restore_model_dir

    # if args.output_dir is not None:
    #     params['OUTPUT_FILE'] = args.output_dir

    if args.output_file is not None:
        params['OUTPUT_FILE'] = args.output_file

    if 'train_conc' in params['ACTION']:
        from ProtLig_GPCRclassA.scripts.main_train_conc_script import main_train_conc_script
        main_train_conc_script(params)
    elif 'eval_conc' in params['ACTION']:
        from ProtLig_GPCRclassA.scripts.main_eval_conc_script import main_eval_conc_script
        main_eval_conc_script(params)
    
    elif 'train_ec50_regression' in params['ACTION']:
        from ProtLig_GPCRclassA.scripts.main_train_ec50_regression_script import main_train_ec50_regression_script
        main_train_ec50_regression_script(params)
    elif 'eval_ec50_regression' in params['ACTION']:
        from ProtLig_GPCRclassA.scripts.main_eval_ec50_regression_script import main_eval_ec50_regression_script
        main_eval_ec50_regression_script(params)

    elif 'train' in params['ACTION']:
        from ProtLig_GPCRclassA.scripts.main_train_script import main_train_script
        main_train_script(params)
    elif 'eval' in params['ACTION']:
        from ProtLig_GPCRclassA.scripts.main_eval_script import main_eval_script
        main_eval_script(params)
    elif 'predict' in params['ACTION']:
        from ProtLig_GPCRclassA.scripts.main_predict_script import main_predict_script
        main_predict_script(params)
    elif 'precompute' in params['ACTION']:
        from ProtLig_GPCRclassA.scripts.main_precompute_script import main_precompute_script
        main_precompute_script(params)
    elif 'insights' in params['ACTION']:
        from ProtLig_GPCRclassA.scripts.main_insights_script import main_insights_script
        main_insights_script(params)