import colorsys

import cv2
import numpy as np
import torch as pt
import torchvision.utils as ptvu
import torch.nn.functional as ptnf


def normaliz_for_visualiz(image: np.ndarray):
    return (image - image.min()) / (image.max() - image.min())


def even_resize_and_center_crop(image: np.ndarray, size: int, interp=cv2.INTER_LINEAR):
    h0, w0 = image.shape[:2]
    ratio = size / min(h0, w0)
    image2 = cv2.resize(image, dsize=None, fx=ratio, fy=ratio, interpolation=interp)
    h2, w2 = image2.shape[:2]
    t = (h2 - size) // 2
    l = (w2 - size) // 2
    b = t + size
    r = l + size
    output = image2[t:b, l:r]
    return output


def calc_foreground_center_bbox(segment_index, haxis=-2, waxis=-1):
    """
    segment_index: in shape (..,h,w)
    """
    foreground = segment_index > 0
    _, y, x = np.where(foreground)
    l = x.min()
    t = y.min()
    r = x.max()
    b = y.max()
    cx, cy = (l + r) / 2, (t + b) / 2
    h, w = segment_index.shape[haxis], segment_index.shape[waxis]
    side = min(h, w)
    lobe = side / 2
    if h == w:
        bbox = [0, 0, w, h]
    elif h < w:
        bbox = [cx - lobe, 0, cx + lobe, h]
    else:  # h > w
        bbox = [0, cy - lobe, w, cy + lobe]
    bbox = np.round(bbox).astype("int32")
    if bbox[0] < 0:
        bbox[0] = 0
        bbox[2] = side
    if bbox[1] < 0:
        bbox[1] = 0
        bbox[3] = side
    if bbox[2] > w:
        bbox[2] = w
        bbox[0] = w - side
    if bbox[3] > h:
        bbox[3] = h
        bbox[1] = h - side
    assert np.all(bbox[:2] >= 0) and (bbox[2] <= w) and (bbox[3] <= h)
    assert np.all(bbox[2:] - bbox[:2] == np.array([side] * 2))
    return bbox


def index_segment_to_bbox(segment_index: np.ndarray, num=None, max_num=255):
    """
    segment_index: in shape (h,w)
    num: number of objects + the background
    """
    assert segment_index.ndim == 2 and segment_index.dtype == np.uint8
    if num is None:
        num = len(set(segment_index) + {0})
    idxs = list(range(num))
    assert num <= max_num
    idxs.pop(0)  # not include the bbox for background
    bbox = np.zeros([num, 4], dtype="float32")
    for i, idx in enumerate(idxs):
        flag = segment_index == idx
        y, x = np.where(flag)
        assert len(x) == len(y)
        if len(y) > 0:
            bbox[i, 0] = np.min(x)
            bbox[i, 1] = np.min(y)
            bbox[i, 2] = np.max(x)
            bbox[i, 3] = np.max(y)
        else:
            bbox[i] = [0, 0, 0, 0]
    return bbox  # ltrb


def color_segment_to_index_segment_and_bbox(
    segment_color: np.ndarray, num=None, max_num=255
):
    """
    segment_color: in shape (h,w,c=3)
    num: number of objects + the background
    """
    assert segment_color.ndim == 3 and segment_color.dtype == np.uint8  # (h,w,c=3)
    segment = (segment_color * [[[256**2, 256**1, 256**0]]]).sum(2)
    idxs = np.unique(segment).tolist()
    if num is None:
        num = len(set(idxs) | {0})
    assert num <= max_num
    idxs.pop(0)  # not include the bbox for background
    segment_index = np.zeros_like(segment, dtype="uint8")
    bbox = np.zeros([num, 4], dtype="float32")
    for i, idx in enumerate(idxs):
        flag = segment == idx
        segment_index[flag] = i + 1
        y, x = np.where(flag)
        bbox[i, 0] = np.min(x)
        bbox[i, 1] = np.min(y)
        bbox[i, 2] = np.max(x)
        bbox[i, 3] = np.max(y)
    if segment_index.max() + 1 > max_num:
        return [None] * 2
    return segment_index, bbox  # ltrb


def generate_spectrum_colors(num_color):
    spectrum = []
    for i in range(num_color):
        hue = i / float(num_color)
        rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
        spectrum.append([int(255 * c) for c in rgb])
    return np.array(spectrum, dtype="uint8")  # (n,c=3)


def draw_segmentation_np(image: np.ndarray, segment: np.ndarray, max_num=0, alpha=0.5):
    """
    image: in shape (h,w,c)
    segment: in shape (h,w)
    """
    if not max_num:
        max_num = int(segment.max() + 1)
    colors = generate_spectrum_colors(max_num).tolist()
    mask = ptnf.one_hot(pt.from_numpy(segment.astype("int64")), max_num)
    image2 = ptvu.draw_segmentation_masks(
        image=pt.from_numpy(image).permute(2, 0, 1),
        masks=mask.bool().permute(2, 0, 1),
        alpha=alpha,
        colors=colors,
    )
    return image2.permute(1, 2, 0).numpy()
