import os
import pickle

import cv2
import numpy as np


def get_explanation_mask(center_pt, vertex, bottom_length):
    assert center_pt[0] == vertex[0]
    assert center_pt[1] > vertex[1]
    ymax = np.random.randint(center_pt[1], min(center_pt[1]+5, 63))
    ymin = np.random.randint(max(0, vertex[1]-5), vertex[1])
    xmax = np.random.randint(vertex[0]+1, min(vertex[0] + 5, 63))
    xmin = np.random.randint(max(0, vertex[0] - 5), vertex[0])
    mask = []
    for i in range(ymin, ymax+1):
        for j in range(xmin, xmax+1):
            mask.append([i, j])
    return mask


def flatten_mask(mask, input_w=64, input_h=64, input_c=1):
    tmp = []
    for i in range(len(mask)):
        cur = mask[i]
        for c in range(input_c):
            tmp.append(cur[0] * input_w + cur[1] + c * input_w * input_h)
    new_mask = np.zeros((input_w * input_h * input_c), bool)
    new_mask[tmp] = True
    return new_mask


def draw_triangle(image, up):
    coord1 = (np.random.randint(20, image.shape[0] - 20), np.random.randint(30, image.shape[1] - 20))
    coord2 = (np.random.randint(20, image.shape[0]), coord1[1])
    while np.abs(coord2[0] - coord1[0]) <= 10:
        # Make sure the length is not too short
        coord2 = (np.random.randint(20, image.shape[0]), coord1[1])
    if up:
        coord3 = (int((coord1[0] + coord2[0]) / 2), np.random.randint(5, coord1[1]))
        center_pt = [coord3[0], coord1[1]]
        explanation = get_explanation_mask(center_pt, coord3, np.abs(coord2[0] - coord1[0]))
    else:
        coord3 = (int((coord1[0] + coord2[0]) / 2), np.random.randint(coord1[1] + 5, image.shape[1] - 5))
        center_pt = [coord3[0], coord1[1]]
        explanation = get_explanation_mask(coord3, center_pt, np.abs(coord2[0] - coord1[0]))

    # Draw on the target image
    cv2.drawContours(image, [np.array([coord1, coord2, coord3])], 0, (1), -1)

    return explanation


def create_triangle_dataset(size=1000, up_ratio=0.5, image_size=(64, 64), save_dir=None, is_train=True, random_seed=None):
    if random_seed is not None:
        np.random.seed(random_seed)
    is_up = np.random.binomial(1, up_ratio, size=size)
    images = np.array([np.zeros(image_size, np.uint8)] * size)
    masks = []
    for i in range(is_up.shape[0]):
        # print("image {}, label {}".format(i, is_up[i]))
        image = images[i]
        if is_up[i] == 1:
            explanation = draw_triangle(image, up=True)
        else:
            explanation = draw_triangle(image, up=False)

        mask = flatten_mask(explanation, image_size[0], image_size[1], 1)
        masks.append(mask)

    masks = np.array(masks)

    if save_dir is not None:
        save_dataset(images, is_up, masks, save_dir, "triangle_mono", is_train)

    return images, is_up, masks


def save_dataset(data, labels, masks, directory, name, is_train=True):
    d = {'data': data, 'labels': labels, 'masks': masks}
    mode = "train" if is_train else "test"
    with open(os.path.join(directory, ("{}_{}.pickle").format(name, mode)), 'wb') as handle:
        pickle.dump(d, handle, protocol=pickle.HIGHEST_PROTOCOL)


if __name__ == "__main__":
    data, labels, masks = create_triangle_dataset(size=10000, save_dir="./data/valid", is_train=False)