"""Tetrominoes with masks dataset reader."""

import abc
import dataclasses
import functools
from typing import Dict, List, Tuple

from clu import deterministic_data
from clu import preprocess_spec
from ebm_obj.datasets import preprocessing
from ebm_obj.datasets.utils import get_batch_dims
from ebm_obj.datasets import dataset_registry
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
"""Data utils."""
import collections

import os
from ebm_obj.datasets.clevr import preprocess_example
from ebm_obj.datasets.utils import get_batch_dims


COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP')
IMAGE_SIZE = [35, 35]
# The maximum number of foreground and background entities in the provided
# dataset. This corresponds to the number of segmentation masks returned per
# scene.
MAX_NUM_ENTITIES = 4
BYTE_FEATURES = ['mask', 'image']

# Create a dictionary mapping feature names to `tf.Example`-compatible
# shape and data type descriptors.
features = {
    'image': tf.io.FixedLenFeature(IMAGE_SIZE+[3], tf.string),
    'mask': tf.io.FixedLenFeature([MAX_NUM_ENTITIES]+IMAGE_SIZE+[1], tf.string),
    'x': tf.io.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
    'y': tf.io.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
    'shape': tf.io.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
    'color': tf.io.FixedLenFeature([MAX_NUM_ENTITIES, 3], tf.float32),
    'visibility': tf.io.FixedLenFeature([MAX_NUM_ENTITIES], tf.float32),
}


def _decode(example_proto):
  # Parse the input `tf.Example` proto using the feature description dict above.
  single_example = tf.io.parse_single_example(example_proto, features)
  for k in BYTE_FEATURES:
    single_example[k] = tf.squeeze(tf.io.decode_raw(single_example[k], tf.uint8),
                                   axis=-1)
  return single_example


def dataset(tfrecords_path, read_buffer_size=None, map_parallel_calls=None):
  """Read, decompress, and parse the TFRecords file.
  Args:
    tfrecords_path: str. Path to the dataset file.
    read_buffer_size: int. Number of bytes in the read buffer. See documentation
      for `tf.data.TFRecordDataset.__init__`.
    map_parallel_calls: int. Number of elements decoded asynchronously in
      parallel. See documentation for `tf.data.Dataset.map`.
  Returns:
    An unbatched `tf.data.TFRecordDataset`.
  """
  raw_dataset = tf.data.TFRecordDataset(
      tfrecords_path, compression_type=COMPRESSION_TYPE,
      buffer_size=read_buffer_size)
  # tf.data.TFRecordDataset(data_path, compression_type=COMPRESSION_TYPE, buffer_size=config.shuffle_buffer_size).reduce(np.int64(0), lambda x,_: x + 1).numpy()
  return raw_dataset.map(_decode, num_parallel_calls=map_parallel_calls)


def clevr_filter_fn(example, max_n_objects=10):
  """Filter examples based on number of objects.

  The dataset only has feature values for visible/instantiated objects. We can
  exploit this fact to count objects.

  Args:
    example: Dictionary of tensors, decoded from tf.Example message.
    max_n_objects: Integer, maximum number of objects (excl. background) for
      filtering the dataset.

  Returns:
    Predicate for filtering.
  """
  if "visibility" not in example:
    return tf.constant(True)
  return tf.less_equal(
      tf.reduce_sum(example["visibility"]),
      tf.constant(max_n_objects + 1e-5, dtype=tf.float32),
  )

@dataset_registry.register("tetrominoes_with_masks")
def create_tetrominoes_with_masks_dataset(
    config: ml_collections.ConfigDict,
    data_rng
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
  """Create CLEVR datasets with masks for training and evaluation.

  Returns:
    A tuple with the dataset info, the training dataset and the evaluation
    dataset.
  """
  datasets = dict()

  for key in ("train", "validation"):

    if key == "train":
      preprocess_fn = functools.partial(
          preprocess_example, preprocess_strs=config.preproc_train)
    else:
      preprocess_fn = functools.partial(
          preprocess_example, preprocess_strs=config.preproc_eval)

    data_path = config.data_dir
    files = tf.data.Dataset.list_files(data_path)
    data_reader = functools.partial(
        tf.data.TFRecordDataset,
        compression_type=COMPRESSION_TYPE,
        buffer_size=config.shuffle_buffer_size) # using shuffle_buffer_size for reading
    dataset = files.interleave(
        data_reader,
        # Number of parallel readers (set to 1 for determinism).
        # cycle_length=1, 
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

        
    parse_fn = functools.partial(_decode)
    dataset = dataset.map(
        parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if key == "validation":
      dataset = dataset.skip(config.get("train_split_size"))
    if config.debug_fit_single_batch:
      dataset = dataset.take(config.batch_size)
    else:
      dataset = dataset.take(config.get(f"{key}_split_size"))
    
    filter_fn = functools.partial(
      clevr_filter_fn, max_n_objects=config.max_instances + 1)
    
    batch_dims = get_batch_dims(config.batch_size)

    class TeTrominoesWithMasksDatasetBuilder():

      def as_dataset(self, *unused_args, ds=dataset, **unused_kwargs):
        return ds

    dataset_batched = deterministic_data.create_dataset(
        TeTrominoesWithMasksDatasetBuilder(),
        split="",  # Split is only passed to dataset_builder which ignores it.
        rng=data_rng,
        filter_fn=filter_fn,
        preprocess_fn=preprocess_fn,
        cache=False,
        shuffle_buffer_size=config.shuffle_buffer_size,
        batch_dims=batch_dims,
        num_epochs=None if key == "train" else 1,
        shuffle=True if key == "train" else False,
        # pad_up_to_batches=pad_up_to_batches,  # TODO(tkipf): Fix.
        )
    
    datasets[key] = dataset_batched

  return datasets["train"], datasets["validation"]

