"""
--- autoaugment/
------ log/  # bohb返回的result
--------- 0_0/
------------ best_model.pth 
------------ best_config.pkl   #  acc, genotype, normal & reduce
------------ info.log 
------------ error.log 
------------ time.log 
------ shared/
--------- 0_0/
------------ config.pkl
------------ result.pkl
"""
import logging
import os
import pickle
import subprocess
import time

import ConfigSpace as CS
from auto_augment import operations
from config import CONFIG, init_logging
from hpbandster.core.worker import Worker


logger = logging.getLogger(__name__)


class AutoAugmentWorker(Worker):

    def __init__(self, _id: str,
                 sleep_interval=0.0,
                 device="cpu",
                 option="search",
                 *args, **kwargs):
        super().__init__(id=_id, *args, **kwargs)
        self.sleep_interval = sleep_interval
        self.id = _id  # 第几个节点_第几个worker
        self.device = device

        self.cnt = 0  # 统计执行了多少次compute
        self.log_dir = os.path.join(CONFIG.log_dir, self.id)
        self.prev_best_acc = 0

        init_logging(exp_dir=CONFIG.log_dir,
                     config_path=os.path.join(CONFIG.code_dir, "logging_config.yaml"))

        os.makedirs(self.log_dir, exist_ok=True)

    def compute(self, config, budget, *args, **kwargs):
        """
        Args:
            config: dictionary containing the sampled configurations by the optimizer
            budget: (float) amount of time/epochs/etc. the model can use to train

        Returns:
            dictionary with mandatory fields:
                'loss' (scalar)
                'info' (dict)
        """
        self.cnt += 1

        self.shared_dir = os.path.join(CONFIG.shared_dir, self.id, str(self.cnt))
        os.makedirs(self.shared_dir, exist_ok=True)

        logger.info("-------- worker %s, %dth compute with %s --------" % (self.id, self.cnt, budget))
        logger.info(f"--------- budget: {config['budget']}")

        # 储存config。
        with open(os.path.join(self.shared_dir, f"config.pkl"), "wb") as f:
            pickle.dump(config, f)

        # 使用subprocess执行cifar10任务。
        # cmdline = f"cd {CONFIG.code_dir} && python cifar10_task.py --option {args.option} --log_dir {self.log_dir} --shared_dir {self.shared_dir} --code_dir {CONFIG.code_dir} --device {self.device} "
        # cmdline += f"--epochs {int(budget)} --name {CONFIG.net_name} --auto_augment {CONFIG.alg} --prev_best_acc {self.prev_best_acc} --data_set {CONFIG.data_set} "
        cmdline = f"python mnist_task.py --shared_dir {self.shared_dir}"
        logger.info("--> cmdline: %s" % cmdline)
        subprocess.call(cmdline, shell=True)

        time.sleep(0.5)

        # 读取结果。
        logger.info(f"----------- id: {self.id}, shared_dir: {self.shared_dir} ---------")
        try:
            with open(os.path.join(self.shared_dir, f"result.pkl"), "rb") as f:
                best_acc, log = pickle.load(f)
                log["log_dir"] = self.id
                log["status"] = "finished"
        except:
            best_acc = 0
            log = {"status": "crashed", "log_dir": self.id}

        if best_acc > self.prev_best_acc:
            self.prev_best_acc = best_acc

        logger.info("best_acc: %s" % best_acc)
        logger.info("log: %s" % log)

        # --- lyj ---
        # # 迁移结果
        # bucket_log_dir = CONFIG.bucket_dir + self.log_dir
        # subprocess.call(f"cd {CONFIG.code_dir} && python mox.py --source {self.log_dir} --target {bucket_log_dir} ", shell=True)

        time.sleep(self.sleep_interval)
        return({
            'loss': -float(best_acc),  # this is the a mandatory field to run hyperband
            'info': log,   # can be used for any user-defined information - also mandatory
        })

    @staticmethod
    def get_configspace(alg: str = None):
        config_space = CS.ConfigurationSpace()
        for i in range(5):
            config_space.add_hyperparameter(CS.CategoricalHyperparameter(name=f'op{i}1', choices=operations.keys()))
            config_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name=f'p{i}1', lower=0, upper=10))   # 从[0,10]中均匀挑一个整数
            config_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name=f'm{i}1', lower=0, upper=9))

            config_space.add_hyperparameter(CS.CategoricalHyperparameter(name=f'op{i}2', choices=operations.keys()))
            config_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name=f'p{i}2', lower=0, upper=10))
            config_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name=f'm{i}2', lower=0, upper=9))
        if alg == "EQI":
            config_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="budget", lower=1, upper=10))
        return config_space
