import tensorflow as tf
from ebm_obj.configs.clevr.default import get_config as get_default_config


def get_config(prefix_path="ebm_obj."):
  cfg = get_default_config(prefix_path)
  cfg.data.update(
    {
      "tfds_name": "clevr_with_masks",
      "data_dir": "datasets/clevr_with_masks/clevr_with_masks_train.tfrecords",
      "preproc_train": [
        "image_rescale_to_float()",
        "image_rescale_to_float(name='mask', min_val=0.0, max_val=1.0)",
        "image_crop()",
        "mask_crop()",
        "image_resize()",
        "image_resize(name='mask', resize_method='nearest_neighbor')",
        "clevr_mask_preprocess",
      ],
      "preproc_eval": [
        "image_rescale_to_float()",
        "image_rescale_to_float(name='mask', min_val=0.0, max_val=1.0)",
        "image_crop()",
        "mask_crop()",
        "image_resize()",
        "image_resize(name='mask', resize_method='nearest_neighbor')",
        "clevr_mask_preprocess",
      ],
    }
  )
  return cfg
