max_num = 10 + 1
resolut1 = [32, 32]
resolut2 = [128, 128]
num_code = 4096
embed_dim0 = 4  # 256
num_scale = 3

total_step = 30000
val_interval = total_step // 50
batch_size_t = 64
batch_size_v = batch_size_t
num_work = 4
lr = 2e-3

### datum

transforms = [  # (128,128)
    dict(type="Filter", keys=["image"]),
    # dict(  # both RandomResizedCrop and RandomFlip are negative to vae
    #     type="RandomResizedCrop",
    #     keys=["image"],
    #     size=[128, 128],
    #     scale=[0.5, 1],
    #     interp="bilinear",
    # ),
    dict(type="Normalize", keys=["image"], mean=127.5, std=127.5),
]
dataset_t = dict(
    type="ClevrTex",
    data_file="clevrtex/data.lmdb",
    split="train",
    transform=dict(type="Compose", transforms=transforms),
    base_dir=...,
)
dataset_v = dict(
    type="ClevrTex",
    data_file="clevrtex/data.lmdb",
    split="val",
    transform=dict(type="Compose", transforms=transforms),
    base_dir=...,
)

### model

model = dict(
    type="VQVAEMultiScale",
    encode=dict(
        type="Sequential",
        modules=[
            dict(
                type="ResNet",
                model_name="resnet18.fb_swsl_ig1b_ft_in1k",
                in_dim=3,
                k0=4,
                strides=[4, 1, 1, 1],
                gn=16,
            ),
            dict(
                type="Conv2d", in_channels=256, out_channels=embed_dim0, kernel_size=1
            ),
        ],
    ),
    decode=dict(
        type="CNN",
        channel0=embed_dim0,
        channels=[64, 64, 64, 64, 64, 64, 64, 3],
        kernels=[1, 3, 3, 1, 3, 3, 1, 1],
        strides=[1, 1, 1, 1, 1, 1, 1, 1],
        ctypes=[0, 0, 0, 2, 0, 0, 2, 0],
        gn=1,
    ),
    codebook=dict(
        type="ModuleList",
        modules=[
            dict(type="Codebook", num_embed=int(num_code**0.5), embed_dim=embed_dim0)
            for _ in range(1 + num_scale)
        ],
    ),
    project=dict(
        type="LinearPinv2d", in_channel=embed_dim0 * 2, out_channel=embed_dim0
    ),
    num_scale=num_scale,
    normaliz=True,
    learn=True,
)
model_imap = dict(input="image")
model_omap = [
    *[f"encode{_}" for _ in range(num_scale)],
    *[f"encode{_}_" for _ in range(num_scale)],
    *[f"zidx{_}" for _ in range(num_scale)],
    *[f"quant{_}_" for _ in range(num_scale)],
    *[f"quant{_}" for _ in range(num_scale)],
    *[f"decode{_}" for _ in range(num_scale)],
]
ckpt_map = None
freez = None

### learn

param_groups = None
optimiz = dict(type="Adam", params=param_groups, lr=lr)
gscale = dict(type="GradScaler")
gclip = dict(type="ClipGradNorm", max_norm=1)

loss_fn = dict(
    **{
        f"recon_d{_}": dict(
            metric=dict(type="MSELoss"),
            map=dict(output=dict(input=f"decode{_}"), batch=dict(target="image")),
            transform=dict(type="Resize", keys=["target"], scale=1 / 2**_),
        )
        for _ in range(num_scale)
    },
    **{
        f"align{_}": dict(
            metric=dict(type="MSELoss"),
            map=dict(
                output=dict(input=f"quant{_}_", target=f"encode{_}_"), batch=dict()
            ),
            transform=dict(type="Detach", keys=["target"]),
        )
        for _ in range(num_scale)
    },
    **{
        f"commit{_}": dict(
            metric=dict(type="MSELoss"),
            map=dict(
                output=dict(input=f"encode{_}_", target=f"quant{_}_"), batch=dict()
            ),
            transform=dict(type="Detach", keys=["target"]),
            weight=0.25,
        )
        for _ in range(num_scale)
    },
    **{
        f"lpips{_}": dict(
            metric=dict(type="LPIPSLoss"),
            map=dict(output=dict(input=f"decode{_}"), batch=dict(target=f"image")),
            transform=dict(type="Resize", keys=["target"], scale=1 / 2**_),
        )
        for _ in range(num_scale)
    },
)
metric_fn = dict()

before_step = [
    dict(type="ToDevice", keys=["batch.image"]),
    dict(
        type="CosineAnnealingConstant",
        assigns=["model.m.alpha.data[...]=value"],
        base_values=[0.5],  # 0.5 > 1, 0.2, 0.1
        min_values=[0],
        cos_step=total_step // 2,
    ),
    dict(
        type="LinearCosineAnnealing",
        assigns=["optimiz.param_groups[0]['lr']=value"],
        base_values=[lr],
        min_values=[0],
        warmup_step=total_step // 20,
        total_step=total_step,
    ),
]
after_step = [
    dict(
        type="Squarewave",  # >CosineAnnealingConstant
        assigns=[
            f"model.m.codebook._modules['{_}'].replace_rate.data[...]=value"
            for _ in range(1 + num_scale)
        ],
        const_values=[1, 0],
        points=[0, total_step // 2, total_step],  # ttl (comment) > ttl/2 (uncomment)
    )
]
callback_t = [  # EMA is bad ???
    dict(type="Callback", before_step=before_step, after_step=after_step),
    dict(type="AverageLog", log_file=...),
]
callback_v = [
    dict(type="Callback", before_step=before_step[:1]),
    callback_t[1],
    dict(type="SaveModel", save_dir=..., since_step=total_step * 0.5),
]

### loop

loop = dict(
    type="Loop",
    dataset_t=...,
    dataset_v=...,
    model=...,
    optimiz=...,
    loss_fn=...,
    metric_fn=...,
    callback_t=...,
    callback_v=...,
    total_step=total_step,
    val_interval=val_interval,
)
