from torch.optim import Adam, SGD

DATA_HIDDEN_DIM = {
    "dsprites": [256, 128],
    "shapes3d": [256, 256],
    "mpi3d": [256, 256],
}

OPTIMIZER = {
    "adam": Adam,
    "sgd": SGD,
}
