grad_max_norm = 35
print_freq = 10
max_epochs = 200
warmup_iters = 200
return_len_ = 10

multisteplr = False
multisteplr_config = dict(
    decay_t = [87 * 500],
    decay_rate = 0.1,
    warmup_t = warmup_iters,
    warmup_lr_init = 1e-6,
    t_in_epochs = False
)

optimizer = dict(
    optimizer=dict(
        type='AdamW',
        lr=1e-3,
        weight_decay=0.01,
    ),
)

data_path = '/data3/xuhr/dataset/nuscenes/'

train_dataset_config = dict(
    type='nuScenesSceneDatasetLidarQuery',
    data_path = data_path,
    return_len = return_len_, 
    offset = 0,
    flip = True,
    times = 2,
    imageset = '/data3/xuhr/dataset/nuscenes_infos_train_temporal_v3_scene.pkl', 
)
    
val_dataset_config = dict(
    type='nuScenesSceneDatasetLidarQuery',
    data_path = data_path,
    return_len = return_len_, 
    offset = 0,
    flip = True,
    times = 2,
    imageset = '/data3/xuhr/dataset/nuscenes_infos_val_temporal_v3_scene.pkl', 
)

train_wrapper_config = dict(
    type='lidarquery_dataset_nuscenes',
    phase='train', 
)

val_wrapper_config = dict(
    type='lidarquery_dataset_nuscenes',
    phase='val', 
)

train_loader = dict(
    batch_size = 1,
    shuffle = True,
    num_workers = 1,
)

val_loader = dict(
    batch_size = 1,
    shuffle = False,
    num_workers = 1,
)

loss = dict(
    type='MultiLoss',
    loss_cfgs=[
        dict(
            type='CeLoss',
            weight=10.0,
            ignore_label=-100,
            use_weight=False,
            cls_weight=None,
            input_dict={
                'ce_inputs': 'preds',
                'ce_labels': 'xyz_labels'}),
        dict(
            type='LovaszLoss',
            weight=1.0,
            input_dict={
                'logits': 'pred_output',
                'labels': 'gt_output'}),
        ])


load_from = ''

_dim_ = 16
expansion = 8
base_channel = 64
n_e_ = 512
model = dict(
    type = 'TripLane',
    encoder_cfg=dict(
        type='TripLaneEncoder',
        z_down=False,
    ), 
    decoder_cfg=dict(
        type='TripLaneDecoder',
    ),
    num_classes=18,
    expansion=expansion
)

shapes = [[200, 200], [100, 100], [50, 50], [25, 25]]

unique_label = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
label_mapping = "./config/label_mapping/nuscenes-occ.yaml"