# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch

from mmdet.registry import MODELS
from mmdet.structures import ReIDDataSample
from mmdet.utils import register_all_modules


class TestLinearReIDHead(TestCase):

    @classmethod
    def setUpClass(cls) -> None:
        register_all_modules()
        head_cfg = dict(
            type='LinearReIDHead',
            num_fcs=1,
            in_channels=128,
            fc_channels=64,
            out_channels=32,
            num_classes=2,
            loss_cls=dict(type='mmpretrain.CrossEntropyLoss', loss_weight=1.0),
            loss_triplet=dict(type='TripletLoss', margin=0.3, loss_weight=1.0),
            norm_cfg=dict(type='BN1d'),
            act_cfg=dict(type='ReLU'))
        cls.head = MODELS.build(head_cfg)
        cls.inputs = (torch.rand(4, 128), torch.rand(4, 128))
        cls.data_samples = [
            ReIDDataSample().set_gt_label(label) for label in (0, 0, 1, 1)
        ]

    def test_forward(self):
        outputs = self.head(self.inputs)
        assert outputs.shape == (4, 32)

    def test_loss(self):
        losses = self.head.loss(self.inputs, self.data_samples)
        assert losses.keys() == {'triplet_loss', 'ce_loss', 'accuracy_top-1'}
        assert losses['ce_loss'].item() >= 0
        assert losses['triplet_loss'].item() >= 0

    def test_predict(self):
        predictions = self.head.predict(self.inputs, self.data_samples)
        for pred in predictions:
            assert isinstance(pred, ReIDDataSample)
            assert isinstance(pred.pred_feature, torch.Tensor)
            assert isinstance(pred.gt_label.label, torch.Tensor)
            assert pred.pred_feature.shape == (32, )
