# -*- coding: utf-8 -*-
"""
Local and Parallel (using processes)
"""

import argparse
import logging
import os
import pickle
import subprocess
import time
import yaml
import torch

# from hpbandster.optimizers import BOHB as BOHB
from hpbandster.optimizers import EQI, BOHB, RandomSearch, HyperBand, BO  # BOSS, BO,  SuccessiveHalving
import hpbandster.core.nameserver as hpns


from boss_config import CONFIG, init_logging


parser = argparse.ArgumentParser("auto augment bohb")
# arguments
parser.add_argument('--ss_steps', type=int, default=3)  # depretched
parser.add_argument('--min_budget', default=1, type=float, help='Minimum number of epochs for training.')
parser.add_argument('--max_budget', default=1, type=float, help='Maximum number of epochs for training.')  # 原来是9
parser.add_argument('--n_iterations', default=5, type=int, help='Number of iterations performed by the optimizer')  # default=16
parser.add_argument('--n_workers', default=8, type=int, help='Number of workers to run in parallel.')
parser.add_argument('--run_id', default="autoaugment_boss", type=str, help='A unique run id for this optimization run. An easy option is to use the job id of the clusters scheduler.')
parser.add_argument("--n_repeat", default=2, type=int, help="repeat number of budgets in BOSS.py")

# parser.add_argument('--data_url', type=str, default=None, help='s3 path of dataset')
# parser.add_argument('--train_url', type=str, default=None, help='s3 path of outputs')
# parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
# parser.add_argument('--init_method', type=str, default=None, help='master address')
# parser.add_argument("--rank", type=int, help="A unique id for node", default=0)
# parser.add_argument("--world_size", type=int, help="number of tasks?", default=1)


parser.add_argument("--bucket_dir", default="s3://bucket-auto/yujun")
parser.add_argument("--log_dir", default="/cache/darts/eqi_random_cifar10_test/log")
parser.add_argument("--shared_dir", default="/cache/darts/eqi_random_cifar10_test/shared")
parser.add_argument("--code_dir", default="/home/work/user-job-dir/pt_eqi_darts")
parser.add_argument("--data_set", default="cifar10", choices=["cifar10", "cifar100"])
parser.add_argument('--net_name', default="wideresnet", type=str, choices=["wideresnet", "simplenet"])
parser.add_argument('--alg', default="EQI", type=str, choices=["EQI", "BOSS", "BOHB", "HB", "SH", "Random", "BO"])
parser.add_argument("--work_space", default="cloud", choices=["cloud", "local"])
parser.add_argument("--init_config", default="random", choices=["random", "lhd"])
parser.add_argument("--num_configs", default="[10,10,10,10,10]", type=str)
parser.add_argument("--nsport", default=36000, type=int)

args, unparsed = parser.parse_known_args()


assert args.data_set in args.log_dir and args.data_set in args.shared_dir
assert args.alg.lower() in args.log_dir and args.alg.lower() in args.shared_dir

bucket_log_dir = args.bucket_dir + args.log_dir
bucket_shared_dir = args.bucket_dir + args.shared_dir

# send args to CONFIG.
[setattr(CONFIG, key, getattr(args, key)) for key in vars(CONFIG)]
init_logging(exp_dir=CONFIG.log_dir, config_path=os.path.join(args.code_dir, "logging_config.yaml"))  # 在主脚本中调用一次
yaml.dump(vars(CONFIG), open(os.path.join(args.log_dir, "autoaugment_config.yaml"), "w"))
logger = logging.getLogger(__name__)


# Start a nameserver (see example_1)
logger.info("## start a nameserver...")
NS = hpns.NameServer(run_id=args.run_id, host='127.0.0.1', port=args.nsport)
ns_host, ns_port = NS.start()
print(f"## ns_host:{ns_host} ns_port:{ns_port}")

logger.info(f"## start {args.n_workers} workers...")
for i in range(args.n_workers):
    cmdline = f"cd {CONFIG.code_dir} && python call_worker.py "
    cmdline += f"--id {args.alg}_{i} --sleep_interval 0.5 "
    cmdline += f"--device cuda:{i%torch.cuda.device_count()} " if torch.cuda.is_available() else "--device cpu "
    cmdline += f"--run_id {args.run_id} "  # --host {host} --ns_host {ns_host} --ns_port {ns_port} "
    cmdline += f"--code_dir {CONFIG.code_dir} --shared_dir {CONFIG.shared_dir} --log_dir {CONFIG.log_dir} --bucket_dir {CONFIG.bucket_dir} "
    cmdline += f"--alg {CONFIG.alg} --data_set {CONFIG.data_set} "
    cmdline += f"--nsport {args.nsport} &"
    logger.info(f"---> process ID_{i} cmdline: {cmdline}")
    subprocess.call(cmdline, shell=True)
    time.sleep(1)


# Step 3: Run an optimizer
logger.info(f"------------- start optimizer: {args.alg} -------------")
from worker import DartsWorker
if args.alg == "EQI":
    automl = EQI(configspace=DartsWorker.get_configspace("EQI"),
                 eta=3,
                 run_id=args.run_id,
                 min_budget=args.min_budget,
                 max_budget=args.max_budget,
                 num_configs=eval(args.num_configs),
                 init_method=args.init_config,
                 nameserver_port=args.nsport,
                 log_dir=args.log_dir,
                 )
elif args.alg == "BOHB":
    automl = BOHB(configspace=DartsWorker.get_configspace(),
                  eta=3,
                  run_id=args.run_id,
                  min_budget=args.min_budget,
                  max_budget=args.max_budget,
                  nameserver_port=args.nsport,
                  )
elif args.alg == "HB":
    automl = HyperBand(configspace=DartsWorker.get_configspace(),
                       eta=3,
                       run_id=args.run_id,
                       min_budget=args.min_budget,
                       max_budget=args.max_budget,
                       nameserver_port=args.nsport,
                       )
elif args.alg == "Random":
    automl = RandomSearch(configspace=DartsWorker.get_configspace(),
                          eta=3,
                          run_id=args.run_id,
                          min_budget=args.min_budget,
                          max_budget=args.max_budget,
                          nameserver_port=args.nsport,
                          )
elif args.alg == "BO":
    automl = BO(configspace=DartsWorker.get_configspace(),
                eta=3,
                run_id=args.run_id,
                min_budget=args.min_budget,
                max_budget=args.max_budget,
                nameserver_port=args.nsport,
                num_configs=eval(args.num_configs),
                )
else:
    raise ValueError(args.alg)
result = automl.run(n_iterations=args.n_iterations, min_n_workers=args.n_workers)

# Step 4: store resutls
with open(os.path.join(args.log_dir, "final_result.pkl"), "wb") as f:
    pickle.dump(result, f)

if args.work_space == "cloud":
    subprocess.call(f"cd {args.code_dir} && python mox.py --source {args.log_dir} --target {bucket_log_dir}", shell=True)


# Step 5: shutdown
logger.info(f"------------- shutdown {args.alg} -------------")
if args.alg in ["EQI", "BOSS", "BOHB", "HB", "SH", "Random", "BO"]:
    automl.shutdown(shutdown_workers=True)
else:
    raise ValueError(args.alg)
NS.shutdown()

# Step 6: Analysis
id2config = result.get_id2config_mapping()
incumbent = result.get_incumbent_id()
all_runs = result.get_all_runs()

best_found_config = id2config[incumbent]['config']
best_found_genotype = result.data[incumbent].results[1]['info']['genotype']
logger.info(f"Best found configuration: {best_found_config}")
logger.info(f"Best found genotype: {best_found_genotype}")
logger.info(f'Best found results: {result.get_runs_by_id(incumbent)}')
logger.info('A total of %i unique configurations where sampled.' % len(id2config.keys()))
logger.info('A total of %i runs where executed.' % len(all_runs))
logger.info('Total budget corresponds to %.1f full function evaluations.' % (sum([r.budget for r in all_runs]) / 2))
logger.info('The run took  %.1f seconds to complete.' % (all_runs[-1].time_stamps['finished'] - all_runs[0].time_stamps['started']))


total_budget = 0
if args.alg == "EQI":
    for k, v in result.data.items():
        b = v.config['budget']
        logger.info(f"config id:{k}, budget:{b}, loss:{v.results[1]['loss']}")
        total_budget += b
elif args.alg in ["BO", "Random"]:
    for k, v in result.data.items():
        logger.info(f"config id:{k}, budget:{args.max_budget}, loss:{v.results[args.max_budget]['loss']}")
        total_budget += args.max_budget
logger.info(f"Total budget: {total_budget}")


logger.info("-------------- start test 25 subpolicies --------------")


# Validate the performance of founded architecture by running cifar10_augment_task.py.

final_dir = os.path.join(CONFIG.log_dir, "final_log")
os.makedirs(final_dir, exist_ok=True)


with open(os.path.join(final_dir, f"config.pkl"), "wb") as f:
    pickle.dump(best_found_config, f)

cmdline = f"cd {os.path.join(CONFIG.code_dir, 'pt.darts-master')} && "
cmdline += f"python cifar10_augment_task.py --log_dir {final_dir} --shared_dir {final_dir} --code_dir {CONFIG.code_dir} "
cmdline += f"--epochs 100 --genotype \"{best_found_genotype}\" --gpus all --dataset {CONFIG.data_set} "  # epoch原来是600，为了缩短时间改成100
subprocess.call(cmdline, shell=True)


# save resutls.
if args.work_space == "cloud":
    subprocess.call(f"cd {args.code_dir} && python mox.py --source {CONFIG.log_dir} --target {bucket_log_dir}", shell=True)
