import os
import pickle

import cv2
import numpy as np


def create_fox_cat_dataset(size=1000, fox_ratio=0.5, image_size=(64, 64, 3), save_dir=None, is_train=True):
    is_fox = np.random.binomial(1, fox_ratio, size=size)
    images = np.array([np.ones(image_size, np.uint8) * 0] * size)
    masks = []
    for i in range(is_fox.shape[0]):
        image = images[i]
        draw_rectangle(image)
        if is_fox[i] == 1:
            # Draw a fox
            mask = draw_triangle(image)
        else:
            # Draw a cat
            mask = draw_circle(image)
        mask = flatten_mask(mask, image_size[0], image_size[1], image_size[2])
        masks.append(mask)
    masks = np.array(masks)

    if save_dir is not None:
        save_dataset(images, is_fox, masks, save_dir, "fox_cat_dark", is_train)

    return images, is_fox, masks


def create_triangle_circle_dataset(size=1000, triangle_ratio=0.5, image_size=(64, 64, 3), save_dir=None, is_train=True):
    is_triangle = np.random.binomial(1, triangle_ratio, size=size)
    images = np.array([np.ones(image_size, np.uint8) * 0] * size)
    masks = []
    for i in range(is_triangle.shape[0]):
        image = images[i]
        if is_triangle[i] == 1:
            # Draw a triangle
            mask = draw_triangle(image)
        else:
            # Draw a circle
            mask = draw_circle(image)
        mask = flatten_mask(mask, image_size[0], image_size[1], image_size[2])
        masks.append(mask)
    masks = np.array(masks)

    if save_dir is not None:
        save_dataset(images, is_triangle, masks, save_dir, "triangle_circle_dark", is_train)

    return images, is_triangle, masks


def save_dataset(data, labels, masks, directory, name, is_train=True):
    d = {'data': data, 'labels': labels, 'masks': masks}
    mode = "train" if is_train else "test"
    with open(os.path.join(directory, ("{}_{}.pickle").format(name, mode)), 'wb') as handle:
        pickle.dump(d, handle, protocol=pickle.HIGHEST_PROTOCOL)


def flatten_mask(mask, input_w=64, input_h=64, input_c=3):
    tmp = []
    for i in range(len(mask)):
        cur = mask[i]
        for c in range(input_c):
            tmp.append(cur[0] * input_w + cur[1] + c * input_w * input_h)
    new_mask = np.zeros((input_w * input_h * input_c), bool)
    new_mask[tmp] = True
    return new_mask


def draw_triangle(image, color=None):
    if color is None:
        color = np.random.randint(256, size=3).tolist()
    coord1 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
    coord2 = (np.random.randint(image.shape[0]), coord1[1])
    while np.abs(coord2[0] - coord1[0]) <= 10:
        # Make sure the length is not too short
        coord2 = (np.random.randint(image.shape[0]), coord1[1])
    coord3 = (int((coord1[0] + coord2[0]) / 2), np.random.randint(image.shape[1]))
    while np.abs(coord3[1] - coord1[1]) <= 10:
        # Make sure the height is not too short
        coord3 = (int((coord1[0] + coord2[0]) / 2), np.random.randint(image.shape[1]))

    # Draw on a blank image for easy retrieval of the mask
    image_dub = np.ones((image.shape), np.uint8) * 255
    cv2.drawContours(image_dub, [np.array([coord1, coord2, coord3])], 0, color, -1)
    mask = []
    for i in range(image_dub.shape[0]):
        for j in range(image_dub.shape[1]):
            if image_dub[i, j, 0] == color[0] and image_dub[i, j, 1] == color[1] and image_dub[i, j, 2] == color[2]:
                mask.append([i, j])

    # Draw on the target image
    cv2.drawContours(image, [np.array([coord1, coord2, coord3])], 0, color, -1)

    return mask


def draw_circle(image, color=None):
    if color is None:
        color = np.random.randint(256, size=3).tolist()
    center = (np.random.randint(10, image.shape[0] - 10), np.random.randint(10, image.shape[1] - 10))
    radius = np.random.randint(np.min(center))

    # Draw on a blank image for easy retrieval of the mask
    image_dub = np.ones((image.shape), np.uint8) * 255
    cv2.circle(image_dub, center,
               min([center[0], image.shape[0] - center[0], center[1], image.shape[1] - center[1]]), color, -1)
    mask = []
    for i in range(image_dub.shape[0]):
        for j in range(image_dub.shape[1]):
            if image_dub[i, j, 0] == color[0] and image_dub[i, j, 1] == color[1] and image_dub[i, j, 2] == color[2]:
                mask.append([i, j])

    # Draw on the target image
    cv2.circle(image, center,
               min([center[0], image.shape[0] - center[0], center[1], image.shape[1] - center[1]]), color, -1)

    return mask


def draw_rectangle(image, color=None):
    if color is None:
        color = np.random.randint(256, size=3).tolist()

    coord1 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
    coord2 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
    while abs(coord1[0] - coord2[0]) <= 5 or abs(coord1[1] - coord2[1]) <= 5:
        coord2 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
    cv2.rectangle(image, coord1, coord2, color, -1)


def create_two_lines_dataset(size=1000, intersect_ratio=0.5, image_size=(64, 64, 3), save_dir=None, is_train=True):
    lines_are_intersect = np.random.binomial(1, intersect_ratio, size=size)
    images = np.array([np.ones(image_size, np.uint8) * 0]*size)
    masks = []
    for i in range(lines_are_intersect.shape[0]):
        # print(i)
        image = images[i]
        if lines_are_intersect[i] == 1:
            mask = draw_line(image, intersection=True)
        else:
            mask = draw_line(image, intersection=False)
        mask = flatten_mask(mask, image_size[0], image_size[1], image_size[2])
        masks.append(mask)
    masks = np.array(masks)

    if save_dir is not None:
        save_dataset(images, lines_are_intersect, masks, save_dir, "two_lines", is_train)

    return images, lines_are_intersect, masks


def get_mask(image, segment1, segment2):
    image_dub1 = image.copy()
    image_dub2 = image.copy()
    cv2.line(image_dub1, segment1[0], segment1[1], (1,0,0), 1)
    cv2.line(image_dub2, segment2[0], segment2[1], (0,1,0), 1)
    mask = []
    for i in range(image_dub1.shape[0]):
        for j in range(image_dub1.shape[1]):
            if image_dub1[i, j, 0] == 1 and image_dub2[i, j, 1] == 1:
                mask.append([i, j])
    return mask


def is_intersect(image, segment1, segment2):
    mask = get_mask(image, segment1, segment2)
    return len(mask) > 0


def draw_line(image, color1=None, color2=None, intersection=True):
    if color1 is None:
        color1 = np.random.randint(256, size=3).tolist()
    if color2 is None:
        color2 = np.random.randint(256, size=3).tolist()
    while np.array(color1).sum() == 0:
        color1 = np.random.randint(256, size=3).tolist()
    while np.array(color2).sum() == 0:
        color2 = np.random.randint(256, size=3).tolist()
    line1_coord1 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
    line1_coord2 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
    while abs(line1_coord1[0] - line1_coord2[0]) < image.shape[0] / 5 or abs(line1_coord1[1] - line1_coord2[1]) < image.shape[1] / 5:
        line1_coord1 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
        line1_coord2 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
    segment1 = (line1_coord1, line1_coord2)

    line2_coord1 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
    line2_coord2 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
    while is_intersect(image, segment1, (line2_coord1, line2_coord2)) != intersection:
        if not intersection:
            line2_coord1 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
            line2_coord2 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
        else:
            # Need to reverse it here!!!
            tmp = pick_a_point_on_line(image, segment1)
            line2_coord1 = (tmp[1], tmp[0])
            assert is_intersect(image, segment1, (line2_coord1, line2_coord2))
            while abs(line2_coord1[0] - line2_coord2[0]) < image.shape[0] / 5 or abs(line2_coord1[1] - line2_coord2[1]) < image.shape[1] / 5:
                line2_coord2 = (np.random.randint(image.shape[0]), np.random.randint(image.shape[1]))
    segment2 = (line2_coord1, line2_coord2)

    mask = get_mask(image, segment1, segment2)
    cv2.line(image, segment1[0], segment1[1], color1, 1)
    cv2.line(image, segment2[0], segment2[1], color2, 1)

    return mask

def pick_a_point_on_line(image, segment):
    image_dub = image.copy()
    cv2.line(image_dub, segment[0], segment[1], (1, 0, 0), 1)
    mask = []
    for i in range(image_dub.shape[0]):
        for j in range(image_dub.shape[1]):
            if image_dub[i, j, 0] == 1:
                mask.append([i, j])
    pick_idx = np.random.randint(len(mask))
    return mask[pick_idx][0], mask[pick_idx][1]

if __name__ == "__main__":
    data, labels, masks = create_two_lines_dataset(size=20000, save_dir="../data", is_train=True)
    data_test, labels_test, masks_test = create_two_lines_dataset(size=10000, save_dir="../data", is_train=False)