
_base_ = [
    '../_base_/models/segmenter_vit-b16_mask_prompt.py',
    '../_base_/datasets/brats.py',  
    '../_base_/default_runtime.py',
    '../_base_/schedules/schedule_40e_cosine.py'
]

crop_size = (512, 512)

data_root = '/your_dataset'  
train_ann = 'ImageSets/Segmentation/train_80.txt'
test_ann = 'ImageSets/Segmentation/test.txt'

data_preprocessor = dict(
    type='SegDataPreProcessor',
    size=crop_size,
    mean=[123.675, 116.28, 103.53],
    std=[58.395, 57.12, 57.375],
    bgr_to_rgb=True,
    pad_val=0,
    seg_pad_val=255
)

model = dict(
    data_preprocessor=data_preprocessor,
    backbone=dict(
        prompt_length=5,
        prompt_dropout=0.1,
    )
)

optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=1e-4, weight_decay=0.01),
    paramwise_cfg=dict(
        custom_keys={
            'prompt_embeddings': dict(lr_mult=1.0),
            'decode_head': dict(lr_mult=1.0),
            'backbone': dict(lr_mult=0.0),  
        }
    )
)

train_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        data_root=data_root,
        ann_file=train_ann,
        data_prefix=dict(
            img_path='JPEGImages',
            seg_map_path='SegmentationClass',
            img_suffix='.png',
            seg_map_suffix='.png'
        )
    )
)

val_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        data_root=data_root,
        ann_file=test_ann,
        data_prefix=dict(
            img_path='JPEGImages',
            seg_map_path='SegmentationClass',
            img_suffix='.png',
            seg_map_suffix='.png'
        )
    )
)
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])
test_evaluator = val_evaluator

work_dir = './work_dirs/segmenter_vit_brats_80_prompt'

custom_hooks = [
    dict(type='LogTrainableParamsHook', priority='LOW')
]
