import os
import pickle

import cv2
import numpy as np


def get_explanation_mask(center_pt, vertex, bottom_length):
    assert center_pt[0] == vertex[0]
    assert center_pt[1] > vertex[1]
    ymax = np.random.randint(vertex[1]+1, center_pt[1])
    ymin = np.random.randint(max(0, vertex[1]-5), vertex[1])
    xmax = np.random.randint(vertex[0]+1, min(vertex[0] + 0.5 * bottom_length, 64))
    xmin = np.random.randint(max(0, vertex[0] - 0.5 * bottom_length), vertex[0])
    mask = []
    for i in range(ymin, ymax+1):
        for j in range(xmin, xmax+1):
            mask.append([i, j])
    return mask


def flatten_mask(mask, input_w=64, input_h=64, input_c=1):
    tmp = []
    for i in range(len(mask)):
        cur = mask[i]
        for c in range(input_c):
            tmp.append(cur[0] * input_w + cur[1] + c * input_w * input_h)
    new_mask = np.zeros((input_w * input_h * input_c), bool)
    new_mask[tmp] = True
    return new_mask
        

def draw_triangle(image):
    coord1 = (np.random.randint(20, image.shape[0]-20), np.random.randint(30, image.shape[1]-20))
    coord2 = (np.random.randint(20, image.shape[0]), coord1[1])
    while np.abs(coord2[0] - coord1[0]) <= 10:
        # Make sure the length is not too short
        coord2 = (np.random.randint(20, image.shape[0]), coord1[1])
    coord3 = (int((coord1[0] + coord2[0])/2), np.random.randint(1, coord1[1]))
    while np.abs(coord3[1] - coord1[1]) <= 10:
        # Make sure the height is not too short
        coord3 = (int((coord1[0] + coord2[0])/2), np.random.randint(1, coord1[1]))
        
    # Draw on a blank image for easy retrieval of the mask
    image_dub = np.zeros((image.shape), np.uint8)
    image_dub2 = np.zeros((image.shape), np.uint8)
    cv2.drawContours(image_dub, [np.array([coord1, coord2, coord3])], 0, (1), -1)
    cv2.drawContours(image_dub2, [np.array([coord1, coord2, coord3])], 0, (1), 1)
    interior = []
    for i in range(image_dub.shape[0]):
        for j in range(image_dub.shape[1]):
            if image_dub[i, j] == 1 and image_dub2[i, j] != 1:
                interior.append([i, j])
    
    # Draw on the target image
    cv2.drawContours(image, [np.array([coord1, coord2, coord3])], 0, (1), -1)
    
    center_pt = [coord3[0], int((coord1[1] + coord3[1])/2)]
    explanation = get_explanation_mask(center_pt, coord3, np.abs(coord2[0] - coord1[0]))
    
    return interior, explanation, coord1[1], center_pt

def draw_circle(image):
    center = (np.random.randint(20, image.shape[0]-20), np.random.randint(20, image.shape[1]-30))
    point = (center[0], np.random.randint(max(5, center[1]-center[0]), center[1]-5))
    radius = center[1] - point[1]
    while radius < 5 or center[1]+radius >= image.shape[1]-5 or center[0]+radius >= image.shape[0] or center[0]-radius-5<=0:
        point = (center[0], np.random.randint(max(5, center[1]-center[0]), center[1]-5))
        radius = center[1] - point[1]
    
    # Draw on a blank image for easy retrieval of the mask
    image_dub = np.zeros((image.shape), np.uint8)
    image_dub2 = np.zeros((image.shape), np.uint8)
    cv2.circle(image_dub, center, radius, (1), -1)
    cv2.circle(image_dub2, center, radius, (1), 1)
    interior = []
    for i in range(image_dub.shape[0]):
        for j in range(image_dub.shape[1]):
            if image_dub[i, j] == 1 and image_dub2[i, j] != 1:
                interior.append([i, j])
    
    # Draw on the target image
    cv2.circle(image, center, radius, (1), -1)
    explanation = get_explanation_mask(center, point, 2*radius)
    
    return interior, explanation, center[1]+radius, center


def create_triangle_circle_dataset(size=1000, fox_ratio=0.5, image_size=(64, 64), save_dir=None, is_train=True, random_seed=None):
    if random_seed is not None:
        np.random.seed(random_seed)
    is_fox = np.random.binomial(1, fox_ratio, size=size)
    images = np.array([np.zeros(image_size, np.uint8)]*size)
    masks = []
    for i in range(is_fox.shape[0]):
        # print(i)
        # print("image {}, label {}".format(i, is_fox[i]))
        image = images[i]
        if is_fox[i] == 1:
            # Draw a fox
            interior, explanation, ymax, center = draw_triangle(image)
        else:
            # Draw a cat
            interior, explanation, ymax, center = draw_circle(image)

        mask = flatten_mask(explanation, image_size[0], image_size[1], 1)
        masks.append(mask)

    masks = np.array(masks)

    if save_dir is not None:
        save_dataset(images, is_fox, masks, save_dir, "triangle_circle_mono", is_train)
    
    return images, is_fox, masks

def save_dataset(data, labels, masks, directory, name, is_train=True):
    d = {'data': data, 'labels': labels, 'masks': masks}
    mode = "train" if is_train else "test"
    with open(os.path.join(directory, ("{}_{}.pickle").format(name, mode)), 'wb') as handle:
        pickle.dump(d, handle, protocol=pickle.HIGHEST_PROTOCOL)

if __name__ == "__main__":
    data, labels, masks = create_triangle_circle_dataset(size=10000, save_dir="./data", is_train=True)