# Copyright (c) OpenMMLab. All rights reserved.
import unittest
from unittest import TestCase

import torch
from parameterized import parameterized

from mmdet.structures import DetDataSample
from mmdet.testing import demo_mm_inputs, get_detector_cfg
from mmdet.utils import register_all_modules


class TestGLIP(TestCase):

    def setUp(self):
        register_all_modules()

    @parameterized.expand(
        ['glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py'])
    def test_init(self, cfg_file):
        model = get_detector_cfg(cfg_file)
        model.backbone.init_cfg = None

        from mmdet.registry import MODELS
        detector = MODELS.build(model)
        self.assertTrue(detector.backbone)
        self.assertTrue(detector.language_model)
        self.assertTrue(detector.neck)
        self.assertTrue(detector.bbox_head)

    @parameterized.expand([
        ('glip/glip_atss_swin-t_a_fpn_dyhead_pretrain_obj365.py', ('cpu',
                                                                   'cuda'))
    ])
    def test_glip_forward_predict_mode(self, cfg_file, devices):
        model = get_detector_cfg(cfg_file)
        model.backbone.init_cfg = None

        from mmdet.registry import MODELS
        assert all([device in ['cpu', 'cuda'] for device in devices])

        for device in devices:
            detector = MODELS.build(model)

            if device == 'cuda':
                if not torch.cuda.is_available():
                    return unittest.skip('test requires GPU and torch+cuda')
                detector = detector.cuda()

            # test custom_entities is True
            packed_inputs = demo_mm_inputs(
                2, [[3, 128, 128], [3, 125, 130]],
                texts=['a', 'b'],
                custom_entities=True)
            data = detector.data_preprocessor(packed_inputs, False)
            # Test forward test
            detector.eval()
            with torch.no_grad():
                batch_results = detector.forward(**data, mode='predict')
                self.assertEqual(len(batch_results), 2)
                self.assertIsInstance(batch_results[0], DetDataSample)

            # test custom_entities is False
            # packed_inputs = demo_mm_inputs(
            #     2, [[3, 128, 128], [3, 125, 130]],
            #     texts=['a', 'b'],
            #     custom_entities=False)
            # data = detector.data_preprocessor(packed_inputs, False)
            # # Test forward test
            # detector.eval()
            # with torch.no_grad():
            #     batch_results = detector.forward(**data, mode='predict')
            #     self.assertEqual(len(batch_results), 2)
            #     self.assertIsInstance(batch_results[0], DetDataSample)
