import os
import pickle
import argparse

import hpbandster.core.nameserver as hpns
import hpbandster.core.result as hpres

from hpbandster.optimizers import BOHB

import logging
logging.basicConfig(level=logging.DEBUG)


from bohb_worker import PyTorchWorker as worker

parser = argparse.ArgumentParser(description='BOHB experiments')
parser.add_argument('--min_budget',   type=float, help='Minimum number of epochs for training.',    default=1)
parser.add_argument('--max_budget',   type=float, help='Maximum number of epochs for training.',    default=9)
parser.add_argument('--n_iterations', type=int,   help='Number of iterations performed by the optimizer', default=16)
parser.add_argument('--worker', help='Flag to turn this into a worker process', action='store_true')
parser.add_argument('--run_id', type=str, help='A unique run id for this optimization run. An easy option is to use the job id of the clusters scheduler.')
parser.add_argument('--nic_name',type=str, help='Which network interface to use for communication.', default='lo')
parser.add_argument('--shared_directory',type=str, help='A directory that is accessible for all processes, e.g. a NFS share.', default='.')
parser.add_argument('--dataset', type=str, default='fashion')
parser.add_argument('--arch', type=str, default='conv_mnist')
parser.add_argument('--alpha', type=float, default=0.05)
parser.add_argument('--mode', type=str, default='train_loss_agn')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--thread', type=int, default=0)
parser.add_argument('--eta', type=int, default=2)
args=parser.parse_args()



# Every process has to lookup the hostname
host = hpns.nic_name_to_host(args.nic_name)

debug_file = '-'.join(map(str,[args.shared_directory, args.thread]))

with open('bohb_debug/' + debug_file, 'a')  as f:
    f.write(str(args))


if args.worker:
        import time
        time.sleep(5)   # short artificial delay to make sure the nameserver is already running
        w = worker(run_id=args.run_id, host=host, timeout=120)
        w.load_dataset(args.dataset)

        if args.dataset in ['cifar10', 'fashion', 'mnist', 'svhn']:
            num_classes = 10
        else:
            num_classes = 100

        w.set_model(args.arch, num_classes)
        w.set_alpha(args.alpha)
        w.set_mode(args.mode)
        w.set_debug_file(debug_file)
        w.load_nameserver_credentials(working_directory=args.shared_directory)
        w.run(background=False)
        exit(0)


# This example shows how to log live results. This is most useful
# for really long runs, where intermediate results could already be
# interesting. The core.result submodule contains the functionality to
# read the two generated files (results.json and configs.json) and
# create a Result object.
result_logger = hpres.json_result_logger(directory=args.shared_directory, overwrite=False)


# Start a nameserver:
NS = hpns.NameServer(run_id=args.run_id, host=host, port=0, working_directory=args.shared_directory)
ns_host, ns_port = NS.start()

# Start local worker
w = worker(run_id=args.run_id, host=host, nameserver=ns_host, nameserver_port=ns_port, timeout=120)
#w.set_gpu(args.gpu)
w.load_dataset(args.dataset)

if args.dataset in ['cifar10', 'fashion', 'mnist', 'svhn']:
    num_classes = 10
else:
    num_classes = 100

w.set_model(args.arch, num_classes)
w.set_alpha(args.alpha)
w.set_mode(args.mode)
w.set_debug_file(debug_file)


w.run(background=True)

# Run an optimizer
bohb = BOHB(  configspace = worker.get_configspace(args.arch=='resnet110_cifar'),
                          eta = args.eta,
                          run_id = args.run_id,
                          host=host,
                          nameserver=ns_host,
                          nameserver_port=ns_port,
                          result_logger=result_logger,
                          min_budget=args.min_budget, max_budget=args.max_budget,
                   )

res = bohb.run(n_iterations=args.n_iterations, min_n_workers=1)

# store results
with open(os.path.join(args.shared_directory, 'results.pkl'), 'wb') as fh:
        pickle.dump(res, fh)

# shutdown
bohb.shutdown(shutdown_workers=True)
NS.shutdown()

