import tensorflow as tf
from ebm_obj.configs.tetrominoes.default import get_config as get_default_config


def get_config(prefix_path="ebm_obj."):
  cfg = get_default_config(prefix_path)
  cfg.data.update(
    {
      "tfds_name": "tetrominoes_with_masks",
      "input_shape": [35, 35, 3],
      "data_dir": "datasets/tetrominoes_with_masks/tetrominoes_train.tfrecords",
      "preproc_train": [
        "image_rescale_to_float()",
        "image_rescale_to_float(name='mask', min_val=0.0, max_val=1.0)",
        "image_resize(resolution=(35, 35))",
        "image_resize(name='mask', resolution=(35, 35), resize_method='nearest_neighbor')",
        "tetrominoes_mask_preprocess",
      ],
      "preproc_eval": [
        "image_rescale_to_float()",
        "image_rescale_to_float(name='mask', min_val=0.0, max_val=1.0)",
        "image_resize(resolution=(35, 35))",
        "image_resize(name='mask', resolution=(35, 35), resize_method='nearest_neighbor')",
        "tetrominoes_mask_preprocess",
      ],
    }
  )
  return cfg
