import argparse
import os

import yaml

__all__ = ["get_config", "print_config"]


def get_config(args):
    config = dict2namespace(
        setdefault(_get_raw_config(args.config), _get_raw_config("default.yml"))
    )

    if not hasattr(config.sampling, "sigma_dist"):
        config.sampling.sigma_dist = config.model.sigma_dist
    if not hasattr(config.biggan, "resolution"):
        config.biggan.resolution = config.data.image_size
    if args.script_type != "finetuning":
        if args.consistent:
            config.sampling.consistent = args.consistent
            config.sampling.noise_first = False
        if args.step_lr:
            config.sampling.step_lr = args.step_lr
        if args.nsigma != 0:
            config.sampling.nsigma = args.nsigma
        if args.step_lr != 0:
            config.sampling.step_lr = args.step_lr
        if args.batch_size != 0:
            config.sampling.batch_size = args.batch_size
            config.fast_fid.batch_size = args.batch_size
        # ToDo: experimental and is only using model_types
        if (
            args.model_types is not None
            and len(args.model_types) == 1
            and args.model_types[0] in [0, 6, 23]
            and config.data.dataset in ["tinyImages", "CIFAR10"]
        ):
            config.sampling.batch_size = min(200, config.sampling.batch_size)
        if (
            args.model_types is not None
            and len(args.model_types) == 1
            and args.model_types[0] in [8]
            and config.data.dataset in ["tinyImages", "CIFAR10"]
        ):
            config.sampling.batch_size = min(800, config.sampling.batch_size)

        if args.ODI_steps == -1:
            args.ODI_steps = None
        if args.fid_num_samples != 0:
            config.fast_fid.num_samples = args.fid_num_samples
        if args.begin_ckpt != 0:
            config.fast_fid.begin_ckpt = args.begin_ckpt
            config.sampling.ckpt_id = args.begin_ckpt
        if args.end_ckpt != 0:
            config.fast_fid.end_ckpt = args.begin_ckpt
        if args.adam:
            config.optim.beta1 = args.adam_beta[0]
            config.optim.beta2 = args.adam_beta[1]
        if args.D_adam:
            config.optim.adv_beta1 = args.D_adam_beta[0]
            config.optim.adv_beta2 = args.D_adam_beta[1]
        if args.D_steps != 0:
            config.adversarial.D_steps = args.D_steps

    return config


def _get_raw_config(name):
    here = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(here, "../blended_diffusion/configs", name), "r") as f:
        yaml_dict = yaml.load(f, Loader=yaml.FullLoader)
    return yaml_dict


def setdefault(config, default):
    print("config is", config, "default is", default)
    for x in default:
        v = default.get(x)
        if isinstance(v, dict) and x in config:
            setdefault(config.get(x), v)
        else:
            config.setdefault(x, v)
    return config


def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace


def print_config(config):
    print(">" * 80)
    print(yaml.dump(config, default_flow_style=False))
    print("<" * 80)
