import os.path as osp
import warnings

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import pycocotools.mask as mask_util
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
import cv2

from ..utils import mask2ndarray


def color_val_matplotlib(color):
    """Convert various input in BGR order to normalized RGB matplotlib color
    tuples,

    Args:
        color (:obj:`Color`/str/tuple/int/ndarray): Color inputs

    Returns:
        tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
    """
    color = mmcv.color_val(color)
    color = [color / 255 for color in color[::-1]]
    return tuple(color)


def imshow_det_bboxes(img,
                      bboxes,
                      labels,
                      segms=None,
                      class_names=None,
                      score_thr=0,
                      bbox_color='green',
                      text_color='green',
                      mask_color=None,
                      thickness=2,
                      font_scale=0.5,
                      font_size=13,
                      win_name='',
                      fig_size=(15, 10),
                      show=True,
                      wait_time=0,
                      out_file=None):
    """Draw bboxes and class labels (with scores) on an image.

    Args:
        img (str or ndarray): The image to be displayed.
        bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
            (n, 5).
        labels (ndarray): Labels of bboxes.
        segms (ndarray or None): Masks, shaped (n,h,w) or None
        class_names (list[str]): Names of each classes.
        score_thr (float): Minimum score of bboxes to be shown.  Default: 0
        bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
           The tuple of color should be in BGR order. Default: 'green'
        text_color (str or tuple(int) or :obj:`Color`):Color of texts.
           The tuple of color should be in BGR order. Default: 'green'
        mask_color (str or tuple(int) or :obj:`Color`, optional):
           Color of masks. The tuple of color should be in BGR order.
           Default: None
        thickness (int): Thickness of lines. Default: 2
        font_scale (float): Font scales of texts. Default: 0.5
        font_size (int): Font size of texts. Default: 13
        show (bool): Whether to show the image. Default: True
        win_name (str): The window name. Default: ''
        fig_size (tuple): Figure size of the pyplot figure. Default: (15, 10)
        wait_time (float): Value of waitKey param. Default: 0.
        out_file (str, optional): The filename to write the image.
            Default: None

    Returns:
        ndarray: The image with bboxes drawn on it.
    """
    warnings.warn('"font_scale" will be deprecated in v2.9.0,'
                  'Please use "font_size"')
    assert bboxes.ndim == 2, \
        f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
    assert labels.ndim == 1, \
        f' labels ndim should be 1, but its ndim is {labels.ndim}.'
    assert bboxes.shape[0] == labels.shape[0], \
        'bboxes.shape[0] and labels.shape[0] should have the same length.'
    assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \
        f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.'
    img = mmcv.imread(img).copy()

    if score_thr > 0:
        assert bboxes.shape[1] == 5
        scores = bboxes[:, -1]
        scores = np.sqrt(scores)
        inds = scores > score_thr
        bboxes = bboxes[inds, :][0:1]
        labels = labels[inds][0:1]
        if segms is not None:
            segms = segms[inds, ...]

    mask_colors = []
    if labels.shape[0] > 0:
        if mask_color is None:
            # random color
            np.random.seed(42)
            mask_colors = [
                np.random.randint(0, 256, (1, 3), dtype=np.uint8)
                for _ in range(max(labels) + 1)
            ]
        else:
            # specify  color
            mask_colors = [
                np.array(mmcv.color_val(mask_color)[::-1], dtype=np.uint8)
            ] * (
                max(labels) + 1)

    bbox_color = color_val_matplotlib(bbox_color)
    text_color = color_val_matplotlib(text_color)

    img = mmcv.bgr2rgb(img)
    img = np.ascontiguousarray(img)

    vis_img = img.copy()

    plt.figure(win_name, figsize=fig_size)
    plt.title(win_name)
    plt.axis('off')
    ax = plt.gca()

    polygons = []
    color = []
    for i, (bbox, label) in enumerate(zip(bboxes, labels)):
        bbox_int = bbox.astype(np.int32)
        poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
                [bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
        np_poly = np.array(poly).reshape((4, 2))
        polygons.append(Polygon(np_poly))
        color.append(bbox_color)
        label_text = class_names[
            label] if class_names is not None else f'Obj. {label}'
        if len(bbox) > 4:
            label_text += f'|{bbox[-1]:.02f}'
        ax.text(
            bbox_int[0],
            bbox_int[1],
            f'{label_text}',
            bbox={
                'facecolor': 'black',
                'alpha': 0.6,
                'pad': 0.7,
                'edgecolor': 'none'
            },
            color=text_color,
            fontsize=font_size,
            verticalalignment='top',
            horizontalalignment='left')
        if segms is not None:
            color_mask = mask_colors[labels[i]]
            mask = segms[i].astype(bool)
            img[mask] = img[mask] * 0.5 + color_mask * 0.5
        if i==0:
            vis_img = cv2.rectangle(vis_img, (bbox_int[0], bbox_int[1]), (bbox_int[2], bbox_int[3]), (36,255,12), thickness)
            vis_img = cv2.putText(vis_img, f'{label_text}', (bbox_int[0], bbox_int[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (36,255,12), thickness)

    plt.imshow(img)

    p = PatchCollection(
        polygons, facecolor='none', edgecolors=color, linewidths=thickness)
    ax.add_collection(p)

    if out_file is not None:
        dir_name = osp.abspath(osp.dirname(out_file))
        mmcv.mkdir_or_exist(dir_name)
        plt.savefig(out_file)
        if not show:
            plt.close()
    if show:
        if wait_time == 0:
            plt.show()
        else:
            plt.show(block=False)
            plt.pause(wait_time)
            plt.close()
    return mmcv.rgb2bgr(vis_img)

def imshow_det_multi_bboxes(img,
                      bboxes,
                      labels,
                      gt_box=None,
                      segms=None,
                      class_names=None,
                      score_thr=0,
                      bbox_color='green',
                      text_color='green',
                      mask_color=None,
                      thickness=2,
                      font_scale=0.5,
                      font_size=13,
                      win_name='',
                      fig_size=(15, 10),
                      show=True,
                      wait_time=0,
                      out_file=None):
    """Draw bboxes and class labels (with scores) on an image.

    Args:
        img (str or ndarray): The image to be displayed.
        bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
            (n, 5).
        labels (ndarray): Labels of bboxes.
        segms (ndarray or None): Masks, shaped (n,h,w) or None
        class_names (list[str]): Names of each classes.
        score_thr (float): Minimum score of bboxes to be shown.  Default: 0
        bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
           The tuple of color should be in BGR order. Default: 'green'
        text_color (str or tuple(int) or :obj:`Color`):Color of texts.
           The tuple of color should be in BGR order. Default: 'green'
        mask_color (str or tuple(int) or :obj:`Color`, optional):
           Color of masks. The tuple of color should be in BGR order.
           Default: None
        thickness (int): Thickness of lines. Default: 2
        font_scale (float): Font scales of texts. Default: 0.5
        font_size (int): Font size of texts. Default: 13
        show (bool): Whether to show the image. Default: True
        win_name (str): The window name. Default: ''
        fig_size (tuple): Figure size of the pyplot figure. Default: (15, 10)
        wait_time (float): Value of waitKey param. Default: 0.
        out_file (str, optional): The filename to write the image.
            Default: None

    Returns:
        ndarray: The image with bboxes drawn on it.
    """
    warnings.warn('"font_scale" will be deprecated in v2.9.0,'
                  'Please use "font_size"')
    img = mmcv.imread(img).copy()

    if score_thr > 0:
        scores = [bboxes[i][:, -1] for i in range(len(bboxes))]
        scores = [np.sqrt(scores[i]) for i in range(len(scores))]
        inds = [scores[i] > score_thr for i in range(len(scores))]
        bboxes = [bboxes[i][inds[i], :][0:1] for i in range(len(bboxes))]
        labels = [labels[i][inds[i]][0:1] for i in range(len(labels))]

    bbox_color = [color_val_matplotlib(bbox_color[i]) for i in range(len(bbox_color))]
    text_color = [color_val_matplotlib(text_color[i]) for i in range(len(text_color))]

    img = mmcv.bgr2rgb(img)
    img = np.ascontiguousarray(img)

    vis_img = img.copy()

    plt.figure(win_name, figsize=fig_size)
    plt.title(win_name)
    plt.axis('off')
    ax = plt.gca()

    polygons = []
    color = []

    if isinstance(gt_box, np.ndarray):
        gt_box = gt_box.astype(np.int32)
        poly = [[gt_box[0], gt_box[1]], [gt_box[0], gt_box[3]],
                    [gt_box[2], gt_box[3]], [gt_box[2], gt_box[1]]]
        np_poly = np.array(poly).reshape((4, 2))
        polygons.append(Polygon(np_poly))
        color.append((0, 1.0, 0))
    else:
        for k in range(len(bboxes)-1, -1, -1):
            for i, (bbox, label) in enumerate(zip(bboxes[k], labels[k])):
                bbox_int = bbox.astype(np.int32)
                poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
                        [bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
                np_poly = np.array(poly).reshape((4, 2))
                polygons.append(Polygon(np_poly))
                color.append(bbox_color[k])
                label_text = class_names[
                    label] if class_names is not None else f'Obj. {label}'
                if len(bbox) > 4:
                    label_text += f'|{np.sqrt(bbox[-1]):.02f}'
                ax.text(
                    bbox_int[0],
                    bbox_int[1]-font_size,
                    f'{label_text}',
                    bbox={
                        'facecolor': 'black',
                        'alpha': 0.6,
                        'pad': 0.7,
                        'edgecolor': 'none'
                    },
                    color=text_color[k],
                    fontsize=font_size,
                    verticalalignment='top',
                    horizontalalignment='left')
                if i==0:
                    vis_img = cv2.rectangle(vis_img, (bbox_int[0], bbox_int[1]), (bbox_int[2], bbox_int[3]), (36,255,12), thickness)
                    vis_img = cv2.putText(vis_img, f'{label_text}', (bbox_int[0], bbox_int[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (36,255,12), thickness)

    plt.imshow(img)

    p = PatchCollection(
        polygons, facecolor='none', edgecolors=color, linewidths=thickness)
    ax.add_collection(p)

    if out_file is not None:
        dir_name = osp.abspath(osp.dirname(out_file))
        mmcv.mkdir_or_exist(dir_name)
        plt.savefig(out_file)
        if not show:
            plt.close()
    if show:
        if wait_time == 0:
            plt.show()
        else:
            plt.show(block=False)
            plt.pause(wait_time)
            plt.close()
    return mmcv.rgb2bgr(vis_img)

def imshow_gt_det_bboxes(img,
                         annotation,
                         result,
                         class_names=None,
                         score_thr=0,
                         gt_bbox_color=(255, 102, 61),
                         gt_text_color=(255, 102, 61),
                         gt_mask_color=(255, 102, 61),
                         det_bbox_color=(72, 101, 241),
                         det_text_color=(72, 101, 241),
                         det_mask_color=(72, 101, 241),
                         thickness=2,
                         font_size=13,
                         win_name='',
                         fig_size=(15, 10),
                         show=True,
                         wait_time=0,
                         out_file=None):
    """General visualization GT and result function.

    Args:
      img (str or ndarray): The image to be displayed.)
      annotation (dict): Ground truth annotations where contain keys of
          'gt_bboxes' and 'gt_labels' or 'gt_masks'
      result (tuple[list] or list): The detection result, can be either
          (bbox, segm) or just bbox.
      class_names (list[str]): Names of each classes.
      score_thr (float): Minimum score of bboxes to be shown.  Default: 0
      gt_bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
           The tuple of color should be in BGR order. Default: (255, 102, 61)
      gt_text_color (str or tuple(int) or :obj:`Color`):Color of texts.
           The tuple of color should be in BGR order. Default: (255, 102, 61)
      gt_mask_color (str or tuple(int) or :obj:`Color`, optional):
           Color of masks. The tuple of color should be in BGR order.
           Default: (255, 102, 61)
      det_bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
           The tuple of color should be in BGR order. Default: (72, 101, 241)
      det_text_color (str or tuple(int) or :obj:`Color`):Color of texts.
           The tuple of color should be in BGR order. Default: (72, 101, 241)
      det_mask_color (str or tuple(int) or :obj:`Color`, optional):
           Color of masks. The tuple of color should be in BGR order.
           Default: (72, 101, 241)
      thickness (int): Thickness of lines. Default: 2
      font_size (int): Font size of texts. Default: 13
      win_name (str): The window name. Default: ''
      fig_size (tuple): Figure size of the pyplot figure. Default: (15, 10)
      show (bool): Whether to show the image. Default: True
      wait_time (float): Value of waitKey param. Default: 0.
      out_file (str, optional): The filename to write the image.
         Default: None

    Returns:
        ndarray: The image with bboxes or masks drawn on it.
    """
    assert 'gt_bboxes' in annotation
    assert 'gt_labels' in annotation
    assert isinstance(
        result,
        (tuple, list)), f'Expected tuple or list, but get {type(result)}'

    gt_masks = annotation.get('gt_masks', None)
    if gt_masks is not None:
        gt_masks = mask2ndarray(gt_masks)

    img = mmcv.imread(img)

    img = imshow_det_bboxes(
        img,
        annotation['gt_bboxes'],
        annotation['gt_labels'],
        gt_masks,
        class_names=class_names,
        bbox_color=gt_bbox_color,
        text_color=gt_text_color,
        mask_color=gt_mask_color,
        thickness=thickness,
        font_size=font_size,
        win_name=win_name,
        fig_size=fig_size,
        show=False)

    if isinstance(result, tuple):
        bbox_result, segm_result = result
        if isinstance(segm_result, tuple):
            segm_result = segm_result[0]  # ms rcnn
    else:
        bbox_result, segm_result = result, None

    bboxes = np.vstack(bbox_result)
    labels = [
        np.full(bbox.shape[0], i, dtype=np.int32)
        for i, bbox in enumerate(bbox_result)
    ]
    labels = np.concatenate(labels)

    segms = None
    if segm_result is not None and len(labels) > 0:  # non empty
        segms = mmcv.concat_list(segm_result)
        segms = mask_util.decode(segms)
        segms = segms.transpose(2, 0, 1)

    img = imshow_det_bboxes(
        img,
        bboxes,
        labels,
        segms=segms,
        class_names=class_names,
        score_thr=score_thr,
        bbox_color=det_bbox_color,
        text_color=det_text_color,
        mask_color=det_mask_color,
        thickness=thickness,
        font_size=font_size,
        win_name=win_name,
        fig_size=fig_size,
        show=show,
        wait_time=wait_time,
        out_file=out_file)
    return img
