"""input pipeline for CLEVR"""

import abc
import dataclasses
import functools
from typing import Dict, List, Tuple

from clu import deterministic_data
from clu import preprocess_spec
from ebm_obj.datasets import preprocessing
from ebm_obj.datasets.utils import get_batch_dims
from ebm_obj.datasets import dataset_registry
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
"""Data utils."""
import collections
import tensorflow as tf
import tensorflow_datasets as tfds


def clevr_filter_fn(example, max_n_objects=10):
  """Filter examples based on number of objects.

  The dataset only has feature values for visible/instantiated objects. We can
  exploit this fact to count objects.

  Args:
    example: Dictionary of tensors, decoded from tf.Example message.
    max_n_objects: Integer, maximum number of objects (excl. background) for
      filtering the dataset.

  Returns:
    Predicate for filtering.
  """  
  return tf.less_equal(
      tf.shape(example["objects"]["3d_coords"])[0],
      tf.constant(max_n_objects, dtype=tf.int32))


def preprocess_example(features: Dict[str, tf.Tensor],
                       preprocess_strs: List[str]) -> Dict[str, tf.Tensor]:
  """Processes a single data example.

  Args:
    features: A dictionary containing the tensors of a single data example.
    preprocess_strs: List of strings, describing one preprocessing operation
      each, in clu.preprocess_spec format.

  Returns:
    Dictionary containing the preprocessed tensors of a single data example.
  """
  all_ops = preprocessing.all_ops()
  preprocess_fn = preprocess_spec.parse("|".join(preprocess_strs), all_ops)
  return preprocess_fn(features)  # pytype: disable=bad-return-type  # allow-recursive-types


@dataset_registry.register("clevr:3.1.0")
def create_clevr_dataset(
    config: ml_collections.ConfigDict,
    data_rng) -> Tuple[int, tf.data.Dataset, tf.data.Dataset]:
  """create dataset for CLEVR."""

  dataset_builder = tfds.builder(
      config.tfds_name, data_dir=config.data_dir)
  batch_dims = get_batch_dims(config.batch_size)

  train_preprocess_fn = functools.partial(
      preprocess_example, preprocess_strs=config.preproc_train)
  eval_preprocess_fn = functools.partial(
      preprocess_example, preprocess_strs=config.preproc_eval)

  train_split_name = config.get("train_split", "train")
  eval_split_name = config.get("validation_split", "validation")

  train_split = deterministic_data.get_read_instruction_for_host(
      train_split_name, dataset_info=dataset_builder.info)

  filter_fn = functools.partial(
      clevr_filter_fn, max_n_objects=config.max_instances)

  train_ds = deterministic_data.create_dataset(
      dataset_builder,
      split=train_split,
      rng=data_rng,
      filter_fn=filter_fn,
      preprocess_fn=train_preprocess_fn,
      cache=False,
      shuffle_buffer_size=config.shuffle_buffer_size,
      batch_dims=batch_dims,
      num_epochs=None,
      shuffle=True)

  eval_split = deterministic_data.get_read_instruction_for_host(
      eval_split_name, dataset_info=dataset_builder.info, drop_remainder=False)
  eval_ds = deterministic_data.create_dataset(
      dataset_builder,
      split=eval_split,
      rng=None,
      preprocess_fn=eval_preprocess_fn,
      cache=False,
      batch_dims=batch_dims,
      num_epochs=1,
      shuffle=False,
      pad_up_to_batches=None)

  return train_ds, eval_ds
