# -*- coding: utf-8 -*-
"""
Local and Parallel (using processes)
"""

import argparse
import logging
import os
import subprocess
import time
import yaml

from config import CONFIG, init_logging


logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser("auto augment bohb")
# arguments
parser.add_argument('--ss_steps', type=int, default=3)  # depretched
parser.add_argument('--min_budget', default=1, type=float, help='Minimum number of epochs for training.')
parser.add_argument('--max_budget', default=1, type=float, help='Maximum number of epochs for training.')  # 原来是9
parser.add_argument('--n_iterations', default=10, type=int, help='Number of iterations performed by the optimizer')  # default=16
parser.add_argument('--n_workers', default=2, type=int, help='Number of workers to run in parallel.')
parser.add_argument('--run_id', default="autoaugment_boss", type=str, help='A unique run id for this optimization run. An easy option is to use the job id of the clusters scheduler.')
parser.add_argument("--n_repeat", default=2, type=int, help="repeat number of budgets in BOSS.py")

# parser.add_argument('--data_url', type=str, default=None, help='s3 path of dataset')
# parser.add_argument('--train_url', type=str, default=None, help='s3 path of outputs')
# parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
# parser.add_argument('--init_method', type=str, default=None, help='master address')
# parser.add_argument("--rank", type=int, help="A unique id for node", default=0)
# parser.add_argument("--world_size", type=int, help="number of tasks?", default=1)


parser.add_argument("--code_dir", default="/home/work/user-job-dir/pt_eqi_aa", type=str)
parser.add_argument('--shared_dir', default="/cache/aa/eqi_random_cifar10_test/shared", type=str)  # 最后不能有/
parser.add_argument("--log_dir", default="/cache/aa/eqi_random_cifar10_test/log", type=str)
parser.add_argument("--bucket_dir", default="s3://bucket-010/yujun")

parser.add_argument("--data_set", default="cifar10", choices=["cifar10", "cifar100"])
parser.add_argument('--net_name', default="wideresnet", type=str, choices=["wideresnet", "simplenet"])
parser.add_argument('--alg', default="EQI", type=str, choices=["EQI", "BOSS", "BOHB", "HB", "SH", "Random", "BO"])
parser.add_argument("--work_space", default="cloud", choices=["cloud", "local"])
parser.add_argument("--init_config", default="random", choices=["random", "lhd"])


args, unparsed = parser.parse_known_args()


assert args.data_set in args.log_dir and args.data_set in args.shared_dir
assert args.alg.lower() in args.log_dir and args.alg.lower() in args.shared_dir

bucket_log_dir = args.bucket_dir + args.log_dir
bucket_shared_dir = args.bucket_dir + args.shared_dir

# send args to CONFIG.
[setattr(CONFIG, key, getattr(args, key)) for key in vars(CONFIG)]
init_logging(exp_dir=CONFIG.log_dir, config_path=os.path.join(args.code_dir, "logging_config.yaml"))  # 在主脚本中调用一次
yaml.dump(vars(CONFIG), open(os.path.join(args.log_dir, "autoaugment_config.yaml"), "w"))

logger = logging.getLogger(__name__)
logger.info("----------- experiment config ------------")
[logger.info("%s: %s" % (k, getattr(CONFIG, k))) for k in vars(CONFIG)]
logger.info("----------- experiment arguments ------------")
[logger.info("%s: %s" % (k, getattr(args, k))) for k in vars(args)]


# initial work at begining.
if args.work_space == "cloud":
    logger.info("---------- install hybandster, copy data -----------")
    if args.data_set == "cifar10":
        subprocess.call(f"cd {args.code_dir} && bash prepare.sh {args.code_dir} {args.bucket_dir+'/Data/CIFAR10'} /cache/data/CIFAR10 ", shell=True)
    elif args.data_set == "cifar100":
        subprocess.call(f"cd {args.code_dir} && bash prepare.sh {args.code_dir} {args.bucket_dir+'/Data/CIFAR100'} /cache/data/CIFAR100 ", shell=True)


logger.info(f" code dir: {args.code_dir}")
logger.info(f" curren path (pwd): {os.getcwd()}")
logger.info(f" ls ./: {os.listdir('./')}")
logger.info(f" ls /home/work/user-job-dir: {os.listdir('/home/work/user-job-dir')}")

cmdline = f"cd {args.code_dir} && python start.py "
cmdline += f"--log_dir={args.log_dir} "
cmdline += f"--shared_dir={args.shared_dir} "
cmdline += f"--alg={args.alg} "
cmdline += f"--init_config={args.init_config} "
cmdline += f"--n_iterations={args.n_iterations} "
cmdline += f"--n_workers={args.n_workers} "

logger.info("### cmdline: {cmdline}")
subprocess.call(f"cd {CONFIG.code_dir} && python mox.py --source {args.log_dir} --target {bucket_log_dir} ", shell=True)
subprocess.call(cmdline, shell=True)
