from nerv.training import BaseParams


class SlotFormerParams(BaseParams):
    project = 'PROJECT_NAME'

    # training settings
    gpus = 1  # 2 GPUs should also be good
    max_epochs = 1  # ~100k steps
    save_interval = 10  # save every 0.25 epoch
    eval_interval = 10  # evaluate every 5 epochs
    save_epoch_end = False # save ckp at the end of every epoch
    n_samples = 10  # visualization after each epoch

    # optimizer settings
    # Adam optimizer, Cosine decay with Warmup
    optimizer = 'Adam'
    lr = 1e-4
    warmup_steps_pct = 0.05  # warmup in the first 5% of total steps
    # no weight decay, no gradient clipping

    # data settings
    dataset = 'langtable_action'
    data_root = 'DATA_PATH'
    #slots_root = './data/language-table/language_table_blocktoblock_4block_sim/savi_langtable_slots.pkl'
    n_sample_frames = 6 + 10  # 6 burn-in
    frame_offset = 1  # no offset
    video_len = 50  # take the first 50 frames of each video
    train_batch_size = 4 // gpus
    val_batch_size = train_batch_size
    num_workers = 4

    # model configs
    model = 'RoboticsLSlotFormer'
    resolution = (128, 128)
    input_frames = 6  # burn-in frames

    #### SAVi + SlotFormer E2E Configs 
    
    # Slot Attention
    num_slots = 6
    slot_size = 128
    slot_dict = dict(
        num_slots=num_slots,  # at most 5 objects per scene
        slot_size=slot_size,
        slot_mlp_size=256,
        num_iterations=2,
    )

    # CNN Encoder
    enc_dict = dict(
        enc_channels=(3, 64, 64, 64, 64),
        enc_ks=5,
        enc_out_channels=128,
        enc_norm='',
    )

    # CNN Decoder
    dec_dict = dict(
        dec_channels=(128, 64, 64, 64, 64),
        dec_resolution=(16, 16),
        dec_ks=5,
        dec_norm='',
        dec_ckp_path='PRETRAINED_SAVI_CKP_PATH',
    )

    # Predictor
    pred_dict = dict(
        pred_type='transformer',
        pred_rnn=True,
        pred_norm_first=True,
        pred_num_layers=2,
        pred_num_heads=4,
        pred_ffn_dim=slot_dict['slot_size'] * 4,
        pred_sg_every=None,
    )

    # Rollouter
    rollout_dict = dict(
        num_slots=num_slots,
        slot_size=slot_size,
        history_len=input_frames,
        t_pe='sin',  # sine temporal P.E.
        slots_pe='',  # no slots P.E.
        # Transformer-related configs
        d_model=slot_size * 2,
        num_layers=8,
        num_heads=8,
        ffn_dim=slot_size * 2 * 4,
        norm_first=True,
    )

    #### Decision Transformer Decoder

    act_dec_dict = dict(
        prev_len=1, # input t-5~t slots (max: 6)
        next_len=10, # input t+1~t+10 slots (max: 10)
        t_pe='sin',  # sine temporal P.E.
        slots_pe='',  # no slots P.E.
        # Transformer-related configs
        d_model=128,
        num_layers=2,
        num_heads=8,
        inst_size=768,
        act_size=2,
        ffn_dim=128,
        norm_first=True,
        num_sample_tasks=10,
        task_steps=10,
        task_batch_size=5,
        upload_video='none', # upload video to wandb
        dec_type='transformer',
        mask=False,
        inst=False,
        pool_inst=False,
        sample_task=True,
    )

    # loss configs
    loss_dict = dict(
        rollout_len=n_sample_frames - rollout_dict['history_len'],
        use_img_recon_loss=False,  # important for predicted image quality
        use_post_recon_loss=False,
        use_action_loss=True,
        kld_method='none',  # standard SAViss
    )

    action_loss_w = 1.
    slot_recon_loss_w = 1.
    img_recon_loss_w = 0.1


