import numpy as np
import cv2
import random
# from rdp import rdp
# from interval import Interval, IntervalSet
import time
from torchvision.transforms import transforms
import torch

patch_trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])


def canvas_size_google(sketch):
    """
    读取quickDraw的画布大小及起始点
    :param sketch: google sketch, quickDraw
    :return: int list,[x, y, h, w]
    """
    # get canvas size

    vertical_sum = np.cumsum(sketch[1:], axis=0)  # 累加 排除第一笔未知的偏移量
    xmin, ymin, _ = np.min(vertical_sum, axis=0)
    xmax, ymax, _ = np.max(vertical_sum, axis=0)
    w = xmax - xmin
    h = ymax - ymin
    start_x = -xmin - sketch[0][0]  # 等效替换第一笔
    start_y = -ymin - sketch[0][1]
    # sketch[0] = sketch[0] - sketch[0]
    # 返回可能处理过的sketch
    return [int(start_x), int(start_y), int(h), int(w)]


def draw_three(sketch, window_name="google", padding=30,
               random_color=False, time=1, show=False, img_size=512):
    """
    此处主要包含画图部分，从canvas_size_google()获得画布的大小和起始点的位置，根据strokes来画
    :param sketches: google quickDraw, (n, 3)
    :param window_name: pass
    :param thickness: pass
    :return: None
    """
    # print("three ")
    # print(sketch)
    # print("-" * 70)
    thickness = int(img_size * 0.025)

    sketch = scale_sketch(sketch, (img_size, img_size))  # scale the sketch.
    [start_x, start_y, h, w] = canvas_size_google(sketch=sketch)
    start_x += thickness + 1
    start_y += thickness + 1
    canvas = np.ones((max(h, w) + 3 * (thickness + 1), max(h, w) + 3 * (thickness + 1), 3), dtype='uint8') * 255
    if random_color:
        color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
    else:
        color = (0, 0, 0)
    pen_now = np.array([start_x, start_y])
    first_zero = False
    for stroke in sketch:
        delta_x_y = stroke[0:0 + 2]
        state = stroke[2:]
        if first_zero:  # 首个零是偏移量, 不画
            pen_now += delta_x_y
            first_zero = False
            continue
        cv2.line(canvas, tuple(pen_now), tuple(pen_now + delta_x_y), color, thickness=thickness)
        if int(state) == 1:  # next stroke
            first_zero = True
            if random_color:
                color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            else:
                color = (0, 0, 0)
        pen_now += delta_x_y
    if show:
        key = cv2.waitKeyEx()
        if key == 27:  # esc
            cv2.destroyAllWindows()
            exit(0)
    # cv2.imwrite(f"./{window_name}.png", canvas)
    return cv2.resize(canvas, (img_size, img_size))


def make_graph(sketch, graph_picture_size=256, random_color=False, channel_3=False):
    tmp_img_size = 512
    thickness = int(tmp_img_size * 0.025)
    # preprocess
    sketch = scale_sketch(sketch, (tmp_img_size, tmp_img_size))  # scale the sketch.
    [start_x, start_y, h, w] = canvas_size_google(sketch=sketch)
    start_x += thickness + 1
    start_y += thickness + 1

    if channel_3:
        canvas = np.zeros((max(h, w) + 2 * (thickness + 1), max(h, w) + 2 * (thickness + 1), 3), dtype='uint8') + 255
    else:
        canvas = np.zeros((max(h, w) + 2 * (thickness + 1), max(h, w) + 2 * (thickness + 1)), dtype='uint8')
    if random_color:
        color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
    else:
        color = (255, 255, 255)
    pen_now = np.array([start_x, start_y])
    first_zero = False

    # generate canvas.
    for index, stroke in enumerate(sketch):
        delta_x_y = stroke[0:0 + 2]
        state = stroke[2:]
        if first_zero:  # 首个零是偏移量, 不画
            pen_now += delta_x_y
            first_zero = False
            continue
        cv2.line(canvas, tuple(pen_now), tuple(pen_now + delta_x_y), color, thickness=thickness)
        if int(state) != 0:  # next stroke
            first_zero = True
            if random_color:
                color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            else:
                color = (255, 255, 255)
        pen_now += delta_x_y
    if channel_3:
        _h, _w, _c = canvas.shape  # (h, w, c)
        boundary_size = int(graph_picture_size * 1.5)
        top_bottom = np.zeros((boundary_size, _w, 3), dtype=canvas.dtype) + 255
        left_right = np.zeros((boundary_size * 2 + _h, boundary_size, 3), dtype=canvas.dtype) + 255
    else:
        _h, _w = canvas.shape  # (h, w, c)
        boundary_size = int(graph_picture_size * 1.5)
        top_bottom = np.zeros((boundary_size, _w), dtype=canvas.dtype)
        left_right = np.zeros((boundary_size * 2 + _h, boundary_size), dtype=canvas.dtype)
    canvas = np.concatenate((top_bottom, canvas, top_bottom), axis=0)
    canvas = np.concatenate((left_right, canvas, left_right), axis=1)

    canvas_global = cv2.resize(canvas[boundary_size - thickness:-boundary_size + thickness,
                               boundary_size - thickness:-boundary_size + thickness],
                               (graph_picture_size, graph_picture_size))
    cv2.imwrite('test.jpg', canvas_global)
    print(torch.max(torch.tensor(canvas_global)))

    return torch.tensor(canvas_global)

def remove_white_space_image(img_np: np.ndarray):
    """
    获取白底图片中, 物体的bbox; 此处白底必须是纯白色.
    其中, 白底有两种表示方法, 分别是 1.0 以及 255; 在开始时进行检查并且匹配
    对最大值为255的图片进行操作.
    三通道的图无法直接使用255进行操作, 为了减小计算, 直接将三通道相加, 值为255*3的pix 认为是白底.
    :param img_np:
    :return:
    """
    if np.max(img_np) <= 1.0:  # 1.0 <= 1.0 True
        img_np = (img_np * 255).astype("uint8")
    else:
        img_np = img_np.astype("uint8")
    img_np_single = np.sum(img_np, axis=2)
    Y, X = np.where(img_np_single <= 760)  # max = 765, 留有一些余地
    ymin, ymax, xmin, xmax = np.min(Y), np.max(Y), np.min(X), np.max(X)
    img_cropped = img_np[ymin:ymax, xmin:xmax, :]
    return img_cropped  # (h, w), img_cropped


def remove_white_space_sketch(sketch):
    """
    删除留白
    :param sketch:
    :return:
    """
    min_list = np.min(sketch, axis=0)
    sketch[:, :2] = sketch[:, :2] - np.array(min_list[:2])
    return sketch


def scale_sketch(sketch, size=(448, 448)):
    [_, _, h, w] = canvas_size_google(sketch)
    if h >= w:
        sketch_normalize = sketch / np.array([[h, h, 1]], dtype=np.float)
    else:
        sketch_normalize = sketch / np.array([[w, w, 1]], dtype=np.float)
    sketch_rescale = sketch_normalize * np.array([[size[0], size[1], 1]], dtype=np.float)
    return sketch_rescale.astype("int16")


if __name__ == '__main__':
    import glob
    import os

    prob = 0.3
    for each_np_path in glob.glob("../dataset/*.npz"):
        catname = each_np_path.split("/")[-1].split(".")[0]
        os.makedirs(f"/root/human-study/human/{prob}/{catname}", exist_ok=True)
        dataset_np = np.load(each_np_path, encoding='latin1', allow_pickle=True)

        npz_ = dataset_np['test']

        for index, sample in enumerate(npz_):
            gra, adj = make_graph(sample, graph_picture_size=64,
                                  mask_prob=prob, random_color=False, channel_3=False,
                                  save=f"/root/human-study/human/{prob}/{catname}/{index}.jpg")
            print(index)

