import os
import warnings
warnings.filterwarnings('ignore')

from nes.optimizers.baselearner_train.genotypes import Genotype, DARTS, AmoebaNet
from nes.optimizers.cluster_worker import REWorker as Worker


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    mutex = parser.add_mutually_exclusive_group()
    parser.add_argument('--working_directory', type=str,
                        help='directory where the generated results are saved')
    parser.add_argument('--arch_path', type=str, default=None,
                        help='directory where the architecture genotypes are')
    parser.add_argument('--arch_id', type=int, default=0,
                        help='architecture id number')
    parser.add_argument('--seed_id', type=int, default=0,
                        help='seed number')
    parser.add_argument('--global_seed', type=int, default=1,
                        help='global seed number')
    parser.add_argument('--num_epochs', type=int, default=15,
                        help='Number of epochs to train the baselearner')
    parser.add_argument('--batch_size', type=int, default=100,
                        help='Mini-batch size to train the baselearner')
    parser.add_argument('--dataset', type=str, default='fmnist',
                        help='image dataset')
    parser.add_argument('--scheme', type=str, default='deepens_rs',
                        help='scheme name, i.e. nes or deepens variants')
    mutex.add_argument('--train_darts', action='store_true', default=False,
                        help='evaluate the arch found by DARTS')
    mutex.add_argument('--train_pcdarts', action='store_true', default=False,
                        help='evaluate the arch found by DARTS')
    mutex.add_argument('--train_gdas', action='store_true', default=False,
                        help='evaluate the arch found by GDAS')
    mutex.add_argument('--train_global_optima', action='store_true', default=False,
                        help='evaluate the best architecture in nb201')
    mutex.add_argument('--train_amoebanet', action='store_true', default=False,
                        help='evaluate the arch found by RE')
    parser.add_argument('--nb201', action='store_true', default=False,
                        help='NAS-bench-201 space')
    parser.add_argument('--debug', action='store_true', default=False,
                        help='debug mode: run for a single mini-batch')
    args = parser.parse_args()

    assert args.global_seed > 0, "global seed should be greater than 0"

    args.working_directory = os.path.join(args.working_directory,
                                          'run_%d'%args.global_seed)

    # load either DARTS, AmoebaNet, or the incumbent architectures from NES-RS
    opt_to_id = {
        'cifar10' : {'DARTS'   : '001835',
                     'GDAS'    : '003928',
                     'Optima'  : '006111'},
        'cifar100': {'DARTS'   : '',
                     'GDAS'    : '003203',
                     'Optima'  : '009930'},
        'imagenet': {'DARTS'   : '004771',
                     'PC-DARTS' : '002800',
                     'Optima'  : '010767'},
    }

    if args.nb201:
        assert not args.train_amoebanet, "Cannot evaluate AmoebaNet on NB201!"
        if args.train_darts:
            genotype = opt_to_id[args.dataset]['DARTS']
        elif args.train_pcdarts:
            genotype = opt_to_id[args.dataset]['PC-DARTS']
        elif args.train_gdas:
            genotype = opt_to_id[args.dataset]['GDAS']
        elif args.train_global_optima:
            genotype = opt_to_id[args.dataset]['Optima']
        else:
            with open(os.path.join(args.arch_path, 'arch_%d.txt'%args.arch_id), 'r') as f:
                genotype = f.read()
    else:
        assert not args.train_gdas, "GDAS not implemented yet!"
        if args.train_darts:
            genotype = DARTS
        elif args.train_amoebanet:
            genotype = AmoebaNet
        else:
            with open(os.path.join(args.arch_path, 'arch_%d.txt'%args.arch_id), 'r') as f:
                genotype = eval(f.read())

    # no need for a master node in the NES-RS case. Just instantiate a worker
    # that trains the architectures and computes predictions on clean and
    # shifted data
    worker = Worker(working_directory=args.working_directory,
                    num_epochs=args.num_epochs,
                    batch_size=args.batch_size,
                    run_id=args.arch_id,
                    scheme=args.scheme,
                    dataset=args.dataset,
                    nb201=args.nb201,
                    debug=args.debug)

    worker.compute(genotype,
                   budget=args.num_epochs,
                   config_id=(args.arch_id, 0, 0),
                   seed_id=args.seed_id,
                   global_seed=args.global_seed)

