import yaml
from pathlib import Path

# https://stackoverflow.com/questions/9169025/how-can-i-add-a-python-tuple-to-a-yaml-file-using-pyyaml
class PrettySafeLoader(yaml.SafeLoader):
    def construct_python_tuple(self, node):
        return tuple(self.construct_sequence(node))

PrettySafeLoader.add_constructor(
    u'tag:yaml.org,2002:python/tuple',
    PrettySafeLoader.construct_python_tuple)

class NaiveConfig:

    def __init__(self, config_file_path, root_path=None):

        with open(config_file_path, "r") as f:
            config_dict = yaml.load(f, Loader=PrettySafeLoader)

        self.runtime = RuntimeConfig(config_dict, root_path=root_path)
        self.task_runtime = TaskRuntime(config_dict, root_path=root_path)
        self.config_dict = config_dict

class RuntimeConfig():

    def __init__(self, config_dict, root_path=None):

        runtime_params = config_dict["runtime"]

        for key, value in runtime_params.items():
            if key == "root_path":
                if root_path is None:
                    setattr(self, key, Path(value))
                else:
                    setattr(self, key, Path(root_path))
            elif key == "artifact_dir":
                self.base_dir = Path(self.root_path, value)
            else:
                setattr(self, key, value)
        self.checkpoint_dir = Path(self.base_dir, "checkpoints")
        self.imagenet_checkpoint = Path(self.checkpoint_dir, "imagenet12.pth.tar")
        self.weight_checkpoint_dir = Path(self.base_dir, "weights")

        self.log_dir = Path(self.base_dir, "logs")


class TaskRuntime:

    def __init__(self, config_dict, root_path=None):

        if root_path is None:
            self.root_path = Path(config_dict["runtime"]["root_path"])
        else:
            self.root_path = Path(root_path)

        task_runtime = config_dict["task_runtime"]
        
        for key, value in task_runtime.items():
            if key == "asset_root":
                self.asset_root = Path(self.root_path, config_dict["task_runtime"]["asset_root"])
            elif key == "vdd_root":
                self.vdd_root = Path(self.root_path, config_dict["task_runtime"]["vdd_root"])
            elif key == "dataset_root":
                self.dataset_root = Path(self.root_path, config_dict["task_runtime"]["dataset_root"])
            else:
                setattr(self, key, value)

        # Derived attributes
        self.num_tasks = len(self.task_order)
        self.total_classes = sum([self.task_config[task]["num_classes"] for task in self.task_order])

        self.offsets = []
        offset = 0
        for i in range(len(self.task_order)):
            task = self.task_order[i]
            self.offsets.append(offset)
            offset += self.task_config[task]["num_classes"]
