
import torch


ZERO_VAE_CACHE = [
    torch.zeros(1, 16, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 384, 2, 60, 104),
    torch.zeros(1, 192, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 384, 2, 120, 208),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 192, 2, 240, 416),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832),
    torch.zeros(1, 96, 2, 480, 832)
]

feat_names = [f"vae_cache_{i}" for i in range(len(ZERO_VAE_CACHE))]
ALL_INPUTS_NAMES = ["z", "use_cache"] + feat_names
