max_num = 10 + 1
resolut1 = [32, 32]
resolut2 = [128, 128]
num_code = 4096
embed_dim0 = 4  # 256
embed_dim = 256

total_step = 50000
val_interval = total_step // 50
batch_size_t = 32  # // 2
batch_size_v = batch_size_t
num_work = 4
lr = 2e-4

### datum

transforms = [  # (128,128)
    dict(type="Filter", keys=["image", "segment"]),
    # dict(
    #     type="RandomResizedCrop",
    #     keys=["image", "segment"],
    #     size=[128, 128],
    #     scale=[0.5, 1],
    #     interp=["bilinear", "nearest-exact"],
    # ),
    dict(type="Normalize", keys=["image"], mean=127.5, std=127.5),
]
dataset_t = dict(
    type="ClevrTex",
    data_file="clevrtex/data.lmdb",
    split="train",
    transform=dict(type="Compose", transforms=transforms),
    base_dir=...,
)
dataset_v = dict(
    type="ClevrTex",
    data_file="clevrtex/data.lmdb",
    split="val",
    transform=dict(type="Compose", transforms=transforms),  # [0::2]
    base_dir=...,
)

### model

model = dict(
    type="SlotDiffusionImage",
    mediat=dict(
        type="VQVAE",
        encode=dict(
            type="Sequential",
            modules=[
                dict(
                    type="ResNet",
                    model_name="resnet18.fb_swsl_ig1b_ft_in1k",
                    in_dim=3,
                    k0=4,
                    strides=[4, 1, 1, 1],
                    gn=16,
                ),
                dict(
                    type="Conv2d",
                    in_channels=256,
                    out_channels=embed_dim0,
                    kernel_size=1,
                ),
            ],
        ),
        decode=None,
        codebook=dict(type="Codebook", num_embed=num_code, embed_dim=embed_dim0),
    ),
    encode_backbone=dict(
        type="BigLittle",
        big=dict(type="Dinolet", arch="v1s8", num_block=None, learn=False),
        little=dict(
            type="CNN",
            channel0=3,
            channels=[64, 64, 64, 64],
            kernels=[5, 5, 5, 5],
            strides=[1, 1, 1, 1],
        ),
        bpre=None,
        lpre=None,
        bpost=dict(
            type="Conv2d", in_channels=384, out_channels=64, kernel_size=3, padding=1
        ),
        lpost=None,
        fuse="cat",
        post=dict(
            type="Conv2d",
            in_channels=64 * 2,
            out_channels=64 * 2,
            kernel_size=5,
            padding=2,
        ),
    ),
    h2w2=resolut2,
    encode_posit_embed=dict(
        type="CartesianPositionalEmbedding2d", resolut=resolut2, embed_dim=64 * 2
    ),
    encode_project=dict(
        type="Sequential",
        modules=[
            dict(type="LayerNorm", normalized_shape=64 * 2),
            dict(type="MLP", channel0=64 * 2, channels=[embed_dim, embed_dim]),
        ],
    ),
    initializ=dict(type="LearntInitializ", num_slot=max_num, slot_dim=embed_dim),
    correct=dict(
        type="SlotAttention",
        num_iter=3,
        embed_dim=embed_dim,
        ffn_dim=embed_dim * 2,
        dropout=0.01,
        trunc_bp="bi-level",
    ),
    noise_sched=dict(
        type="NoiseSchedJjd",
        # beta_schedule="linear",
        # beta_start=0.00085,  # 0.0015
        # beta_end=0.012,  # 0.0195
        beta_schedule="scaled_linear",  # SlotDiffuz
        beta_start=0.0015,
        beta_end=0.0195,
        num_train_timesteps=1000,
    ),
    decode_backbone=dict(
        type="UNet2dCondition",
        in_channels=embed_dim0,
        out_channels=embed_dim0,
        down_block_types=[
            "DownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
        ],
        mid_block_type="UNetMidBlock2DCrossAttn",
        up_block_types=[
            "CrossAttnUpBlock2D",
            "CrossAttnUpBlock2D",
            "CrossAttnUpBlock2D",
            "UpBlock2D",
        ],
        block_out_channels=[128, 256, 384, 512],
        layers_per_block=2,
        cross_attention_dim=256,
        transformer_layers_per_block=1,
        dropout=0.1,
    ),
)
model_imap = dict(input="image")  # conditioned<random
model_omap = ["zidx", "noise", "decode", "segment", "correct", "attent"]
ckpt_map = [
    ["m.mediat.encode.", "m.encode."],
    ["m.mediat.codebook.", "m.codebook."],
]
freez = ["m.mediat"]

### learn

param_groups = None
optimiz = dict(type="Adam", params=param_groups, lr=lr)
gscale = dict(type="GradScaler")
gclip = dict(type="ClipGradNorm", max_norm=1)

loss_fn = dict(
    mse_d=dict(
        metric=dict(type="MSELoss"),
        map=dict(output=dict(input="decode", target="noise"), batch=dict()),
    ),
)
metric_fn = dict(
    ari=dict(
        metric=dict(type="ARI", fg=False),
        map=dict(output=dict(input="segment"), batch=dict(target="segment")),
    ),
    ari_fg=dict(
        metric=dict(type="ARI", fg=True),
        map=dict(output=dict(input="segment"), batch=dict(target="segment")),
    ),
)

before_step = [
    dict(type="ToDevice", keys=["batch.image", "batch.segment"]),
    dict(
        type="CosineAnnealing",
        assigns=["model.m.initializ.sigma=value"],
        base_values=[1],
        min_values=[0],
        total_step=total_step,
    ),
    dict(
        type="LinearCosineAnnealing",
        assigns=["optimiz.param_groups[0]['lr']=value"],
        base_values=[lr],
        min_values=[0],
        warmup_step=total_step // 20,
        total_step=total_step,
    ),
]
after_forward = []
callback_t = [
    dict(type="Callback", before_step=before_step, after_forward=after_forward),
    dict(type="AverageLog", log_file=...),
]
callback_v = [
    dict(type="Callback", before_step=before_step[:1], after_forward=after_forward),
    callback_t[1],
    dict(type="SaveModel", save_dir=..., since_step=total_step * 0.5),
]

### loop

loop = dict(
    type="Loop",
    dataset_t=...,
    dataset_v=...,
    model=...,
    optimiz=...,
    loss_fn=...,
    metric_fn=...,
    callback_t=...,
    callback_v=...,
    total_step=total_step,
    val_interval=val_interval,
)
