grad_max_norm = 35
print_freq = 10
max_epochs = 200
warmup_iters = 50
return_len_ = 9


# load_from = 'out/triplane_20250117-v1/latest.pth' 
# load_from = 'out/triplane_20250208-v1/latest.pth' 
load_from = 'out/triplane_20250215-v1/latest.pth' 

multisteplr = False
multisteplr_config = dict(
    decay_t = [87 * 500],
    decay_rate = 0.1,
    warmup_t = warmup_iters,
    warmup_lr_init = 1e-6,
    t_in_epochs = False
)
optimizer = dict(
    optimizer=dict(
        type='AdamW',
        lr=1e-3,
        weight_decay=0.01,
    ),
)

data_path = '/data3/xuhr/dataset/nuscenes/'

train_dataset_config = dict(
    type='nuScenesSceneDatasetLidarHexPlane',
    data_path = data_path,
    return_len = return_len_+1, 
    offset = 0,
    imageset = '/data3/xuhr/dataset/nuscenes_infos_train_temporal_v3_scene.pkl', 
)
    
val_dataset_config = dict(
    type='nuScenesSceneDatasetLidarHexPlane',
    data_path = data_path,
    return_len = return_len_+1, 
    offset = 0,
    imageset = '/data3/xuhr/dataset/nuscenes_infos_val_temporal_v3_scene.pkl', 
)

train_wrapper_config = dict(
    type='lidarqueryhex_dataset_nuscenes',
    phase='train', 
)

val_wrapper_config = dict(
    type='lidarqueryhex_dataset_nuscenes',
    phase='val', 
)

train_loader = dict(
    batch_size = 1,
    shuffle = True,
    num_workers = 1,
)

val_loader = dict(
    batch_size = 1,
    shuffle = False,
    num_workers = 1,
)

loss = dict(
    type='MultiLoss',
    loss_cfgs=[
        dict(
            type='PlaneLoss',
            weight=10.0,
            reg_fff=1e-6,
            input_dict={
                'mask': 'hexplane_mask',
                'logits': 'hexplane_pred',
                'labels': 'hexplane_gt',
            }
        ),
        dict(
            type='CeLoss',
            weight=0.1,
            ignore_label=-100,
            cls_weight=None,
            input_dict={
                'ce_inputs': 'preds',
                'ce_labels': 'xyz_labels'}),
        dict(
            type='LovaszLoss',
            weight=0.01,
            input_dict={
                'logits': 'pred_output',
                'labels': 'gt_output'}),
        dict(
            type='PoseLoss',
            weight=1,
            loss_type='l2',
            num_modes=3,
            input_dict={
                'rel_pose': 'rel_pose',
                'metas': 'output_metas'}),
    ]
)



_dim_ = 16
expansion = 8
base_channel = 64
n_e_ = 512
model = dict(
    type = 'TransHexplane',
    num_frames=return_len_,
    delta_input=False,
    offset=1,
    prev_steps=4,
    num_classes=18,
    triplane_cfg=dict(
        type='TripLane',
        encoder_cfg=dict(
            type='TripLaneEncoder',
            z_down=False,
        ), 
        decoder_cfg=dict(
            type='TripLaneDecoder',
        ),
        num_classes=18,
        expansion=expansion
    ),
    encoder_cfg=dict(
        type='TripLaneEncoder',
    ), 
    decoder_cfg=dict(
        type='TripLaneDecoder',
    ),
    # pose_encoder=dict(
    #     type = 'PoseEncoder',
    #     in_channels=5,
    #     out_channels=base_channel*2,
    #     num_layers=2,
    #     num_modes=3,
    #     num_fut_ts=1,
    # ),
    # pose_decoder=dict(
    #     type = 'PoseDecoder',
    #     in_channels=base_channel*2,
    #     num_layers=2,
    #     num_modes=3,
    #     num_fut_ts=1,
    # ),
)

shapes = [[200, 200], [100, 100], [50, 50], [25, 25]]

unique_label = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
label_mapping = "./config/label_mapping/nuscenes-occ.yaml"