from typing import Dict, List, Tuple

import dataclasses
import collections
from clu import preprocess_spec

import tensorflow as tf

Features = preprocess_spec.Features

all_ops = lambda: preprocess_spec.get_all_ops(__name__)

@dataclasses.dataclass
class ImageRescaleToFloat:
  name: str = "image"
  min_val: float = -1.0
  max_val: float = 1.0

  def __call__(self, features: Features) -> Features:
    if self.name in features:
      image = features[self.name]      
      image = tf.cast(image, dtype=tf.float32)
      image = image / 255.0
      image = image * (self.max_val - self.min_val) + self.min_val
      features[self.name] = image
    else:
      raise ValueError("{} not found in features".format(self.name))

    return features


@dataclasses.dataclass
class ImageCrop:
  name: str = "image"
  x_min: int = 29
  x_max: int = 221
  y_min: int = 64
  y_max: int = 256

  def __call__(self, features: Features) -> Features:
    if self.name in features:
      image = features[self.name]
      crop = ((self.x_min, self.x_max), (self.y_min, self.y_max))
      image = image[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1], :]
      features[self.name] = image
    else:
      raise ValueError("{} not found in features".format(self.name))

    return features


@dataclasses.dataclass
class MaskCrop:
  name: str = "mask"
  x_min: int = 29
  x_max: int = 221
  y_min: int = 64
  y_max: int = 256

  def __call__(self, features: Features) -> Features:
    if self.name in features:
      image = features[self.name]
      crop = ((self.x_min, self.x_max), (self.y_min, self.y_max))
      image = image[:, crop[0][0]:crop[0][1], crop[1][0]:crop[1][1], :]
      features[self.name] = image
    else:
      raise ValueError("{} not found in features".format(self.name))

    return features


@dataclasses.dataclass
class ImageResize:
  name: str = "image"
  resolution: Tuple[int,
                    int] = dataclasses.field(default_factory=lambda: (128, 128))
  resize_method: str = "bilinear"
      
  def __call__(self, features: Features) -> Features:
    if self.name in features:
      resize_method = None
      if self.resize_method == "bilinear":
        resize_method = tf.image.ResizeMethod.BILINEAR
      elif self.resize_method == "nearest_neighbor":
        resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
      else:
        raise ValueError("{} not found in resize methods".format(self.resize_method))
      image = features[self.name]
      image = tf.image.resize(
          image, self.resolution, method=resize_method)
      features[self.name] = image
    else:
      raise ValueError("{} not found in features".format(self.name))

    return features

@dataclasses.dataclass
class ImageClipByValue:
  name: str = "image"
  min_val: float = -1.0
  max_val: float = 1.0

  def __call__(self, features: Features) -> Features:
    if self.name in features:
      image = features[self.name]
      image = tf.clip_by_value(image, self.min_val, self.max_val)
      features[self.name] = image
    else:
      raise ValueError("{} not found in features".format(self.name))

    return features

@dataclasses.dataclass
class ClevrLabelPreprocess:
  get_properties: bool = True
  max_n_objects: int = 10

  def __call__(self, features: Features) -> Features:
    if self.get_properties:
      # One-hot encoding of the discrete features.
      size = tf.one_hot(features["objects"]["size"], 2)
      material = tf.one_hot(features["objects"]["material"], 2)
      shape_obj = tf.one_hot(features["objects"]["shape"], 3)
      color = tf.one_hot(features["objects"]["color"], 8)
      # Originally the x, y, z positions are in [-3, 3].
      # We re-normalize them to [0, 1].
      coords = (features["objects"]["3d_coords"] + 3.) / 6.
      properties_dict = collections.OrderedDict({
          "3d_coords": coords,
          "size": size,
          "material": material,
          "shape": shape_obj,
          "color": color
      })
    properties_tensor = tf.concat(list(properties_dict.values()), axis=1)

    # Add a 1 indicating these are real objects.
    properties_tensor = tf.concat(
        [properties_tensor,
         tf.ones([tf.shape(properties_tensor)[0], 1])], axis=1)

    # Pad the remaining objects.
    properties_pad = tf.pad(
        properties_tensor,
        [[0, self.max_n_objects - tf.shape(properties_tensor)[0],], [0, 0]],
        "CONSTANT")
    return {
        "image": features["image"],
        "target": properties_pad
    }


@dataclasses.dataclass
class ClevrMaskPreprocess:
  name: str = "mask"
  tgt_name: str = "segmentations"

  def __call__(self, features: Features) -> Features:
    assert self.name in features   
    masks = features[self.name][..., 0] >= 0.5
    # [batch_size, n_objects, height, width]
    batch_shape = masks.shape[:-3]
    masks = tf.cast(tf.reshape(masks, [-1] + list(masks.shape[-3:])), tf.int32)
    instance_ids = tf.range(0, masks.shape[-3])[tf.newaxis, ..., tf.newaxis, tf.newaxis]
    segmentation = tf.reduce_max(masks * instance_ids, axis=-3)
    segmentation = tf.reshape(segmentation, batch_shape + segmentation.shape[-2:])
    features[self.tgt_name] = segmentation
    if self.tgt_name != self.name:
      del features[self.name]
    return features

@dataclasses.dataclass
class TetrominoesMaskPreprocess:
  name: str = "mask"
  tgt_name: str = "segmentations"

  def __call__(self, features: Features) -> Features:
    assert self.name in features   
    masks = features[self.name][..., 0] >= 0.5
    # [batch_size, n_objects, height, width]
    batch_shape = masks.shape[:-3]
    masks = tf.cast(tf.reshape(masks, [-1] + list(masks.shape[-3:])), tf.int32)
    instance_ids = tf.range(0, masks.shape[-3])[tf.newaxis, ..., tf.newaxis, tf.newaxis]
    segmentation = tf.reduce_max(masks * instance_ids, axis=-3)
    segmentation = tf.reshape(segmentation, batch_shape + segmentation.shape[-2:])
    features[self.tgt_name] = segmentation
    if self.tgt_name != self.name:
      del features[self.name]
    return features
