import ml_collections

num_pieces = 9
num_digits = 4
epochs = 10

dataset_size = 60000
batch_size = 64
num_batches = dataset_size // batch_size

def config_dict(**kwargs):
    return ml_collections.ConfigDict(initial_dictionary=kwargs)

def get_config():
    return config_dict(
        dataset="unscramble-noisy-MNIST", # "unscramble-MNIST" | "unscramble-noisy-MNIST" | "sort-MNIST"
        num_pieces=num_pieces,
        num_digits=num_digits,
        image_size=28,
        CNN=config_dict(
            in_channels=1,
            hidden_channels1=32,
            kernel_size1=5,
            stride1=1,
            padding1=2,
            hidden_channels2=64,
            kernel_size2=5,
            stride2=1,
            padding2=2,
        ),
        transformer=config_dict(
            embd_dim=128,
            d_hid=64,
            n_layers=2,
            nhead=8,
            dropout=0.1
        ),
        train=config_dict(
            record_wandb=False,
            run_name="sort-MNIST_2layers_3300",
            save_model=True,
            sample_N=3,
            method="diffusion",
            diffusion=config_dict(
                num_timesteps=2,
                transition="riffle", # "insert" | "swap" | "riffle"
                lazy=True,
                reverse="generalized_PL", # "original" | "PL" | "generalized_PL"
                reverse_steps=[], #! 0 ~ T inclusive
                latent=False,
            ),
            batch_size=batch_size,
            num_batches=num_batches,
            epochs=epochs,
            warmup_steps=(4300),
            learning_rate=1e-5,
            scheduler=True,
            ema_rate=0.999,
            reinforce_N=10,
            reinforce_ema_rate=0.995,
        ),
        use_ema=False,
        eval_only=True,
        save_wrong_images=False,
        beam_search=True,
        beam_size=5
    )
