max_num = 10 + 1
resolut1 = [32, 32]
resolut2 = [128, 128]
num_code = 4096
embed_dim = 256

total_step = 50000
val_interval = total_step // 50
batch_size_t = 32  # // 2
batch_size_v = batch_size_t
num_work = 4
lr = 2e-4

### datum

transforms = [  # (128,128)
    dict(type="Filter", keys=["image", "segment"]),
    # dict(
    #     type="RandomResizedCrop",
    #     keys=["image", "segment"],
    #     size=[128, 128],
    #     scale=[0.5, 1],
    #     interp=["bilinear", "nearest-exact"],
    # ),
    dict(type="Normalize", keys=["image"], mean=127.5, std=127.5),
]
dataset_t = dict(
    type="ClevrTex",
    data_file="clevrtex/data.lmdb",
    split="train",
    transform=dict(type="Compose", transforms=transforms),
    base_dir=...,
)
dataset_v = dict(
    type="ClevrTex",
    data_file="clevrtex/data.lmdb",
    split="val",
    transform=dict(type="Compose", transforms=transforms),  # [0::2]
    base_dir=...,
)

### model

model = dict(
    type="SLATE",
    mediat=dict(
        type="VQVAE",
        encode=dict(
            type="Sequential",
            modules=[
                dict(
                    type="ResNet",
                    model_name="resnet18.fb_swsl_ig1b_ft_in1k",
                    in_dim=3,
                    k0=4,
                    strides=[4, 1, 1, 1],
                    gn=16,
                ),
                dict(
                    type="Conv2d",
                    in_channels=256,
                    out_channels=embed_dim,
                    kernel_size=1,
                ),
            ],
        ),
        decode=None,
        codebook=dict(type="Codebook", num_embed=num_code, embed_dim=embed_dim),
    ),
    h1w1=resolut1,
    encode_backbone=dict(
        type="BigLittle",
        big=dict(type="Dinolet", arch="v1s8", num_block=None, learn=False),
        little=dict(
            type="CNN",
            channel0=3,
            channels=[64, 64, 64, 64],
            kernels=[5, 5, 5, 5],
            strides=[1, 1, 1, 1],
        ),
        bpre=None,
        lpre=None,
        bpost=dict(
            type="Conv2d", in_channels=384, out_channels=64, kernel_size=3, padding=1
        ),
        lpost=None,
        fuse="cat",
        post=dict(
            type="Conv2d",
            in_channels=64 * 2,
            out_channels=64 * 2,
            kernel_size=5,
            padding=2,
        ),
    ),
    h2w2=resolut2,
    encode_posit_embed=dict(
        type="CartesianPositionalEmbedding2d", resolut=resolut2, embed_dim=64 * 2
    ),
    encode_project=dict(
        type="Sequential",
        modules=[
            dict(type="LayerNorm", normalized_shape=64 * 2),
            dict(type="MLP", channel0=64 * 2, channels=[embed_dim, embed_dim]),
        ],
    ),
    initializ=dict(type="LearntInitializ", num_slot=max_num, slot_dim=embed_dim),
    correct=dict(
        type="SlotAttention",
        num_iter=3,
        embed_dim=embed_dim,
        ffn_dim=embed_dim * 2,
        dropout=0.01,
        trunc_bp="bi-level",
    ),
    decode_bos=dict(type="Parameter", func="randn", size=[1, 1, embed_dim]),
    decode_posit_embed=dict(
        type="LearntPositionalEmbedding1d",
        length=resolut1[0] * resolut1[1],
        embed_dim=embed_dim,
        dropout=0.1,
    ),
    decode_backbone=dict(
        type="TransformDecodeOCL",
        embed_dim=embed_dim,
        num_head=4,
        ffn_dim=embed_dim * 4,
        dropout=0.1,
        num_layer=4,
    ),
    decode_readout=dict(
        type="Linear", in_features=embed_dim, out_features=num_code, bias=False
    ),
)
model_imap = dict(input="image")  # conditioned<random
model_omap = ["zidx", "prob", "segment", "correct", "attent"]
ckpt_map = [
    ["m.mediat.encode.", "m.encode."],
    ["m.mediat.codebook.", "m.codebook."],
]
freez = ["m.mediat"]

### learn

param_groups = None
optimiz = dict(type="Adam", params=param_groups, lr=lr)
gscale = dict(type="GradScaler")
gclip = dict(type="ClipGradNorm", max_norm=1)

loss_fn = dict(
    ce_p=dict(
        metric=dict(type="CrossEntropyLoss"),
        map=dict(output=dict(input="prob", target="zidx"), batch=dict()),
    ),
)
metric_fn = dict(
    ari=dict(
        metric=dict(type="ARI", fg=False),
        map=dict(output=dict(input="segment"), batch=dict(target="segment")),
    ),
    ari_fg=dict(
        metric=dict(type="ARI", fg=True),
        map=dict(output=dict(input="segment"), batch=dict(target="segment")),
    ),
)

before_step = [
    dict(type="ToDevice", keys=["batch.image", "batch.segment"]),
    dict(
        type="CosineAnnealing",
        assigns=["model.m.initializ.sigma=value"],
        base_values=[1],
        min_values=[0],
        total_step=total_step,
    ),
    dict(
        type="LinearCosineAnnealing",
        assigns=["optimiz.param_groups[0]['lr']=value"],
        base_values=[lr],
        min_values=[0],
        warmup_step=total_step // 20,
        total_step=total_step,
    ),
]
after_forward = []
callback_t = [
    dict(type="Callback", before_step=before_step, after_forward=after_forward),
    dict(type="AverageLog", log_file=...),
]
callback_v = [
    dict(type="Callback", before_step=before_step[:1], after_forward=after_forward),
    callback_t[1],
    dict(type="SaveModel", save_dir=..., since_step=total_step * 0.5),
]

### loop

loop = dict(
    type="Loop",
    dataset_t=...,
    dataset_v=...,
    model=...,
    optimiz=...,
    loss_fn=...,
    metric_fn=...,
    callback_t=...,
    callback_v=...,
    total_step=total_step,
    val_interval=val_interval,
)
