norm_cfg = dict(type='SyncBN', requires_grad=True)

model = dict(
    type='EncoderDecoder_Guidance',
    backbone=dict(
        type='FusionFormer',
        embed_dim=24,
        depths=[6, 2],
        num_heads=[16, 16],
        segment_frequencies=[3, 1],
        window_bases=[8, 16],
        ratio_bases=[2, 1],
        qk_head_dims=[32, 32],
        v_head_dims=[32, 32],
        mlp_ratio=4.,
        drop_path_rate = 0.1,
        init_values = 1e-5,
        kernel_norm=True,
        use_level_embed=False,
        attention_sum=True,
        super_res=False,
        convert_norm=True,
        pretrained_path='/guest/mnt0/fuzheming/weights/fusion_base_6_2.pth',
    ),
    decode_head=dict(
        type='LightHead',
        in_channels=[96, 192, 384],
        in_index=[0, 2, 3],
        channels=192,
        dropout_ratio=0.1,
        embed_dims=[96, 192],
        num_classes=19,
        is_dw=True,
        use_attn=False,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
    auxiliary_head=dict(
        type='FoundationGuidanceHead',
        in_channels=[192, 384],
        channels=192,
        in_index=[2, 3],
        base_channels=48,
        vit_channels=768,
        use_cls=True,
        num_classes=19,
        pretrained_path='/guest/mnt0/fuzheming/weights/dinov2_vitb14_reg4_pretrain.pth',
        loss_decode=dict(type='AlignmentLoss', loss_weight=[1, 1],
                         use_cls=True, cls_only=False),
    ),
    # model training and testing settings
    train_cfg=dict(),
    test_cfg=dict(mode='whole')
)


#data
dataset_type = 'CityscapesDataset'
data_root = '/guest/mnt0/fuzheming/FusionFormer_seg/data/Cityscapes'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (1024, 1024)
data = dict(
    samples_per_gpu=8,
    workers_per_gpu=8,
    train=dict(
        type='CityscapesDataset',
        data_root='/guest/mnt0/fuzheming/FusionFormer_seg/data/Cityscapes',
        img_dir='leftImg8bit/train',
        ann_dir='gtFine/train',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations'),
            dict(
                type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
            dict(
                type='RandomCrop', crop_size=(1024, 1024), cat_max_ratio=0.75),
            dict(type='RandomFlip', prob=0.5),
            dict(type='PhotoMetricDistortion'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size=(1024, 1024), pad_val=0, seg_pad_val=255),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_semantic_seg'])
        ]),
    val=dict(
        type='CityscapesDataset',
        data_root='/guest/mnt0/fuzheming/FusionFormer_seg/data/Cityscapes',
        img_dir='leftImg8bit/val',
        ann_dir='gtFine/val',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(2048, 1024),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]),
    test=dict(
        type='CityscapesDataset',
        data_root='/guest/mnt0/fuzheming/FusionFormer_seg/data/Cityscapes',
        img_dir='leftImg8bit/val',
        ann_dir='gtFine/val',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(2048, 1024),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]))

#3.optimizer
log_config = dict(
    interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True
optimizer = dict(
    type='AdamW',
    lr=0.00006,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys=dict(
            head=dict(lr_mult=10.0),
            norm=dict(decay_mult=0.0))))
optimizer_config = dict()
lr_config = dict(
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-06,
    power=1.0,
    min_lr=0,
    by_epoch=False)
runner = dict(type='IterBasedRunner', max_iters=160000)
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(interval=4000, metric='mIoU', pre_eval=True)
find_unused_parameters = True
#resume_from = '/guest/mnt0/fuzheming/FusionFormer_seg/work_dirs/guidance-base_without_cls_cityscapes/iter_72000.pth'
auto_resume = False
seed = 1440161127
