from nerv.training import BaseParams


class SlotFormerParams(BaseParams):
    project = 'PROJECT_NAME'

    # training settings
    gpus = 1  # 2 GPUs should also be good
    max_epochs = 100  # ~100k steps
    save_interval = 0.25  # save every 0.25 epoch
    eval_interval = 1  # evaluate every 5 epochs
    save_epoch_end = True  # save ckp at the end of every epoch
    n_samples = 1  # visualization after each epoch

    # optimizer settings
    # Adam optimizer, Cosine decay with Warmup
    optimizer = 'Adam'
    lr = 2e-4
    warmup_steps_pct = 0.05  # warmup in the first 5% of total steps
    # no weight decay, no gradient clipping

    # data settings
    dataset = 'langtable_slots'
    data_root = 'DATA_PATH'
    slots_root = 'EXTRACTED_SLOT_PATH'
    n_sample_frames = 6 + 10  # 6 burn-in, 10 rollout
    frame_offset = 1  # no offset
    video_len = 50  # take the first 50 frames of each video
    train_batch_size = 2 // gpus
    val_batch_size = train_batch_size * 2
    num_workers = 8

    # model configs
    model = 'LSlotFormer'
    resolution = (128, 128)
    input_frames = 6  # burn-in frames

    num_slots = 6
    slot_size = 128
    slot_dict = dict(
        num_slots=num_slots,
        slot_size=slot_size,
    )

    # Rollouter
    rollout_dict = dict(
        num_slots=num_slots,
        slot_size=slot_size,
        history_len=input_frames,
        t_pe='sin',  # sine temporal P.E.
        slots_pe='',  # no slots P.E.
        # Transformer-related configs
        d_model=slot_size * 2,
        num_layers=8,
        num_heads=8,
        ffn_dim=slot_size * 2 * 4,
        norm_first=True,
    )

    # CNN Decoder
    dec_dict = dict(
        img_channels=3,
        dec_channels=(128, 64, 64, 64, 64),
        dec_resolution=(16, 16),
        dec_ks=5,
        dec_norm='',
        dec_ckp_path='PRETRAINED_SAVI_CKP_PATH',
    )

    # loss configs
    loss_dict = dict(
        rollout_len=n_sample_frames - rollout_dict['history_len'],
        use_img_recon_loss=False,  # important for predicted image quality
    )

    slot_recon_loss_w = 1.
    img_recon_loss_w = 0.1
