from typing import Dict, List
import torch
import colorsys
import random
import numpy as np
from skimage.draw import line_aa, circle_perimeter_aa
import cv2
from .util import select_data


def _gen_random_colors(N, bright=True):
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    random.shuffle(colors)
    return colors


_static_label_colors = [
    np.array((1.0, 1.0, 1.0), np.float32),
    np.array((255, 250, 79), np.float32) / 255.0,  # face
    np.array([255, 125, 138], np.float32) / 255.0,  # lb
    np.array([213, 32, 29], np.float32) / 255.0,  # rb
    np.array([0, 144, 187], np.float32) / 255.0,  # le
    np.array([0, 196, 253], np.float32) / 255.0,  # re
    np.array([255, 129, 54], np.float32) / 255.0,  # nose
    np.array([88, 233, 135], np.float32) / 255.0,  # ulip
    np.array([0, 117, 27], np.float32) / 255.0,  # llip
    np.array([255, 76, 249], np.float32) / 255.0,  # imouth
    np.array((1.0, 0.0, 0.0), np.float32),  # hair
    np.array((255, 250, 100), np.float32) / 255.0,  # lr
    np.array((255, 250, 100), np.float32) / 255.0,  # rr
    np.array((250, 245, 50), np.float32) / 255.0,  # neck
    np.array((0.0, 1.0, 0.5), np.float32),  # cloth
    np.array((1.0, 0.0, 0.5), np.float32),
] + _gen_random_colors(256)

_names_in_static_label_colors = [
    'background', 'face', 'lb', 'rb', 'le', 're', 'nose',
    'ulip', 'llip', 'imouth', 'hair', 'lr', 'rr', 'neck',
    'cloth', 'eyeg', 'hat', 'earr'
]


def _blend_labels(image, labels, label_names_dict=None,
                  default_alpha=0.6, color_offset=None):
    assert labels.ndim == 2
    bg_mask = labels == 0
    if label_names_dict is None:
        colors = _static_label_colors
    else:
        colors = [np.array((1.0, 1.0, 1.0), np.float32)]
        for i in range(1, labels.max() + 1):
            if isinstance(label_names_dict, dict) and i not in label_names_dict:
                bg_mask = np.logical_or(bg_mask, labels == i)
                colors.append(np.zeros((3)))
                continue
            label_name = label_names_dict[i]
            if label_name in _names_in_static_label_colors:
                color = _static_label_colors[
                    _names_in_static_label_colors.index(
                        label_name)]
            else:
                color = np.array((1.0, 1.0, 1.0), np.float32)
            colors.append(color)

    if color_offset is not None:
        ncolors = []
        for c in colors:
            nc = np.array(c)
            if (nc != np.zeros(3)).any():
                nc += color_offset
            ncolors.append(nc)
        colors = ncolors

    if image is None:
        image = orig_image = np.zeros(
            [labels.shape[0], labels.shape[1], 3], np.float32)
        alpha = 1.0
    else:
        orig_image = image / np.max(image)
        image = orig_image * (1.0 - default_alpha)
        alpha = default_alpha
    for i in range(1, np.max(labels) + 1):
        image += alpha * \
            np.tile(
                np.expand_dims(
                    (labels == i).astype(np.float32), -1),
                [1, 1, 3]) * colors[(i) % len(colors)]
    image[np.where(image > 1.0)] = 1.0
    image[np.where(image < 0)] = 0.0
    image[np.where(bg_mask)] = orig_image[np.where(bg_mask)]
    return image


def _draw_hwc(image: torch.Tensor, data: Dict[str, torch.Tensor]):
    device = image.device
    image = np.array(image.cpu().numpy(), copy=True)
    dtype = image.dtype
    h, w, _ = image.shape

    draw_score_error = False
    for tag, batch_content in data.items():
        if tag == 'rects':
            for cid, content in enumerate(batch_content):
                x1, y1, x2, y2 = [int(v) for v in content]
                y1, y2 = [max(min(v, h-1), 0) for v in [y1, y2]]
                x1, x2 = [max(min(v, w-1), 0) for v in [x1, x2]]
                for xx1, yy1, xx2, yy2 in [
                    [x1, y1, x2, y1],
                    [x1, y2, x2, y2],
                    [x1, y1, x1, y2],
                    [x2, y1, x2, y2]
                ]:
                    rr, cc, val = line_aa(yy1, xx1, yy2, xx2)
                    val = val[:, None][:, [0, 0, 0]]
                    image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255

                if 'scores' in data:
                    try:
                        import cv2
                        score = data['scores'][cid].item()
                        score_str = f'{score:0.3f}'
                        image_c = np.array(image).copy()
                        cv2.putText(image_c, score_str, org=(int(x1), int(y2)),
                                    fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                                    fontScale=0.6, color=(255, 255, 255), thickness=1)
                        image[:, :, :] = image_c
                    except Exception as e:
                        if not draw_score_error:
                            print(f'Failed to draw scores on image.')
                            print(e)
                        draw_score_error = True

        if tag == 'points':
            for content in batch_content:
                # content: npoints x 2
                for x, y in content:
                    x = max(min(int(x), w-1), 0)
                    y = max(min(int(y), h-1), 0)
                    rr, cc, val = circle_perimeter_aa(y, x, 1)
                    valid = np.all([rr >= 0, rr < h, cc >= 0, cc < w], axis=0)
                    rr = rr[valid]
                    cc = cc[valid]
                    val = val[valid]
                    val = val[:, None][:, [0, 0, 0]]
                    image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255

        if tag == 'seg':
            label_names = batch_content['label_names']
            for seg_logits in batch_content['logits']:
                # content: nclasses x h x w
                seg_probs = seg_logits.softmax(dim=0)
                seg_labels = seg_probs.argmax(dim=0).cpu().numpy()
                image = (_blend_labels(image.astype(np.float32) /
                         255, seg_labels,
                         label_names_dict=label_names) * 255).astype(dtype)

    return torch.from_numpy(image).to(device=device)


def draw_bchw(images: torch.Tensor, data: Dict[str, torch.Tensor]) -> torch.Tensor:
    images2 = []
    for image_id, image_chw in enumerate(images):
        selected_data = select_data(image_id == data['image_ids'], data)
        images2.append(
            _draw_hwc(image_chw.permute(1, 2, 0), selected_data).permute(2, 0, 1))
    return torch.stack(images2, dim=0)

def draw_landmarks(img, bbox=None, landmark=None, color=(0, 255, 0)):
    """
    Input:
    - img: gray or RGB
    - bbox: type of BBox
    - landmark: reproject landmark of (5L, 2L)
    Output:
    - img marked with landmark and bbox
    """
    img = cv2.UMat(img).get()
    if bbox is not None:
        x1, y1, x2, y2 = np.array(bbox)[:4].astype(np.int32)
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
    if landmark is not None:
        for x, y in np.array(landmark).astype(np.int32):
            cv2.circle(img, (int(x), int(y)), 2, color, -1)
    return img

def draw_landmarks_only_eyes(img, bbox=None, landmark=None, color=(0, 255, 0)):
    """
    Input:
    - img: gray or RGB
    - bbox: type of BBox
    - landmark: reproject landmark of (5L, 2L)
    Output:
    - img marked with landmark and bbox
    """
    # img = cv2.UMat(img).get()
    if bbox is not None:
        x1, y1, x2, y2 = np.array(bbox)[:4].astype(np.int32)
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
    landmark = np.vstack((landmark[60:76, :], landmark[96:98, :]))
    if landmark is not None:
        for x, y in np.array(landmark).astype(np.int32):
            cv2.circle(img, (int(x), int(y)), 2, color, -1)
    return img