import os
import sys
from basic_config import PATH_TO_CIFAR, TMP_DATETIME_FILE
sys.path.append(PATH_TO_CIFAR)
import train as cifar_train
import hyperparameters.vgg8_cifar10_baseline as cifar10_vgg8_hyperparams
import hyperparameters.vgg11_cifar10_baseline as cifar10_vgg11_hyperparams
import hyperparameters.vgg11_half_cifar10_baseline as cifar10_vgg11_half_hyperparams
import hyperparameters.vgg11_doub_cifar10_baseline as cifar10_vgg11_doub_hyperparams
import hyperparameters.vgg11_quad_cifar10_baseline as cifar10_vgg11_quad_hyperparams
import hyperparameters.vgg13_cifar10_baseline as cifar10_vgg13_hyperparams
import hyperparameters.vgg13_student_cifar10_baseline as cifar10_vgg13_student_hyperparams
import hyperparameters.vgg13_half_cifar10_baseline as cifar10_vgg13_half_hyperparams
import hyperparameters.vgg13_doub_cifar10_baseline as cifar10_vgg13_doub_hyperparams
import hyperparameters.vgg13_quad_cifar10_baseline as cifar10_vgg13_quad_hyperparams
import hyperparameters.vgg16_cifar10_baseline as cifar10_vgg16_hyperparams
import hyperparameters.vgg19_cifar10_baseline as cifar10_vgg19_hyperparams
import hyperparameters.resnet18_nobias_cifar10_baseline as cifar10_resnet18_nobias_hyperparams
import hyperparameters.resnet18_nobias_nobn_cifar10_baseline as cifar10_resnet18_nobias_nobn_hyperparams
import hyperparameters.resnet18_eighth_nobias_nobn_cifar10_baseline as cifar10_resnet18_eighth_nobias_nobn_hyperparams
import hyperparameters.resnet18_fourth_nobias_nobn_cifar10_baseline as cifar10_resnet18_fourth_nobias_nobn_hyperparams
import hyperparameters.resnet18_half_nobias_nobn_cifar10_baseline as cifar10_resnet18_half_nobias_nobn_hyperparams
import hyperparameters.resnet18_doub_nobias_nobn_cifar10_baseline as cifar10_resnet18_doub_nobias_nobn_hyperparams
import hyperparameters.resnet34_nobias_cifar10_baseline as cifar10_resnet34_nobias_hyperparams
import hyperparameters.resnet34_nobias_nobn_cifar10_baseline as cifar10_resnet34_nobias_nobn_hyperparams
import hyperparameters.resnet34_half_nobias_nobn_cifar10_baseline as cifar10_resnet34_half_nobias_nobn_hyperparams
import hyperparameters.resnet34_doub_nobias_nobn_cifar10_baseline as cifar10_resnet34_doub_nobias_nobn_hyperparams
import copy
from log import logger, get_first_timestamp

num_models = 1

def main():
    gpus = [2]*num_models

    if len(sys.argv) >=2:
        model_type = str(sys.argv[1])
        if '@' in model_type:
            model_type, architecture_type = model_type.split('@')
        else:
            architecture_type = 'vgg11'
    else:
        model_type = 'cifar10'
        architecture_type = 'vgg11'

    if len(sys.argv) >=3:
        sub_type = str(sys.argv[2]) + '_'
        sub_type_str = str(sys.argv[2])
    else:
        sub_type = ''
        sub_type_str = 'plain'

    if len(sys.argv) >= 4:
        gpu_num = int(sys.argv[3])
        gpus = [gpu_num] * num_models

    base_config = globals()[f'{model_type}_{architecture_type}_{sub_type}hyperparams'].config
    logger.info('base_config is {}'.format(base_config))
    logger.info("gpus are {}".format(gpus))
    logger.info(f'Model type is {model_type} and sub_type is {sub_type_str}')

    timestamp = get_first_timestamp()

    assert len(gpus) == num_models
    for idx in range(num_models):
        model_config = copy.deepcopy(base_config)
        model_config['seed'] = model_config['seed'] + idx
        logger.info("Model with idx {} runnning with seed {} on GPU {}".format(
            idx, 
            model_config['seed'], 
            gpus[idx]))

        model_output_dir = './cifar_models/exp_{}_{}_{}_{}/model_{}/'.format(
            model_type, 
            architecture_type,
            sub_type_str, 
            timestamp, 
            idx)
        logger.info("This model with idx {} will be saved at {}".format(idx, model_output_dir))

        accuracy = cifar_train.main(model_config, model_output_dir, gpus[idx])
        logger.info("The accuracy of model with idx {} is {}".format(idx, accuracy*100))

    logger.info("Done training all the models")
    os.remove(TMP_DATETIME_FILE)

if __name__ == '__main__':
    main()