# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch

from mmdet.registry import MODELS, TASK_UTILS
from mmdet.testing import demo_track_inputs, random_boxes
from mmdet.utils import register_all_modules


class TestByteTracker(TestCase):

    @classmethod
    def setUpClass(cls):
        register_all_modules(init_default_scope=True)
        cfg = dict(
            type='ByteTracker',
            motion=dict(type='KalmanFilter'),
            obj_score_thrs=dict(high=0.6, low=0.1),
            init_track_thr=0.7,
            weight_iou_with_det_scores=True,
            match_iou_thrs=dict(high=0.1, low=0.5, tentative=0.3),
            num_tentatives=3,
            num_frames_retain=30)
        cls.tracker = MODELS.build(cfg)
        cls.tracker.kf = TASK_UTILS.build(dict(type='KalmanFilter'))
        cls.num_frames_retain = cfg['num_frames_retain']
        cls.num_objs = 30

    def test_init(self):
        bboxes = random_boxes(self.num_objs, 512)
        labels = torch.zeros(self.num_objs)
        scores = torch.ones(self.num_objs)
        ids = torch.arange(self.num_objs)
        self.tracker.update(
            ids=ids, bboxes=bboxes, scores=scores, labels=labels, frame_ids=0)

        assert self.tracker.ids == list(ids)
        assert self.tracker.memo_items == [
            'ids', 'bboxes', 'scores', 'labels', 'frame_ids'
        ]

    def test_track(self):

        with torch.no_grad():
            packed_inputs = demo_track_inputs(batch_size=1, num_frames=2)
            track_data_sample = packed_inputs['data_samples'][0]
            video_len = len(track_data_sample)
            for frame_id in range(video_len):
                img_data_sample = track_data_sample[frame_id]
                img_data_sample.pred_instances = \
                    img_data_sample.gt_instances.clone()
                # add fake scores
                scores = torch.ones(len(img_data_sample.gt_instances.bboxes))
                img_data_sample.pred_instances.scores = torch.FloatTensor(
                    scores)

                pred_track_instances = self.tracker.track(
                    data_sample=img_data_sample)

                bboxes = pred_track_instances.bboxes
                labels = pred_track_instances.labels

                assert bboxes.shape[1] == 4
                assert bboxes.shape[0] == labels.shape[0]
