from ml_collections import ConfigDict, config_dict
import torch


def add_training_configs(config):
    config.training = training = ConfigDict()
    training.batch_size = 1000
    training.n_iters = 100000
    training.save_interval = config_dict.placeholder(int) # If given, will save a save of the model every save_interval iterations
    training.log_interval = 1000
    training.eval_interval = 1000
    training.n_eval_samples = 8196
    training.clip_grad_norm = config_dict.placeholder(float) # If given, will clip the gradient norm to this value
    training.init_checkpoint = config_dict.placeholder(str) # If given, will load the model from this checkpoint

    # loss
    config.loss = loss = ConfigDict()
    loss.name = "edm" # Choices: ["edm", "vp", "ve"]

    # optimization
    config.optim = optim = ConfigDict()
    optim.optimizer = "Adam"
    optim.lr = 3.e-4


def add_sampling_configs(config):
    config.sampling = sampling = ConfigDict()
    sampling.steps = 100
    sampling.S_churn = 10


def add_data_configs(config):
    config.data = data = ConfigDict()
    data.dataset = "checkerboard"
    data.train_set_size = 1000
    data.slack = 0.0 # The minimum distance between the samples and the boundary
    data.neg_dataset = config_dict.placeholder(str)
    data.neg_dataset_size = config_dict.placeholder(int)
    data.pos_dataset = config_dict.placeholder(str)
    data.fake_dataset_size = config_dict.placeholder(int)


def add_model_configs(config):
    config.model = model = ConfigDict()
    model.group_norm = 0 # If nonzero, use group norm in the model (only works for the residual network)
    model.precond = "edm_simple"
    model.hidden_layers = 4
    model.sigma_min = 0.002
    model.sigma_max = 80
    model.dim = 2
    model.h_dim = 256
    model.act = "silu" # Choices: ["relu", "leaky_relu", "silu"]
    model.classifier = config_dict.placeholder(str) # If given, will use the given classifier to guide the diffusion process
    model.bridge = config_dict.placeholder(str) # If given, will use the given bridge model to guide the diffusion process
    model.bridge_sigma_max = config_dict.placeholder(float) # If given, will add the bridge update only for simas smaller than this value
    model.bridge_scale_schedule = config_dict.placeholder(str) # If given, will use the given schedule to scale the bridge update


def add_distill_configs(config):
    config.distill = distill = ConfigDict()
    distill.checkpoint = config_dict.placeholder(str) # The path to the checkpoint of the teacher model
    distill.classifier = config_dict.placeholder(str) # A comma-separated list paths to classifier checkpoints to use in the teacher model
    distill.dataset = config_dict.placeholder(str) # The dataset to use for distillation
    distill.train_set_size = config_dict.placeholder(int) # The size of the dataset to use for distillation
    distill.mix_gt = 0.0 # The probability of mixing in the ground truth training


def get_default_configs(config_names=["model", "sampling", "data", "training"]):
    config = ConfigDict()
    config.seed = config_dict.placeholder(int)
    config.device = (
        torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    )

    for name in config_names:
        globals()[f"add_{name}_configs"](config)

    return config
