# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import unittest
from unittest.mock import MagicMock, patch

import pytest

from mmdet.datasets import DATASETS


@patch('mmdet.datasets.CocoDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.CustomDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.XMLDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.CityscapesDataset.load_annotations', MagicMock())
@patch('mmdet.datasets.CocoDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.CustomDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.XMLDataset._filter_imgs', MagicMock)
@patch('mmdet.datasets.CityscapesDataset._filter_imgs', MagicMock)
@pytest.mark.parametrize('dataset',
                         ['CocoDataset', 'VOCDataset', 'CityscapesDataset'])
def test_custom_classes_override_default(dataset):
    dataset_class = DATASETS.get(dataset)
    if dataset in ['CocoDataset', 'CityscapesDataset']:
        dataset_class.coco = MagicMock()
        dataset_class.cat_ids = MagicMock()

    original_classes = dataset_class.CLASSES

    # Test setting classes as a tuple
    custom_dataset = dataset_class(
        ann_file=MagicMock(),
        pipeline=[],
        classes=('bus', 'car'),
        test_mode=True,
        img_prefix='VOC2007' if dataset == 'VOCDataset' else '')

    assert custom_dataset.CLASSES != original_classes
    assert custom_dataset.CLASSES == ('bus', 'car')
    print(custom_dataset)

    # Test setting classes as a list
    custom_dataset = dataset_class(
        ann_file=MagicMock(),
        pipeline=[],
        classes=['bus', 'car'],
        test_mode=True,
        img_prefix='VOC2007' if dataset == 'VOCDataset' else '')

    assert custom_dataset.CLASSES != original_classes
    assert custom_dataset.CLASSES == ['bus', 'car']
    print(custom_dataset)

    # Test overriding not a subset
    custom_dataset = dataset_class(
        ann_file=MagicMock(),
        pipeline=[],
        classes=['foo'],
        test_mode=True,
        img_prefix='VOC2007' if dataset == 'VOCDataset' else '')

    assert custom_dataset.CLASSES != original_classes
    assert custom_dataset.CLASSES == ['foo']
    print(custom_dataset)

    # Test default behavior
    custom_dataset = dataset_class(
        ann_file=MagicMock(),
        pipeline=[],
        classes=None,
        test_mode=True,
        img_prefix='VOC2007' if dataset == 'VOCDataset' else '')

    assert custom_dataset.CLASSES == original_classes
    print(custom_dataset)

    # Test sending file path
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdir:
        path = tmpdir + 'classes.txt'
        with open(path, 'w') as f:
            f.write('bus\ncar\n')
    custom_dataset = dataset_class(
        ann_file=MagicMock(),
        pipeline=[],
        classes=path,
        test_mode=True,
        img_prefix='VOC2007' if dataset == 'VOCDataset' else '')

    assert custom_dataset.CLASSES != original_classes
    assert custom_dataset.CLASSES == ['bus', 'car']
    print(custom_dataset)


class CustomDatasetTests(unittest.TestCase):

    def setUp(self):
        super().setUp()
        self.data_dir = osp.join(
            osp.dirname(osp.dirname(osp.dirname(__file__))), 'data')
        self.dataset_class = DATASETS.get('XMLDataset')

    def test_data_infos__default_db_directories(self):
        """Test correct data read having a Pacal-VOC directory structure."""
        test_dataset_root = osp.join(self.data_dir, 'VOCdevkit', 'VOC2007')
        custom_ds = self.dataset_class(
            data_root=test_dataset_root,
            ann_file=osp.join(test_dataset_root, 'ImageSets', 'Main',
                              'trainval.txt'),
            pipeline=[],
            classes=('person', 'dog'),
            test_mode=True)

        self.assertListEqual([{
            'id': '000001',
            'filename': osp.join('JPEGImages', '000001.jpg'),
            'width': 353,
            'height': 500
        }], custom_ds.data_infos)

    def test_data_infos__overridden_db_subdirectories(self):
        """Test correct data read having a customized directory structure."""
        test_dataset_root = osp.join(self.data_dir, 'custom_dataset')
        custom_ds = self.dataset_class(
            data_root=test_dataset_root,
            ann_file=osp.join(test_dataset_root, 'trainval.txt'),
            pipeline=[],
            classes=('person', 'dog'),
            test_mode=True,
            img_prefix='',
            img_subdir='images',
            ann_subdir='images')

        self.assertListEqual([{
            'id': '000001',
            'filename': osp.join('images', '000001.jpg'),
            'width': 353,
            'height': 500
        }], custom_ds.data_infos)
