from torch_geometric.graphgym.register import register_config


@register_config("overwrite_defaults")
def overwrite_defaults_cfg(cfg):
    """Overwrite the default config values that are first set by GraphGym in
    torch_geometric.graphgym.config.set_cfg

    WARNING: At the time of writing, the order in which custom config-setting
    functions like this one are executed is random; see the referenced `set_cfg`
    Therefore never reset here config options that are custom added, only change
    those that exist in core GraphGym.
    """

    # Training (and validation) pipeline mode
    cfg.train.mode = "custom"  # 'standard' uses PyTorch-Lightning since PyG 2.1

    # Overwrite default dataset name
    cfg.dataset.name = "none"

    # Overwrite default rounding precision
    cfg.round = 5


@register_config("extended_cfg")
def extended_cfg(cfg):
    """General extended config options."""

    # Additional name tag used in `run_dir` and `wandb_name` auto generation.
    cfg.name_tag = ""

    # In training, if True (and also cfg.train.enable_ckpt is True) then
    # always checkpoint the current best model based on validation performance,
    # instead, when False, follow cfg.train.eval_period checkpointing frequency.
    cfg.train.ckpt_best = False

    cfg.train.batch_size_test = 256

    cfg.train.accum_gradient_steps = 1

    cfg.group_id = "none"

    cfg.rank = 0

    cfg.run_dir = "none"

    cfg.run_id = 0
    cfg.train.sampler_graph_limit = 10
