max_num = 10 + 1
resolut1 = [32, 32]
resolut2 = [128, 128]
num_code = 4096
embed_dim0 = 4  # 256

total_step = 30000
val_interval = total_step // 50
batch_size_t = 64
batch_size_v = batch_size_t
num_work = 4
lr = 2e-3

### datum

transforms = [  # (128,128)
    dict(type="Filter", keys=["image"]),
    # dict(  # both RandomResizedCrop and RandomFlip are negative to vae
    #     type="RandomResizedCrop",
    #     keys=["image"],
    #     size=[128, 128],
    #     scale=[0.5, 1],
    #     interp="bilinear",
    # ),
    dict(type="Normalize", keys=["image"], mean=127.5, std=127.5),
]
dataset_t = dict(
    type="ClevrTex",
    data_file="clevrtex/data.lmdb",
    split="train",
    transform=dict(type="Compose", transforms=transforms),
    base_dir=...,
)
dataset_v = dict(
    type="ClevrTex",
    data_file="clevrtex/data.lmdb",
    split="val",
    transform=dict(type="Compose", transforms=transforms),
    base_dir=...,
)

### model

model = dict(
    type="VQVAE",
    encode=dict(
        type="Sequential",
        modules=[
            dict(
                type="ResNet",
                model_name="resnet18.fb_swsl_ig1b_ft_in1k",
                in_dim=3,
                k0=4,
                strides=[4, 1, 1, 1],
                gn=16,
            ),
            dict(
                type="Conv2d", in_channels=256, out_channels=embed_dim0, kernel_size=1
            ),
        ],
    ),
    decode=dict(
        type="CNN",
        channel0=embed_dim0,
        channels=[64, 64, 64, 64, 64, 64, 64, 3],
        kernels=[1, 3, 3, 1, 3, 3, 1, 1],
        strides=[1, 1, 1, 1, 1, 1, 1, 1],
        ctypes=[0, 0, 0, 2, 0, 0, 2, 0],
        gn=1,
    ),
    codebook=dict(type="Codebook", num_embed=num_code, embed_dim=embed_dim0),
)
model_imap = dict(input="image")
model_omap = ["encode", "zidx", "quant", "decode"]
ckpt_map = None
freez = None

### learn

param_groups = None
optimiz = dict(type="Adam", params=param_groups, lr=lr)
gscale = dict(type="GradScaler")
gclip = dict(type="ClipGradNorm", max_norm=1)

loss_fn = dict(
    recon_d=dict(
        metric=dict(type="MSELoss"),
        map=dict(output=dict(input="decode"), batch=dict(target="image")),
    ),
    align=dict(
        metric=dict(type="MSELoss"),
        map=dict(output=dict(input="quant", target="encode"), batch=dict()),
        transform=dict(type="Detach", keys=["target"]),
    ),
    commit=dict(
        metric=dict(type="MSELoss"),
        map=dict(output=dict(input="encode", target="quant"), batch=dict()),
        transform=dict(type="Detach", keys=["target"]),
        weight=0.25,
    ),
    lpips=dict(
        metric=dict(type="LPIPSLoss"),
        map=dict(output=dict(input="decode"), batch=dict(target="image")),
    ),
)
metric_fn = dict()

before_step = [
    dict(type="ToDevice", keys=["batch.image"]),
    dict(
        type="LinearCosineAnnealing",
        assigns=["optimiz.param_groups[0]['lr']=value"],
        base_values=[lr],
        min_values=[0],
        warmup_step=total_step // 20,
        total_step=total_step,
    ),
]
callback_t = [  # EMA is bad ???
    dict(type="Callback", before_step=before_step),
    dict(type="AverageLog", log_file=...),
]
callback_v = [
    dict(type="Callback", before_step=before_step[:1]),
    callback_t[1],
    dict(type="SaveModel", save_dir=..., since_step=total_step * 0.5),
]

### loop

loop = dict(
    type="Loop",
    dataset_t=...,
    dataset_v=...,
    model=...,
    optimiz=...,
    loss_fn=...,
    metric_fn=...,
    callback_t=...,
    callback_v=...,
    total_step=total_step,
    val_interval=val_interval,
)
