#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ABOUT:
=======
command-line interface for launching training
"""
import json
import os
import argparse
import jax
import pickle
import shutil
import sys
import gc

from dloaders.init_dataloader import init_dataloader

def main():
    """
    Usage: CUDA_VISIBLE_DEVICES=n, python Pair_Alignment.py -configs file.json
    """
    # running models on single GPU
    err_ms = 'SELECT GPU TO RUN THIS COMPUTATION ON with CUDA_VISIBLE_DEVICES=DEVICE_NUM'
    assert len(jax.devices()) == 1, err_ms
    del err_ms
    
    # initialize argparse
    parser = argparse.ArgumentParser(prog='Pair_Alignment')
    parser.add_argument('-configs',
                        type = str,
                        required=True,
                        help='Load configs from file or folder of files, in json format.')
    top_level_args = parser.parse_args()
    
    # helper: open a single config file and extract additional arguments
    def read_config_file(config_file):
        with open(config_file, 'r') as f:
            contents = json.load(f)
            t_args = argparse.Namespace()
            t_args.__dict__.update(contents)
            args = parser.parse_args(namespace=t_args)
        return args
    
    # read argparse
    assert top_level_args.configs.endswith('.json'), "input is one JSON file"
    print(f'TRAINING WITH: {top_level_args.configs}')
    args = read_config_file(top_level_args.configs)
    pred_model_type = args.pred_model_type
    
    # import correct wrappers, dataloader initializers
    if 'pairhmm_indp_sites' in pred_model_type:
        from cli.train_pairhmm_indp_sites import train_pairhmm_indp_sites as train_fn
        from dloaders.init_counts_dset import init_counts_dset as init_datasets
        from dloaders.CountsDset import jax_collator as collate_fn
        
    elif pred_model_type in ['pairhmm_frag_and_site_classes',
                             'pairhmm_nested_tkf',
                             'neural_hmm',
                             'feedforward']:
        from dloaders.init_full_len_dset import init_full_len_dset as init_datasets
        from dloaders.FullLenDset import jax_collator as collate_fn
        
        if pred_model_type in ['pairhmm_frag_and_site_classes', 'pairhmm_nested_tkf']:
            from cli.train_pairhmm_transit_mixes import train_pairhmm_transit_mixes as train_fn
            
        elif pred_model_type == 'neural_hmm':
            from cli.train_neural_hmm import train_neural_hmm as train_fn

        elif pred_model_type == 'feedforward':
            from cli.train_feedforward import train_feedforward as train_fn
            
    # make dataloder list
    dload_dict = init_datasets( args,
                                'train',
                                training_argparse = None,
                                include_dataloader = True )
        
    # train model
    train_fn( args, dload_dict )

if __name__ == '__main__':
    main()
