#!/usr/bin/env python

import os
from pathlib import Path
import pickle
from random import randint, random
from uuid import uuid4
import sys
import traceback

import av
import numpy
from PIL import Image, ImageDraw
import pyttsx3
from tqdm import tqdm

#
# Helpers
#
class Helpers():

    def square(self, **kwargs) -> numpy.array:
        raise NotImplemented()

    def rectangle(self, **kwargs) -> numpy.array:
        raise NotImplemented()

    def ellipse(self, **kwargs) -> numpy.array:
        raise NotImplemented()

    def circle(self, **kwargs) -> numpy.array:
        raise NotImplemented()


class VoiceHelpers(Helpers):

    def __init__(self):
        # No object-wise tts, because it blocks after the first call.
        tts = pyttsx3.init()
        self.baseline_rate = tts.getProperty('rate')
        self.baseline_volume = tts.getProperty('volume')
        self.work_dir = os.path.join('.tmp')
        Path(self.work_dir).mkdir(parents=True, exist_ok=True)
        self.duration_padding = 1 # second

    def square(self, rate_change:float=None, volume_change:float=None) -> numpy.array:
        return self._produce('square', rate_change, volume_change)

    def rectangle(self, rate_change:float=None, volume_change:float=None) -> numpy.array:
        return self._produce('rectangle', rate_change, volume_change)

    def ellipse(self, rate_change:float=None, volume_change:float=None) -> numpy.array:
        return self._produce('ellipse', rate_change, volume_change)

    def circle(self, rate_change:float=None, volume_change:float=None) -> numpy.array:
        return self._produce('circle', rate_change, volume_change)

    def _produce(self, utterance:str,
                 rate_change:float=None, volume_change:float=None) -> numpy.array:
        tmp_file = os.path.join(self.work_dir, str(uuid4())) + '.mp3'
        tts = pyttsx3.init()
        try:
            if rate_change:
                change = self.baseline_rate + rate_change
                tts.setProperty('rate', change)
            if volume_change:
                change = self.baseline_volume + volume_change
                tts.setProperty('volume', change)
            tts.save_to_file(utterance, tmp_file)
            tts.runAndWait()

            samples = []
            with av.open(tmp_file) as audioc:
                for frame in audioc.decode(audio=0):
                    samples.append(frame.to_ndarray())
            samples = numpy.array(samples[:-1], dtype=numpy.float32)
            if samples[0].shape[0] == 1:
                return numpy.hstack([samples, samples])
            else:
                return samples
            return samples
        except:
            print(traceback.format_exc(), file=sys.stderr)
        finally:
            if os.path.exists(tmp_file):
                os.remove(tmp_file)


class ImageHelpers(Helpers):

    def __init__(self):
        pass

    def square(self, height:int=69, width:int=99, outline:tuple=None, fill:tuple=None, line:int=1):
        factor = random()
        if height > width:
            val = int(width * factor)
        else:
            val = int(height * factor)
        h, w = val, val
        return self._produce(height, width, 'rectangle', h, w, outline, fill, line)

    def rectangle(self, height:int=69, width:int=99, outline:tuple=None, fill:tuple=None, line:int=1):
        return self._produce(height, width, 'rectangle', int(height * self._random_scale_down()), int(width * self._random_scale_down()), outline, fill, line)

    def circle(self, height:int=69, width:int=99, outline:tuple=None, fill:tuple=None, line:int=1):
        factor = random()
        if height > width:
            val = int(width * factor)
        else:
            val = int(height * factor)
        h, w = val, val
        return self._produce(height, width, 'ellipse', h, w, outline, fill, line)

    def ellipse(self, height:int=69, width:int=99, outline:tuple=None, fill:tuple=None, line:int=1):
        return self._produce(height, width, 'ellipse', int(height * self._random_scale_down()), int(width * self._random_scale_down()), outline, fill, line)

    def _produce(self, height:int, width:int, shape_fn:str, shape_height:int, shape_width:int,
                 outline:tuple=None, fill:tuple=None, line:int=1) -> numpy.array:
        image = Image.new('RGB', (width, height))
        draw = ImageDraw.Draw(image)
        getattr(draw, shape_fn)(self._points(height, width, shape_height, shape_width),
                                outline=outline, fill=fill, width=line)
        return numpy.array(image)

    def _points(self, oh:int, ow:int, h:int, w:int) -> list:
        origin = (randint(1, max([1, (ow - w) - 1])), randint(1, max([1, (oh - h) - 1])))
        return [origin, (origin[0] + w, origin[1] + h)]

    def _random_scale_down(self):
        '''
        A scale-down factor not too small, for us to distinguish unambiguously the shapes.

        This guarantees a minimum (e.g. 20% of the original), with randomness up to 100%.
        '''
        return 0.2 + random() * 0.8


def gen(config):
    dataset = []
    imagen = ImageHelpers()
    voigen = VoiceHelpers()
    shape_fns = [m for m in dir(Helpers()) if not m.startswith('_')]
    shape_count = len(shape_fns)
    samples = 0
    label_map = { shape: idx for idx, shape in enumerate(shape_fns) }
    with tqdm(config.number) as pbar:
        while samples < config.number:
            shape_fn = shape_fns[samples % shape_count]
            fill_color = (randint(100, 256), randint(100, 256), randint(100, 256))
            outline_color = (randint(0, 256), randint(0, 256), randint(0, 256))
            line_width = randint(1, 3)
            rate_change = randint(0, 50) - 25
            volume_change = random() - 0.5

            frame = getattr(imagen, shape_fn)(fill=fill_color, outline=outline_color, line=line_width)
            image = numpy.array([frame]*config.fps)
            label = numpy.array([label_map[shape_fn]]*config.fps)

            sound = getattr(voigen, shape_fn)(rate_change=rate_change, volume_change=volume_change)
            # Frame rate does not seem configurable in pyttsx
            desired = int(22050 / sound.shape[-1]) - sound.shape[0]
            sound = numpy.pad(sound, ((0, desired), (0, 0), (0, 0)))

            dataset.append({
                'video': image.transpose((0, 3, 1, 2)), # (fps, C, H, W)
                'audio': sound, # (samples, C)
                'label': label,
            })
            samples += 1
            pbar.update(1)
    with open(config.output, 'wb') as f:
        pickle.dump((dataset, {label_map[k]: k for k in label_map}), f)


def data(archive:str) -> list:
    with open(archive, 'rb') as f:
        dataset = pickle.load(f)
    return dataset


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description="Generate a balanced shape synthetic dataset with image and audio, for shape classification.")
    parser.add_argument('--number', '-n',
                        default=10000,
                        type=int,
                        required=False,
                        help="Number of (balanced) samples to generate.")
    parser.add_argument('--fps', '-f',
                        default=30,
                        type=int,
                        required=False,
                        help="FPS generated for the video output.")
    parser.add_argument('--output', '-o',
                        default='shape_dataset.pickle',
                        required=False,
                        help="File name of the dataset archive.")
    args = parser.parse_args()
    gen(args)
