from omegaconf import OmegaConf

import detectron2.data.transforms as T
from detectron2.config import LazyCall as L
from detectron2.data import (
    RefCOCOMapper,
    build_detection_test_loader,
    build_detection_train_loader,
    get_detection_dataset_dicts,
)
from detectron2.evaluation import ReferEvaluator

dataloader = OmegaConf.create()

dataloader.train = L(build_detection_train_loader)(
    dataset=L(get_detection_dataset_dicts)(names='grefcoco-unc-train'),
    mapper=L(RefCOCOMapper)(
        is_train=True,
        augmentations=[
            L(T.ResizeShortestEdge)(
                short_edge_length=(640, 672, 704, 736, 768, 800),
                sample_style="choice",
                max_size=1333,
            ),
            L(T.RandomFlip)(horizontal=True),
        ],
        image_format="BGR",
        use_instance_mask=True,
    ),
    total_batch_size=16,
    num_workers=4,
)

grefcoco_test_dataset_names = [
    'grefcoco-unc-train', 
    'grefcoco-unc-val', 
    'grefcoco-unc-testA', 
    'grefcoco-unc-testB'
]
dataloader.tests = [
    L(build_detection_test_loader)(
        dataset=L(get_detection_dataset_dicts)(names=name, filter_empty=False),
        mapper=L(RefCOCOMapper)(
            is_train=False,
            image_format="${....train.mapper.image_format}",
        ),
        num_workers=4,
        batch_size=4,
    )
    for name in grefcoco_test_dataset_names
]

dataloader.evaluators = [
    L(ReferEvaluator)(
        dataset_name=name,
    )
    for name in grefcoco_test_dataset_names
]