import numpy as np
import random
from spriteworld import environment as spriteworld_environment
from spriteworld import factor_distributions as distribs
from spriteworld import sprite_generators
from spriteworld import tasks
import argparse
from spriteworld import renderers as spriteworld_renderers
import torch
import matplotlib.pyplot as plt


def random_sprites_config(num_objects, seed):
    """
    Computes config for spriteworld renderer
    Adapted from code given by Loic Matthey: https://gist.github.com/Azhag/f249b51b4cf3bfe568584f3a827708c1

    Args:
        num_objects: maximum number of objects in image

    Returns:
        config for spriteworld renderer
    """

    random.seed(seed)
    np.random.seed(seed)

    if num_objects == 1:
        factors = distribs.Product([
          distribs.Beta('x', 3., 3.),
          distribs.Beta('y', 3., 3.),
          distribs.Discrete('shape', ['triangle']),
          distribs.Beta('scale', 3., 3.),
          distribs.Continuous('angle', 0, 0),
          distribs.Beta('c0', 3., 3.),
          distribs.Continuous('c1', 3., 3.),
          distribs.Continuous('c2', 1., 1.),
        ])

    else:
        factors = distribs.Product([
          distribs.Continuous('x', .1, .8),
          distribs.Continuous('y', .1, .8),
          distribs.Discrete('shape', ['square', 'triangle', 'circle']),
          distribs.Continuous('scale', .08, .13),
          distribs.Continuous('angle', 0, 0),
          distribs.Continuous('c0', 0., 1.),
          distribs.Continuous('c1', 3., 3.),
          distribs.Continuous('c2', 1., 1.),
        ])

    if num_objects == 1:
        num_sprites = lambda: np.random.randint(1, 2)
    else:
        num_sprites = lambda: np.random.randint(2, num_objects + 1)
    sprite_gen = sprite_generators.generate_sprites(factors, num_sprites=num_sprites)

    renderers = {
      'image':
          spriteworld_renderers.PILRenderer(
              image_size=(64, 64),
              anti_aliasing=5,
              color_to_rgb=spriteworld_renderers.color_maps.hsv_to_rgb,
          ),
      'attributes':
          spriteworld_renderers.SpriteFactors(
              factors=('x', 'y', 'shape', 'angle', 'scale', 'c0', 'c1', 'c2')),
    }

    config = {
      'task': tasks.NoReward(),
      'action_space': None,
      'renderers': renderers,
      'init_sprites': sprite_gen,
      'max_episode_length': 1,
    }
    return config


def get_masks(slot_images):
    # sum over color channels
    slot_images = torch.from_numpy(slot_images).permute(0, 1, 4, 2, 3)
    slot_images = slot_images.abs().sum(2)
    num_images, num_slots, im_size, im_size = slot_images.shape

    # binarize image. use threshold to remove artifacts
    slot_images = slot_images.flatten()
    object_mask = torch.where(slot_images > 8, 1., 0.)
    object_mask = object_mask.view(num_images, num_slots, im_size, im_size)

    # get background mask
    summed_slot_mask = object_mask.sum(1).flatten()
    background_mask = torch.where(summed_slot_mask > 0, 0., 1.).view(num_images, 1, im_size, im_size)

    # merge object mask and background mask
    masks = torch.cat((background_mask, object_mask), 1)
    masks = masks.argmax(dim=1, keepdim=True).squeeze(1).numpy().astype('uint8')
    return masks


# iterate over object images. take absolute value and sum over color channels.
def collect_frames(config, max_objects, num_frames, shape_dict):
    """
    Instantiate config as environment and get single images from it.
    Adapted from code given by Loic Matthey: https://gist.github.com/Azhag/f249b51b4cf3bfe568584f3a827708c1

    Args:
        config: sprites config given by def random_sprites_config
        max_objects: maximum number of objects in image
        num_frames: number of samples in dataset
        shape_dict: dictionary specifying possible values for shapes in image

    Returns:
        arrays containing dataset of images, corresponding latents, and number of objects in each image
    """
    env = spriteworld_environment.Environment(**config)
    images = []
    if max_objects == 1:
        num_factors = 4
    else:
        num_factors = 5
    Z = np.zeros((num_frames, max_objects, num_factors))
    for i in range(num_frames):
        print(i)
        ts = env.reset()
        for j in range(len(env._sprites)):
            Z[i, j, 0] = env._sprites[j].x
            Z[i, j, 1] = env._sprites[j].y
            Z[i, j, 2] = env._sprites[j].scale
            Z[i, j, 3] = env._sprites[j].c0
            if num_factors > 4:
                Z[i, j, 4] = shape_dict[env._sprites[j].shape]

        obs = ts.observation['image']
        obs.insert(0, obs[-1])
        del obs[-1]

        for i in range(max_objects - len(env._sprites)):
            obs.insert(len(obs), np.zeros((64, 64, 3)).astype('uint8'))

        images.append(obs)

    images = np.array(images).astype('uint8')
    obs = images[:, 0, :, :, :]
    if max_objects > 1:
        masks = get_masks(images[:, 1:(max_objects + 1), :, :, :])
    else:
        masks = None

    return obs, Z, masks


def gen_sprites(max_objects, num_obs):
    """
    Function to generate sprites dataset. Saves dataset of observations, latents, and number of objects in each image
    as a numpy array.
    Adapted from code given by Loic Matthey: https://gist.github.com/Azhag/f249b51b4cf3bfe568584f3a827708c1

    Args:
        max_objects: maximum number of objects in image
        num_obs: number of samples in dataset
    """
    data_name = str(max_objects) + "_obj_sprites"
    shape_dict = {"triangle": 1, "circle": 2, "square": 3}
    X, Z, masks = collect_frames(random_sprites_config(max_objects, seed=0), max_objects, num_obs, shape_dict)
    #if max_objects > 1:
        #np.savez_compressed("data/datasets/" + data_name, X, Z, masks)
    #else:
        #np.savez_compressed("data/datasets/"+data_name, X, Z)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_objects", type=int, default="4")
    parser.add_argument("--nobs", type=int, default="25")
    args = parser.parse_args()
    gen_sprites(max_objects=args.max_objects, num_obs=args.nobs)

    