import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pickle
from time import time
from matplotlib.patches import Polygon
import math
import pickle
import os

colors = {'red': (255, 0, 0),
          'green': (0, 255, 0),
          'blue' : (0, 0, 255), 
          'cyan' : (0, 255, 255), 
          'magenta': (255, 0, 255), 
          'yellow': (255, 255, 0)
         }

def draw_shape(n, color, traintest):
    if traintest == 'train':
        color = color + np.random.randint(low=-32, high=32, size=3)
        radius = 0.8 + 0.4 * np.random.random()
        if n == 7: # circle
            shorter_radius = radius * (0.8 + np.random.random()*0.2)
        angles = np.cumsum(2*math.pi/n * (0.8 + np.random.random(n-1) * 0.4))

    else:
        color = color + np.random.randint(low=-64, high=64, size=3)
        radius = 0.6 + 0.6 * np.random.random()
        if n == 7: # circle
            shorter_radius = radius * (0.6 + np.random.random()*0.4)
        angles = np.cumsum(2*math.pi/n * (0.6 + np.random.random(n-1) * 0.8))


    color = np.array([np.max((np.min((v, 255)), 0)) for v in color])
    color = color / 255.

    pos = [[radius, 0]]
    for angle in angles:
        pos.append([radius * math.cos(angle), radius * math.sin(angle)])

    if n == 7:
        polygon = matplotlib.patches.Ellipse([0, 0], width=radius*2, height=shorter_radius*2, color=color, angle=angles[0] / np.pi * 180.)
        polygon_colorblind = matplotlib.patches.Ellipse([0, 0], width=radius*2, height=shorter_radius*2, color='#00FF00', angle=angles[0] / np.pi * 180.)
    else:
        polygon = Polygon(pos, True, color=color)
        polygon_colorblind = Polygon(pos, True, color='#00FF00')

    polygon_shapeblind = Polygon([[1, 0], [-0.5, 0.86602540378], [-0.5, -0.86602540378]], True, color=color)

    fig, ax = plt.subplots()
    fig.set_size_inches((0.32, 0.32))
    ax.set_xlim((-1.2, 1.2))
    ax.set_ylim((-1.2, 1.2))

    ax.add_artist(polygon)

    plt.axis('off')

    fig.tight_layout(pad=0)
    fig.canvas.draw()
    shape_array = np.array(fig.canvas.renderer._renderer)

    plt.close()

    fig, ax = plt.subplots()
    fig.set_size_inches((0.32, 0.32))
    ax.set_xlim((-1.2, 1.2))
    ax.set_ylim((-1.2, 1.2))

    ax.add_artist(polygon_colorblind)

    plt.axis('off')

    fig.tight_layout(pad=0)
    fig.canvas.draw()
    shape_array_colorblind = np.array(fig.canvas.renderer._renderer)

    plt.close()

    fig, ax = plt.subplots()
    fig.set_size_inches((0.32, 0.32))
    ax.set_xlim((-1.2, 1.2))
    ax.set_ylim((-1.2, 1.2))

    ax.add_artist(polygon_shapeblind)

    plt.axis('off')

    fig.tight_layout(pad=0)
    fig.canvas.draw()
    shape_array_shapeblind = np.array(fig.canvas.renderer._renderer)

    plt.close()

    return [shape_array, shape_array_colorblind, shape_array_shapeblind]

if not os.path.isdir('./data/shapes'):
    os.mkdir('./data/shapes')

for traintest in ['train', 'test']:
    for n in range(3, 8):
        color_names = []
        arrays = []
        for i in range(1000):
            color_name = list(colors.keys())[np.random.randint(len(colors))]
            color = np.array(list(colors[color_name]))
            s = draw_shape(n, color, traintest)
            s = np.array(s)
            assert s.shape == (3, 32, 32, 4)
            color_names.append(color_name)
            arrays.append(s)
            if i % 100 == 0:
                print(i)
        np.save('./data/shapes/{}_{}.npy'.format(n, traintest), arrays)
        pickle.dump(color_names, open('./data/shapes/colors_{}_{}.pkl'.format(n, traintest), 'wb'))
