_base_ = [
    '../_base_/models/segmenter_vit-b16_mask.py',
    '../_base_/datasets/brats.py', 
    '../_base_/default_runtime.py',
    '../_base_/schedules/schedule_40e_cosine.py' 
]


data_root = '/your_dataset'  
train_ann = 'ImageSets/Segmentation/train_80.txt'
test_ann = 'ImageSets/Segmentation/test.txt'


data_preprocessor = dict(
    type='SegDataPreProcessor',
    size=(512, 512),
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    bgr_to_rgb=True,
    pad_val=0,
    seg_pad_val=255
)

model = dict(
    data_preprocessor=data_preprocessor,
    pretrained='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segmenter/vit_base_p16_384_20220308-96dfe169.pth',
    decode_head=dict(num_classes=2)
)

optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=1e-4, weight_decay=0.01)  
)

train_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        data_root=data_root,
        ann_file=train_ann,
        data_prefix=dict(
            img_path='JPEGImages',
            seg_map_path='SegmentationClass',
            img_suffix='.png',
            seg_map_suffix='.png'
        )
    )
)

val_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        data_root=data_root,
        ann_file=test_ann,
        data_prefix=dict(
            img_path='JPEGImages',
            seg_map_path='SegmentationClass',
            img_suffix='.png',
            seg_map_suffix='.png'
        )
    )
)

test_dataloader = val_dataloader


val_evaluator = dict(
    type='IoUMetric',
    iou_metrics=['mIoU', 'mDice', 'mFscore']
)
test_evaluator = val_evaluator


work_dir = './work_dirs/segmenter_vit_b16_brats_80_full'

custom_hooks = [
    dict(type='LogTrainableParamsHook', priority='LOW')
]
