"""
--- darts/
------ log/  
--------- 0_0/
------------ best_model.pth 
------------ best_config.pkl   #  acc, genotype, normal & reduce
------------ info.log 
------------ error.log 
------------ time.log 
------ shared/
--------- 0_0/
------------ config.pkl
------------ result.pkl
"""
import logging
import os
import pickle
import subprocess
import threading
import time

import ConfigSpace as CS
from boss_config import CONFIG, init_logging
from hpbandster.core.worker import Worker

print("code_dir:", CONFIG.code_dir)
init_logging(exp_dir=CONFIG.log_dir, config_path=os.path.join(CONFIG.code_dir, "logging_config.yaml"))
logger = logging.getLogger(__name__)


class DartsWorker(Worker):
    def __init__(self, id: str, sleep_interval=0.0, device="cpu", *args, **kwargs):
        super().__init__(id=id, *args, **kwargs)
        self.sleep_interval = sleep_interval
        self.id = id  # 第几个节点_第几个worker
        self.device = device

        self.cnt = 0  # 统计执行了多少次compute, 用于创建中间结果的唯一id。

        logger.info(f"--------- worker {self.id} at {self.device} config ----------")
        [logger.info("%s: %s" % (k, getattr(CONFIG, k))) for k in vars(CONFIG)]
        self.budget_dict = {0: 1, 1: 10, 2: 20, 3: 30, 4: 40, 5: 50}

    def compute(self, config, budget, *args, **kwargs):
        """
        Args:
            config: dictionary containing the sampled configurations by the optimizer
            budget: (float) amount of time/epochs/etc. the model can use to train

        Returns:
            dictionary with mandatory fields:
                'loss' (scalar)
                'info' (dict)
        """
        self.cnt += 1
        self.shared_dir = os.path.join(CONFIG.shared_dir, self.id, str(self.cnt))
        self.log_dir = os.path.join(CONFIG.log_dir, self.id)
        os.makedirs(self.shared_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)

        logger.info("##############################################################")
        if "EQI" in self.id:
            logger.info(f"## start one work at {self.device}. id:{self.id}, {self.cnt}th fake budget {budget} real budget {config['budget']} ##")
        else:
            logger.info(f"## start one work at {self.device}. id:{self.id}, {self.cnt}th budget {budget} ##")
        logger.info("##############################################################")
        with open(os.path.join(self.shared_dir, f"config.pkl"), "wb") as f:
            pickle.dump(config, f)

        # cmdline中的'--gpus all'是没效果的。
        cmdline = f"cd {os.path.join(CONFIG.code_dir, 'pt.darts-master')} && python cifar10_task.py --log_dir {self.log_dir} --shared_dir {self.shared_dir} --code_dir {CONFIG.code_dir} "
        cmdline += f"--gpus all --device {self.device} --dataset {CONFIG.data_set} "
        if "EQI" in self.id:
            cmdline += f"--epochs {self.budget_dict[config['budget']]} "
        else:
            cmdline += f"--epochs {int(budget)}"

        logger.info("--> cifar_task.py cmdline: %s" % cmdline)
        subprocess.call(cmdline, shell=True)

        time.sleep(0.5)

        try:
            with open(os.path.join(self.shared_dir, f"result.pkl"), "rb") as f:
                best_acc, log = pickle.load(f)
                log["status"] = "finished"
                log["log_dir"] = self.log_dir
        except:
            best_acc = 0.0
            log = {"status": "crashed", "log_dir": self.log_dir}

        logger.info("##############################################################")
        logger.info(f"## finish one work id:{self.id} shared_dir:{self.shared_dir}, acc:{best_acc}  ##")
        logger.info("##############################################################")
        bucket_log_dir = CONFIG.bucket_dir + self.log_dir
        subprocess.call(f"cd {CONFIG.code_dir} && python mox.py --source {self.log_dir} --target {bucket_log_dir} ", shell=True)
        root_bucket_log_dir = CONFIG.bucket_dir + CONFIG.log_dir
        subprocess.call(f"cd {CONFIG.code_dir} && python mox.py --source {CONFIG.log_dir} --target {root_bucket_log_dir} ", shell=True)

        time.sleep(self.sleep_interval)
        return({
            'loss': -float(best_acc),  # this is the a mandatory field to run hyperband
            'info': log,   # can be used for any user-defined information - also mandatory
        })

    @staticmethod
    def get_configspace(alg: str = None):
        config_space = CS.ConfigurationSpace()

        for i in range(16):
            config_space.add_hyperparameter(CS.UniformFloatHyperparameter(name=f"normal.0.{i:02d}", lower=0, upper=1))
            config_space.add_hyperparameter(CS.UniformFloatHyperparameter(name=f"reduce.0.{i:02d}", lower=0, upper=1))
        for i in range(24):
            config_space.add_hyperparameter(CS.UniformFloatHyperparameter(name=f"normal.1.{i:02d}", lower=0, upper=1))
            config_space.add_hyperparameter(CS.UniformFloatHyperparameter(name=f"reduce.1.{i:02d}", lower=0, upper=1))
        for i in range(32):
            config_space.add_hyperparameter(CS.UniformFloatHyperparameter(name=f"normal.2.{i:02d}", lower=0, upper=1))
            config_space.add_hyperparameter(CS.UniformFloatHyperparameter(name=f"reduce.2.{i:02d}", lower=0, upper=1))
        for i in range(40):
            config_space.add_hyperparameter(CS.UniformFloatHyperparameter(name=f"normal.3.{i:02d}", lower=0, upper=1))
            config_space.add_hyperparameter(CS.UniformFloatHyperparameter(name=f"reduce.3.{i:02d}", lower=0, upper=1))
        if alg == "EQI":
            config_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="budget", lower=0, upper=5))   # {1, 10, 20, 30, 40, 50}
        return config_space
