_base_ = [
    '../_base_/models/segmenter_vit-b16_mask_dka.py',  
    '../_base_/datasets/brats.py', 
    '../_base_/default_runtime.py',
    '../_base_/schedules/schedule_40e_cosine.py'
]

crop_size = (512, 512)

data_root = '/your_dataset' 
train_ann = 'ImageSets/Segmentation/train_0.5.txt'
test_ann = 'ImageSets/Segmentation/test.txt'

data_preprocessor = dict(
    type='SegDataPreProcessor',
    size=crop_size,
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    bgr_to_rgb=True,
    pad_val=0,
    seg_pad_val=255
)

model = dict(
    data_preprocessor=data_preprocessor,
    backbone=dict(
        type='VisionTransformer_dka',
        img_size=(512, 512),
        patch_size=16,
        in_channels=3,
        embed_dims=768,
        num_layers=12,
        num_heads=12,
        drop_path_rate=0.1,
        attn_drop_rate=0.0,
        drop_rate=0.0,
        final_norm=True,
        norm_cfg=dict(type='LN', eps=1e-6),
        with_cls_token=True,
        interpolate_mode='bicubic',
        out_indices=[11],
        dka_cfg=dict(
            type='Dka',
            input_dim=768,
            middle_dim=192
        ),
    ),
    decode_head=dict(
        type='SegmenterMaskTransformerHead',
        in_channels=768,
        channels=768,
        num_classes=2,
        num_layers=2,
        num_heads=12,
        embed_dims=768,
        dropout_ratio=0.0,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
    ),
    test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(480, 480))
)

optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=1e-4, weight_decay=0.01),
    paramwise_cfg=dict(
        custom_keys={
            'backbone.dka': dict(lr_mult=10.0),  
            'decode_head': dict(lr_mult=1.0),
            'backbone': dict(lr_mult=0.0)
        }
    )
)


train_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        data_root=data_root,
        ann_file=train_ann,
        data_prefix=dict(
            img_path='JPEGImages',
            seg_map_path='SegmentationClass',
            img_suffix='.png',
            seg_map_suffix='.png'
        )
    )
)

val_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        data_root=data_root,
        ann_file=test_ann,
        data_prefix=dict(
            img_path='JPEGImages',
            seg_map_path='SegmentationClass',
            img_suffix='.png',
            seg_map_suffix='.png'
        )
    )
)
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])
test_evaluator = val_evaluator

work_dir = './work_dirs/segmenter_vit_brats_80_dka'

custom_hooks = [
    dict(type='LogTrainableParamsHook', priority='LOW')
]
