import yaml
from pathlib import Path

# https://stackoverflow.com/questions/9169025/how-can-i-add-a-python-tuple-to-a-yaml-file-using-pyyaml
class PrettySafeLoader(yaml.SafeLoader):
    def construct_python_tuple(self, node):
        return tuple(self.construct_sequence(node))

PrettySafeLoader.add_constructor(
    u'tag:yaml.org,2002:python/tuple',
    PrettySafeLoader.construct_python_tuple)


class ConfigPlain:

    def __init__(self, benchmark, runtime_config_file, root_path, exp_name, wandb_username=None, wandb_project=None):

        with open(runtime_config_file, "r") as f:
            config_dict = yaml.load(f, Loader=PrettySafeLoader)

        self.runtime = RuntimeConfig(config_dict, root_path, exp_name, wandb_username=wandb_username, wandb_project=wandb_project)
        
        # Get the benchmark config
        benchmark_config_path = Path(root_path, "configs", "benchmarks", f"{benchmark}.yaml")
        self.benchmark_config = BenchmarkConfig(benchmark_config_path)
        
        self.config_dict = config_dict

class RuntimeConfig():

    def __init__(self, config_dict, root_path, exp_name, wandb_username=None, wandb_project=None):

        runtime_params = config_dict["runtime"]

        self.root_path = Path(root_path)

        for key, value in runtime_params.items():
            setattr(self, key, value)

        self.artifact_dir = Path(self.root_path, "artifacts", self.dataset, exp_name)

        # Set all the derived attributes
        self.checkpoint_dir = Path(self.artifact_dir, "checkpoints") # Place where all the finetuned models will be stored
        self.checkpoint_with_mean_dir = Path(self.artifact_dir, "checkpoints_with_mean") # Place where finetuned models eith calculated means will be sotred
        self.config_checkpoint_dir = Path(self.artifact_dir, "configs")
        self.mapping_dir = Path(self.artifact_dir, "mapping")
        # Handle ImageNet separately
        self.imagenet_base_mapping = Path(self.root_path, "artifacts", "imagenet_base", "mapping", "imagenet12.yaml")
        self.imagenet_base_config = Path(self.root_path, "artifacts", "imagenet_base", "configs", "imagenet12.yaml")
        self.imagenet_base_checkpoint = Path(self.root_path, "artifacts", "imagenet_base", "checkpoints", "imagenet12.pth.tar")
        self.temp_dir = Path(self.artifact_dir, "temp")
        self.log_dir = Path(self.artifact_dir, "logs")
        self.supernet_config_dir = Path(self.artifact_dir, "supernet_configs")

        # WandB setup
        self.wandb = {
            "project": wandb_project,
            "entity": wandb_username,
            "group": exp_name
        }


class BenchmarkConfig:

    def __init__(self, task_config_file):
        
        with open(task_config_file, "r") as f:
            config_dict = yaml.load(f, Loader=PrettySafeLoader)

        for key, value in config_dict.items():
            setattr(self, key, value)
