# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from typing import Tuple

import cv2
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.transforms import Compose
from mmengine.utils import track_iter_progress

from mmdet.apis import init_detector
from mmdet.registry import VISUALIZERS
from mmdet.structures import DetDataSample

try:
    import ffmpegcv
except ImportError:
    raise ImportError(
        'Please install ffmpegcv with:\n\n    pip install ffmpegcv')


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDetection video demo with GPU acceleration')
    parser.add_argument('video', help='Video file')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='Bbox score threshold')
    parser.add_argument('--out', type=str, help='Output video file')
    parser.add_argument('--show', action='store_true', help='Show video')
    parser.add_argument(
        '--nvdecode', action='store_true', help='Use NVIDIA decoder')
    parser.add_argument(
        '--wait-time',
        type=float,
        default=1,
        help='The interval of show (s), 0 is block')
    args = parser.parse_args()
    return args


def prefetch_batch_input_shape(model: nn.Module, ori_wh: Tuple[int,
                                                               int]) -> dict:
    cfg = model.cfg
    w, h = ori_wh
    cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
    test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
    data = {'img': np.zeros((h, w, 3), dtype=np.uint8), 'img_id': 0}
    data = test_pipeline(data)
    _, data_sample = model.data_preprocessor([data], False)
    batch_input_shape = data_sample[0].batch_input_shape
    return batch_input_shape


def pack_data(frame_resize: np.ndarray, batch_input_shape: Tuple[int, int],
              ori_shape: Tuple[int, int]) -> dict:
    assert frame_resize.shape[:2] == batch_input_shape
    data_sample = DetDataSample()
    data_sample.set_metainfo({
        'img_shape':
        batch_input_shape,
        'ori_shape':
        ori_shape,
        'scale_factor': (batch_input_shape[0] / ori_shape[0],
                         batch_input_shape[1] / ori_shape[1])
    })
    frame_resize = torch.from_numpy(frame_resize).permute((2, 0, 1))
    data = {'inputs': frame_resize, 'data_sample': data_sample}
    return data


def main():
    args = parse_args()
    assert args.out or args.show, \
        ('Please specify at least one operation (save/show the '
         'video) with the argument "--out" or "--show"')

    model = init_detector(args.config, args.checkpoint, device=args.device)

    # init visualizer
    visualizer = VISUALIZERS.build(model.cfg.visualizer)
    # the dataset_meta is loaded from the checkpoint and
    # then pass to the model in init_detector
    visualizer.dataset_meta = model.dataset_meta

    if args.nvdecode:
        VideoCapture = ffmpegcv.VideoCaptureNV
    else:
        VideoCapture = ffmpegcv.VideoCapture
    video_origin = VideoCapture(args.video)

    batch_input_shape = prefetch_batch_input_shape(
        model, (video_origin.width, video_origin.height))
    ori_shape = (video_origin.height, video_origin.width)
    resize_wh = batch_input_shape[::-1]
    video_resize = VideoCapture(
        args.video,
        resize=resize_wh,
        resize_keepratio=True,
        resize_keepratioalign='topleft')

    video_writer = None
    if args.out:
        video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps)

    with torch.no_grad():
        for i, (frame_resize, frame_origin) in enumerate(
                zip(track_iter_progress(video_resize), video_origin)):
            data = pack_data(frame_resize, batch_input_shape, ori_shape)
            result = model.test_step([data])[0]

            visualizer.add_datasample(
                name='video',
                image=frame_origin,
                data_sample=result,
                draw_gt=False,
                show=False,
                pred_score_thr=args.score_thr)

            frame_mask = visualizer.get_image()

            if args.show:
                cv2.namedWindow('video', 0)
                mmcv.imshow(frame_mask, 'video', args.wait_time)
            if args.out:
                video_writer.write(frame_mask)

    if video_writer:
        video_writer.release()
    video_origin.release()
    video_resize.release()

    cv2.destroyAllWindows()


if __name__ == '__main__':
    main()
