# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import time
import warnings

import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
                         wrap_fp16_model)

from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)
from mmdet.models import build_detector
from mmdet.utils import (build_ddp, build_dp, fast_compat_cfg, get_device,
                         replace_cfg_vals, rfnext_init_model,
                         setup_multi_processes, update_data_root)
from mqb_general_process import make_qmodel_for_mmd, prepocess
from mqbench.utils.state import *
import global_placeholder
from mqb_general_process import *
from mmcv.image import tensor2imgs
import numpy as np
from mmdet.core.visualization import imshow_det_bboxes
import sys

import cv2
import matplotlib.pyplot as plt
import mmcv
import numpy as np
from mmdet.core.visualization.image import draw_bboxes, draw_labels, _get_adaptive_scales, get_palette, palette_val

EPS = 1e-2

def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    # parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('pkl_result_path', help='pkl result file')
    # parser.add_argument('quant_config', default=None, help='quant config file path')
    # parser.add_argument('--aqd-mode', type=int, default=0, help='when bigger than 0 , it means switch on aqd, and equals the neck output level num')
    # parser.add_argument('--quantize', 
        # action='store_true', help='quant flag')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    
    parser.add_argument(
        '--work-dir',
        help='the directory to save the file containing evaluation metrics')
    parser.add_argument('--out', help='output result file in pickle format')
    parser.add_argument(
        '--fuse-conv-bn',
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')
    parser.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='(Deprecated, please use --gpu-id) ids of gpus to use '
        '(only applicable to non-distributed training)')
    parser.add_argument(
        '--gpu-id',
        type=int,
        default=0,
        help='id of gpu to use '
        '(only applicable to non-distributed testing)')
    parser.add_argument(
        '--format-only',
        action='store_true',
        help='Format the output results without perform evaluation. It is'
        'useful when you want to format the result to a specific format and '
        'submit it to the test server')
    parser.add_argument(
        '--eval',
        type=str,
        nargs='+',
        help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
        ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
    parser.add_argument('--show', action='store_true', help='show results')
    parser.add_argument(
        '--show-dir', help='directory where painted images will be saved')
    parser.add_argument(
        '--show-score-thr',
        type=float,
        default=0.3,
        help='score threshold (default: 0.3)')
    parser.add_argument(
        '--gpu-collect',
        action='store_true',
        help='whether to use gpu to collect results.')
    parser.add_argument(
        '--tmpdir',
        help='tmp directory used for collecting results from multiple '
        'workers, available when gpu-collect is not specified')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function (deprecate), '
        'change to --eval-options instead.')
    parser.add_argument(
        '--eval-options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    if args.options and args.eval_options:
        raise ValueError(
            '--options and --eval-options cannot be both '
            'specified, --options is deprecated in favor of --eval-options')
    if args.options:
        warnings.warn('--options is deprecated in favor of --eval-options')
        args.eval_options = args.options
    return args


def main():
    args = parse_args()

    assert args.out or args.eval or args.format_only or args.show \
        or args.show_dir, \
        ('Please specify at least one operation (save/eval/format/show the '
         'results / save the results) with the argument "--out", "--eval"'
         ', "--format-only", "--show" or "--show-dir"')

    if args.eval and args.format_only:
        raise ValueError('--eval and --format_only cannot be both specified')

    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

    cfg = Config.fromfile(args.config)

    # replace the ${key} with the value of cfg.key
    cfg = replace_cfg_vals(cfg)

    # update data root according to MMDET_DATASETS
    update_data_root(cfg)

    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    cfg = fast_compat_cfg(cfg)

    # set multi-process settings
    setup_multi_processes(cfg)
    set_random_seed(args.seed)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # if 'pretrained' in cfg.model:
    #     cfg.model.pretrained = None
    # elif 'init_cfg' in cfg.model.backbone:
    #     cfg.model.backbone.init_cfg = None

    # if cfg.model.get('neck'):
    #     if isinstance(cfg.model.neck, list):
    #         for neck_cfg in cfg.model.neck:
    #             if neck_cfg.get('rfp_backbone'):
    #                 if neck_cfg.rfp_backbone.get('pretrained'):
    #                     neck_cfg.rfp_backbone.pretrained = None
    #     elif cfg.model.neck.get('rfp_backbone'):
    #         if cfg.model.neck.rfp_backbone.get('pretrained'):
    #             cfg.model.neck.rfp_backbone.pretrained = None

    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids[0:1]
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                      'Because we only support single GPU mode in '
                      'non-distributed testing. Use the first GPU '
                      'in `gpu_ids` now.')
    else:
        cfg.gpu_ids = [args.gpu_id]
    cfg.device = get_device()
    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        # init_dist(args.launcher, **cfg.dist_params)

    test_dataloader_default_args = dict(
        samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False)

    # in case the test dataset is concatenated
    if isinstance(cfg.data.test, dict):
        cfg.data.test.test_mode = True
        if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.test.pipeline = replace_ImageToTensor(
                cfg.data.test.pipeline)
    elif isinstance(cfg.data.test, list):
        for ds_cfg in cfg.data.test:
            ds_cfg.test_mode = True
        if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
            for ds_cfg in cfg.data.test:
                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)

    test_loader_cfg = {
        **test_dataloader_default_args,
        **cfg.data.get('test_dataloader', {})
    }

    rank, _ = get_dist_info()
    # allows not to create
    if args.work_dir is not None and rank == 0:
        mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
        timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
        json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')

    # build the dataloader
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(dataset, **test_loader_cfg)

    # build the model and load checkpoint
    # cfg.model.train_cfg = None
    # model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
    # global setting
    # if args.aqd_mode != 0:
    #     global_placeholder.modify_AQD_mode(args.aqd_mode)
    
    # 如果是QAT模型，那么就需要提前定义好量化模型结构
    # if args.quantize:
    #     quant_config = prepocess(args.quant_config)
    #     # copy_config_file(args.quant_config, cfg.work_dir)
    #     logger.info(quant_config)
    #     global_placeholder.modify_quant_bit(quant_config.extra_prepare_dict.extra_qconfig_dict.w_qscheme.bit)
    #     global_placeholder.modify_quant_algorithm(quant_config.quantize.quant_algorithm)
    #     global_placeholder.modify_buff_flag(quant_config.training.my_buff_flag)
    #     global_placeholder.modify_range_reg_flag(quant_config.training.range_reg_flag)
    #     global_placeholder.modify_fold_bn_flag(quant_config.training.fold_bn_flag)
        
    #     model.train()
    #     model = make_qmodel_for_mmd(model, quant_config, cfg.trace_config)
        
    # else:
    #     if 'HQOD' in cfg.work_dir:
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！\n")
    #         global_placeholder.modify_buff_flag(1) # 为mypro
    #     elif 'HarDet' in cfg.work_dir:
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！")
    #         logger.info("插播!!!!直接启用harmony！！\n")
    #         global_placeholder.modify_buff_flag(2) # 为hardet
            
    # # 检查 num levels 一致性
    # if global_placeholder.aqd_mode != 0 and model.neck.num_outs != global_placeholder.aqd_mode:
    #     # 说明 num levels给的不对
    #     raise  ValueError(f'num levels给的不对! aqd_mode={global_placeholder.aqd_mode} 而 neck.num_outs={model.neck.num_outs}')
    
    # init rfnext if 'RFSearchHook' is defined in cfg
    # rfnext_init_model(model, cfg=cfg)
    # fp16_cfg = cfg.get('fp16', None)
    # if fp16_cfg is None and cfg.get('device', None) == 'npu':
    #     fp16_cfg = dict(loss_scale='dynamic')
    # if fp16_cfg is not None:
    #     wrap_fp16_model(model)
    # checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
    # if args.fuse_conv_bn:
    #     model = fuse_conv_bn(model)
    # # old versions did not save class info in checkpoints, this walkaround is
    # # for backward compatibility
    # if 'CLASSES' in checkpoint.get('meta', {}):
    #     model.CLASSES = checkpoint['meta']['CLASSES']
    # else:
    #     model.CLASSES = dataset.CLASSES

    
    # if args.quantize:
    #     enable_quantization(model.backbone)
    #     enable_quantization(model.neck)
    #     model_general_architecture = cfg.trace_config.get('model_general_architecture', None)
    #     if model_general_architecture == 'FasterRCNN':
    #         enable_quantization(model.rpn_head)
    #         enable_quantization(model.roi_head.bbox_head)
    #     else:
    #         enable_quantization(model.bbox_head)

    # model.eval()
    # if not distributed:
    #     model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
    #     outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
    #                               args.show_score_thr)
    # else:
    #     model = build_ddp(
    #         model,
    #         cfg.device,
    #         device_ids=[int(os.environ['LOCAL_RANK'])],
    #         broadcast_buffers=False)

    #     # In multi_gpu_test, if tmpdir is None, some tesnors
    #     # will init on cuda by default, and no device choice supported.
    #     # Init a tmpdir to avoid error on npu here.
    #     if cfg.device == 'npu' and args.tmpdir is None:
    #         args.tmpdir = './npu_tmpdir'

    #     outputs = multi_gpu_test(
    #         model, data_loader, args.tmpdir, args.gpu_collect
    #         or cfg.evaluation.get('gpu_collect', False))

    # TODO 加载pkl 文件
    outputs = mmcv.load(args.pkl_result_path)
    print(f'Loaded pkl result : {args.pkl_result_path}')
    
    # # create work_dir
    # mmcv.mkdir_or_exist(osp.abspath(os.path.join(args.show_dir, 'imgs')))
    # single_gpu_draw(outputs, data_loader, out_dir=args.show_dir)
    rank, _ = get_dist_info()
    if rank == 0:
        # if args.out:
        #     print(f'\nwriting results to {args.out}')
        #     mmcv.dump(outputs, args.out)
        kwargs = {} if args.eval_options is None else args.eval_options
        if args.format_only:
            dataset.format_results(outputs, **kwargs)
        if args.eval:
            eval_kwargs = cfg.get('evaluation', {}).copy()
            # hard-code way to remove EvalHook args
            for key in [
                    'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
                    'rule', 'dynamic_intervals'
            ]:
                eval_kwargs.pop(key, None)
            eval_kwargs.update(dict(metric=args.eval, **kwargs))
            metric = dataset.evaluate(outputs, **eval_kwargs)
            print(metric)
            metric_dict = dict(config=args.config, metric=metric)
            if args.work_dir is not None and rank == 0:
                mmcv.dump(metric_dict, json_file)


def set_random_seed(seed, deterministic=True):
    """Set random seed.

    Args:
        seed (int): Seed to be used.
        deterministic (bool): Whether to set the deterministic option for
            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
            to True and `torch.backends.cudnn.benchmark` to False.
            Default: False.
    """
    import random

    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False





def single_gpu_draw(results,
                    data_loader,
                    show=False,
                    out_dir=None,
                    show_score_thr=0.3):
    # model.eval()
    # results = []
    dataset = data_loader.dataset
    PALETTE = getattr(dataset, 'PALETTE', None)
    prog_bar = mmcv.ProgressBar(len(dataset))
    for i, data in enumerate(data_loader):
        # with torch.no_grad():
        #     result = model(return_loss=False, rescale=True, **data)
        
        batch_size = data_loader.batch_size
        
        result = results[i*batch_size:(i+1)*batch_size]
        

        if show or out_dir:
            if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
                img_tensor = data['img'][0]
            else:
                img_tensor = data['img'][0].data[0]
            img_metas = data['img_metas'][0].data[0]
            imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
            assert len(imgs) == len(img_metas)
            
            selected_imgs = None
            # selected_imgs = ['000062.jpg', '000069.jpg', '000074.jpg', '000108.jpg', '000852.jpg'
            #                  ,'001285.jpg','001745.jpg','006193.jpg']  # VOC的

            selected_imgs = ['000000508639.jpg', '000000509403.jpg', '000000512648.jpg', '000000512985.jpg'
                             , '000000513524.jpg', '000000513580.jpg', '000000515350.jpg', '000000515982.jpg'
                             , '000000516316.jpg', '000000516318.jpg', '000000525155.jpg', '000000525322.jpg'
                             , '000000530099.jpg', '000000531707.jpg', '000000532493.jpg', '000000542127.jpg'
                             , '000000542423.jpg', '000000542776.jpg', '000000543528.jpg', '000000546011.jpg'
                             , '000000547383.jpg', '000000547502.jpg', '000000551804.jpg', '000000551815.jpg'
                             , '000000554579.jpg', '000000555705.jpg']  # COCO的

            for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
                if selected_imgs is not None:
                    # 说明是筛选模式
                    matched_flag = has_matched_img(img_meta, selected_imgs)
                    if not matched_flag:
                        # 跳过循环
                        continue
                h, w, _ = img_meta['img_shape']
                img_show = img[:h, :w, :]

                ori_h, ori_w = img_meta['ori_shape'][:-1]
                img_show = mmcv.imresize(img_show, (ori_w, ori_h))

                if out_dir:
                    out_file = osp.join(out_dir, img_meta['ori_filename'])
                else:
                    out_file = None

                # gt_bboxes = np.array([[ 155.  ,  155.  ,  155.   ,  155.  ],],
                # dtype=np.float32)
                
                # gt_labels = np.array([15])
                # gt_masks = None
                # img_show = mmcv.imread(img_show)
                # img_show = img_show.copy()
                # imshow_det_bboxes(
                # img_show,
                # gt_bboxes,
                # gt_labels,
                # gt_masks,
                # class_names=dataset.CLASSES,
                # show=show,
                # wait_time=0,
                # out_file=out_file,
                # bbox_color=dataset.PALETTE,
                # text_color=(200, 200, 200),
                # mask_color=dataset.PALETTE)


                score_thr = 0.5
                
                font_size=16
                thickness = 2
                show_result(
                    dataset,
                    img_show,
                    result[i],
                    bbox_color=PALETTE,
                    text_color=PALETTE,
                    mask_color=PALETTE,
                    show=show,
                    out_file=out_file,
                    score_thr=score_thr,
                    font_size=font_size,
                    thickness=thickness)


        for _ in range(batch_size):
            prog_bar.update()
    return 


def has_matched_img(img_meta, selected_imgs):
    for img_key in selected_imgs:
        if img_key in img_meta['filename']:
            return True
    return False
    
    
def show_result(dataset,
                img,
                result,
                score_thr=0.3,
                bbox_color=(72, 101, 241),
                text_color=(72, 101, 241),
                mask_color=None,
                thickness=2,
                font_size=13,
                win_name='',
                show=False,
                wait_time=0,
                out_file=None):
    """Draw `result` over `img`.

    Args:
        img (str or Tensor): The image to be displayed.
        result (Tensor or tuple): The results to draw over `img`
            bbox_result or (bbox_result, segm_result).
        score_thr (float, optional): Minimum score of bboxes to be shown.
            Default: 0.3.
        bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
            The tuple of color should be in BGR order. Default: 'green'
        text_color (str or tuple(int) or :obj:`Color`):Color of texts.
            The tuple of color should be in BGR order. Default: 'green'
        mask_color (None or str or tuple(int) or :obj:`Color`):
            Color of masks. The tuple of color should be in BGR order.
            Default: None
        thickness (int): Thickness of lines. Default: 2
        font_size (int): Font size of texts. Default: 13
        win_name (str): The window name. Default: ''
        wait_time (float): Value of waitKey param.
            Default: 0.
        show (bool): Whether to show the image.
            Default: False.
        out_file (str or None): The filename to write the image.
            Default: None.

    Returns:
        img (Tensor): Only if not `show` or `out_file`
    """
    img = mmcv.imread(img)
    img = img.copy()
    if isinstance(result, tuple):
        bbox_result, segm_result = result
        if isinstance(segm_result, tuple):
            segm_result = segm_result[0]  # ms rcnn
    else:
        bbox_result, segm_result = result, None
    bboxes = np.vstack(bbox_result)
    labels = [
        np.full(bbox.shape[0], i, dtype=np.int32)
        for i, bbox in enumerate(bbox_result)
    ]
    labels = np.concatenate(labels)
    # draw segmentation masks
    segms = None
    if segm_result is not None and len(labels) > 0:  # non empty
        segms = mmcv.concat_list(segm_result)
        if isinstance(segms[0], torch.Tensor):
            segms = torch.stack(segms, dim=0).detach().cpu().numpy()
        else:
            segms = np.stack(segms, axis=0)
    # if out_file specified, do not show image in window
    if out_file is not None:
        show = False
    # draw bounding boxes
    
    # bboxes = np.array([[182.12, 157.18, 182.12+354.34, 157.18+194.82, 1. ],  # GT
    #                    [ 132.  ,  125.  ,  557.   ,  367. , .80 ],  # IoU坏但cls好  蓝
    #                 #    [ 177.  ,  170.  ,  522.   ,  360. , .39 ],  # IoU好但cls坏  黄
    #                    ],  # x y x y
    # dtype=np.float32)

    # # inter = (bboxes[2][2] - bboxes[2][0])*(bboxes[2][3]-bboxes[2][1])
    # # union = (bboxes[1][2] - bboxes[1][0])*(bboxes[1][3]-bboxes[1][1])
    # # print('iou', inter / union)

    # labels = np.array([15, 
    #                 #    15,
    #                    15])
    
    # bbox_color = [(254/ 255, 2/ 255, 2/ 255) ,
    #               (0/ 255, 176/ 255, 243/ 255) ,
    #             #   (255/ 255, 200/ 255, 100/ 255) 
    #               ]
    # text_color = [(1, 1, 1),
    #             #   (1, 1, 1),
    #               (1, 1, 1)
    #               ]
    # cls_name = list(dataset.CLASSES)
    # cls_name[15] = 'Score'
    img = imshow_det_bboxes(
        img,
        bboxes,
        labels,
        segms,
        class_names=dataset.CLASSES,
        score_thr=score_thr,
        bbox_color=bbox_color,
        text_color=text_color,
        mask_color=mask_color,
        thickness=thickness,
        font_size=font_size,
        win_name=win_name,
        show=show,
        wait_time=wait_time,
        out_file=out_file)

    if not (show or out_file):
        return img



def imshow_det_bboxes(img,
                      bboxes=None,
                      labels=None,
                      segms=None,
                      class_names=None,
                      score_thr=0,
                      bbox_color='green',
                      text_color='green',
                      mask_color=None,
                      thickness=2,
                      font_size=8,
                      win_name='',
                      show=True,
                      wait_time=0,
                      out_file=None):
    """Draw bboxes and class labels (with scores) on an image.

    Args:
        img (str | ndarray): The image to be displayed.
        bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
            (n, 5).
        labels (ndarray): Labels of bboxes.
        segms (ndarray | None): Masks, shaped (n,h,w) or None.
        class_names (list[str]): Names of each classes.
        score_thr (float): Minimum score of bboxes to be shown. Default: 0.
        bbox_color (list[tuple] | tuple | str | None): Colors of bbox lines.
           If a single color is given, it will be applied to all classes.
           The tuple of color should be in RGB order. Default: 'green'.
        text_color (list[tuple] | tuple | str | None): Colors of texts.
           If a single color is given, it will be applied to all classes.
           The tuple of color should be in RGB order. Default: 'green'.
        mask_color (list[tuple] | tuple | str | None, optional): Colors of
           masks. If a single color is given, it will be applied to all
           classes. The tuple of color should be in RGB order.
           Default: None.
        thickness (int): Thickness of lines. Default: 2.
        font_size (int): Font size of texts. Default: 13.
        show (bool): Whether to show the image. Default: True.
        win_name (str): The window name. Default: ''.
        wait_time (float): Value of waitKey param. Default: 0.
        out_file (str, optional): The filename to write the image.
            Default: None.

    Returns:
        ndarray: The image with bboxes drawn on it.
    """
    assert bboxes is None or bboxes.ndim == 2, \
        f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
    assert labels.ndim == 1, \
        f' labels ndim should be 1, but its ndim is {labels.ndim}.'
    assert bboxes is None or bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \
        f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.'
    assert bboxes is None or bboxes.shape[0] <= labels.shape[0], \
        'labels.shape[0] should not be less than bboxes.shape[0].'
    assert segms is None or segms.shape[0] == labels.shape[0], \
        'segms.shape[0] and labels.shape[0] should have the same length.'
    assert segms is not None or bboxes is not None, \
        'segms and bboxes should not be None at the same time.'

    img = mmcv.imread(img).astype(np.uint8)

    if score_thr > 0:
        assert bboxes is not None and bboxes.shape[1] == 5
        scores = bboxes[:, -1]
        inds = scores > score_thr
        bboxes = bboxes[inds, :]
        labels = labels[inds]
        if segms is not None:
            segms = segms[inds, ...]

    img = mmcv.bgr2rgb(img)
    width, height = img.shape[1], img.shape[0]
    img = np.ascontiguousarray(img)

    fig = plt.figure(win_name, frameon=False)
    plt.title(win_name)
    canvas = fig.canvas
    dpi = fig.get_dpi()
    # add a small EPS to avoid precision lost due to matplotlib's truncation
    # (https://github.com/matplotlib/matplotlib/issues/15363)
    fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)

    # remove white edges by set subplot margin
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    ax = plt.gca()
    ax.axis('off')

    max_label = int(max(labels) if len(labels) > 0 else 0)
    text_palette = palette_val(get_palette(text_color, max_label + 1))
    text_colors = [text_palette[label] for label in labels]

    num_bboxes = 0
    if bboxes is not None:
        num_bboxes = bboxes.shape[0]
        bbox_palette = palette_val(get_palette(bbox_color, max_label + 1))
        colors = [bbox_palette[label] for label in labels[:num_bboxes]]
        draw_bboxes(ax, bboxes, colors, alpha=0.8, thickness=thickness)

        horizontal_alignment = 'left'
        positions = bboxes[:, :2].astype(np.int32) + thickness
        areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
        scales = _get_adaptive_scales(areas)
        scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None
        draw_labels(
            ax,
            labels[:num_bboxes],
            positions,
            scores=scores,
            class_names=class_names,
            color=text_colors,
            font_size=font_size,
            scales=scales,
            horizontal_alignment=horizontal_alignment)

    if segms is not None:
        mask_palette = get_palette(mask_color, max_label + 1)
        colors = [mask_palette[label] for label in labels]
        colors = np.array(colors, dtype=np.uint8)
        draw_masks(ax, img, segms, colors, with_edge=True)

        if num_bboxes < segms.shape[0]:
            segms = segms[num_bboxes:]
            horizontal_alignment = 'center'
            areas = []
            positions = []
            for mask in segms:
                _, _, stats, centroids = cv2.connectedComponentsWithStats(
                    mask.astype(np.uint8), connectivity=8)
                largest_id = np.argmax(stats[1:, -1]) + 1
                positions.append(centroids[largest_id])
                areas.append(stats[largest_id, -1])
            areas = np.stack(areas, axis=0)
            scales = _get_adaptive_scales(areas)
            draw_labels(
                ax,
                labels[num_bboxes:],
                positions,
                class_names=class_names,
                color=text_colors,
                font_size=font_size,
                scales=scales,
                horizontal_alignment=horizontal_alignment)

    plt.imshow(img)

    stream, _ = canvas.print_to_buffer()
    buffer = np.frombuffer(stream, dtype='uint8')
    if sys.platform == 'darwin':
        width, height = canvas.get_width_height(physical=True)
    img_rgba = buffer.reshape(height, width, 4)
    rgb, alpha = np.split(img_rgba, [3], axis=2)
    img = rgb.astype('uint8')
    img = mmcv.rgb2bgr(img)

    if show:
        # We do not use cv2 for display because in some cases, opencv will
        # conflict with Qt, it will output a warning: Current thread
        # is not the object's thread. You can refer to
        # https://github.com/opencv/opencv-python/issues/46 for details
        if wait_time == 0:
            plt.show()
        else:
            plt.show(block=False)
            plt.pause(wait_time)
    if out_file is not None:
        mmcv.imwrite(img, out_file)

    plt.close()

    return img


def draw_labels(ax,
                labels,
                positions,
                scores=None,
                class_names=None,
                color='w',
                font_size=8,
                scales=None,
                horizontal_alignment='left'):
    """Draw labels on the axes.

    Args:
        ax (matplotlib.Axes): The input axes.
        labels (ndarray): The labels with the shape of (n, ).
        positions (ndarray): The positions to draw each labels.
        scores (ndarray): The scores for each labels.
        class_names (list[str]): The class names.
        color (list[tuple] | matplotlib.color): The colors for labels.
        font_size (int): Font size of texts. Default: 8.
        scales (list[float]): Scales of texts. Default: None.
        horizontal_alignment (str): The horizontal alignment method of
            texts. Default: 'left'.

    Returns:
        matplotlib.Axes: The result axes.
    """
    for i, (pos, label) in enumerate(zip(positions, labels)):
        label_text = class_names[
            label] if class_names is not None else f'class {label}'
        if scores is not None:
            label_text += f'|{scores[i]:.02f}'
        text_color = color[i] if isinstance(color, list) else color

        font_size_mask = font_size if scales is None else font_size * scales[i]
        ax.text(
            pos[0],
            pos[1],
            f'{label_text}',
            bbox={
                'facecolor': 'black',
                'alpha': 0.6,
                'pad': 0.7,
                'edgecolor': 'none'
            },
            color=text_color,
            fontsize=font_size_mask,
            verticalalignment='top',
            horizontalalignment=horizontal_alignment)

    return ax






if __name__ == '__main__':
    main()
