from enum import Enum, IntEnum
import json
import os
from random import Random

from PIL import Image, ImageDraw
from tqdm.auto import tqdm

SIZE = 300
MARGIN = 10
VARIANCE = 15
SATURATION_RANGE = (128, 255)
VALUE_RANGE = (64, 255)

EXAMPLES = 50
ROOT = f'data/shapes-{EXAMPLES}'
SEED = 0


class Color(IntEnum):
    # Original 0-360 hue space is mapped to 0-255 by PIL
    RED = 0
    GREEN = 85
    BLUE = 170


class Position(IntEnum):
    START = 50
    MIDDLE = 150
    END = 250


class Shape(Enum):
    ELLIPSE = 0
    RECTANGLE = 1
    TRIANGLE = 2


class Size(IntEnum):
    SMALL = 40
    MEDIUM = 100
    LARGE = 160

class HigherOrder(Enum):
    BUBA = 0
    KIKI = 1


def get_higher_order(shape, hsv):
    hue, _, _ = hsv
    # blue = 0 -> red = 0.5
    # circle = 0, square = 0.25, triangle = 0.5
    # blue circle = maximally bouba,
    # red triangle = maximally kiki
    score = 0.
    if shape == Shape.RECTANGLE:
        score += 0.25
    elif shape == Shape.TRIANGLE:
        score += 0.5

    score += get_hue_score(hue)

    # if shape == Shape.ELLIPSE:
    #     return HigherOrder.BUBA
    # elif shape == Shape.TRIANGLE or (color == Color.RED or color == Color.BLUE):
    #     return HigherOrder.KIKI
    # return HigherOrder.BUBA
    label = HigherOrder.BUBA if score <= 0.5 else HigherOrder.KIKI

    return score, label


def get_hue_score(hue):
    if hue > (Color.BLUE + VARIANCE):
        # Circle back to red through negative values
        hue -= 256

    hue_score = 0.5 * (1. - (hue + VARIANCE) / (Color.BLUE + 2 * VARIANCE))

    return hue_score


def get_bit8_score(bit8_val):
    return bit8_val / (2. ** 8)


def make_image(shape: Shape,
               color: Color,
               x: Position,
               y: Position,
               width: Size,
               height: Size,
               rng: Random) -> dict:

    image = Image.new('HSV', (SIZE, SIZE), (0, 0, 192))
    draw = ImageDraw.Draw(image)

    hue = (color + rng.randint(-VARIANCE, VARIANCE)) % 256
    saturation = rng.randint(*SATURATION_RANGE)
    value = rng.randint(*VALUE_RANGE)
    hsv = (hue, saturation, value)

    metadata = {'shape': shape,
                'color': color,
                'x': x,
                'y': y,
                'width': width,
                'height': height,
                'hue': get_bit8_score(hue),
                'saturation': get_bit8_score(saturation),
                'value': get_bit8_score(value)}
    ho_score, ho_label = get_higher_order(shape, hsv)
    metadata.update(**{
        'higher-order': ho_label,
        'higher-order-score': ho_score,
    })

    x += rng.randint(-VARIANCE, VARIANCE)
    y += rng.randint(-VARIANCE, VARIANCE)
    width += rng.randint(-VARIANCE, VARIANCE)
    height += rng.randint(-VARIANCE, VARIANCE)

    left = max(MARGIN, x - width // 2)
    right = min(SIZE - MARGIN, left + width)

    top = max(MARGIN, y - height // 2)
    bottom = min(SIZE - MARGIN, top + height)

    if shape == Shape.ELLIPSE:
        draw.ellipse((left, top, right, bottom), fill=hsv)
    elif shape == Shape.RECTANGLE:
        draw.rectangle((left, top, right, bottom), fill=hsv)
    elif shape == Shape.TRIANGLE:
        draw.polygon(
            (((left + right) // 2, top), (left, bottom), (right, bottom)),
            fill=hsv
        )
    else:
        raise Exception('Unknown shape')

    return image, metadata


args = []
for shape in Shape:
    for color in Color:
        for x in Position:
            for y in Position:
                for width in Size:
                    for height in Size:
                        args += [(shape, color, x, y, width, height)]*EXAMPLES

os.makedirs(ROOT, exist_ok=True)
all_metadata = []
rng = Random(SEED)
for i, arg in enumerate(tqdm(args)):
    file_name = f'{i:05}.png'

    image, metadata = make_image(*arg, rng=rng)
    image.convert('RGB').save(f'{ROOT}/{file_name}')

    for k in ['shape', 'color', 'x', 'y', 'width', 'height', 'higher-order']:
        metadata[k] = metadata[k].name.lower()
    metadata['file_name'] = file_name
    all_metadata.append(metadata)

with open(f'{ROOT}/metadata.jsonl', 'w') as f:
    f.write('\n'.join(json.dumps(metadata) for metadata in all_metadata))
