import ml_collections


def get_config(prefix_path="ebm_obj."):
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()

  config.seed = 42
  config.seed_data = True
  config.workdir = None
  num_steps = ml_collections.config_dict.FieldReference(5)
  batch_size = ml_collections.config_dict.FieldReference(128)
  config.batch_size = batch_size
  config.num_train_steps = 500000
  config.num_decay_steps = 500000
  config.num_eval_steps = -1

  config.run_eval = False
  debug_fit_single_batch = ml_collections.config_dict.FieldReference(False)
  config.debug_fit_single_batch = debug_fit_single_batch

  # Adam optimizer config.
  config.learning_rate = 2e-4
  config.warmup_steps = 2500
  config.max_grad_norm = 1.0
  config.lr_end_value = 0.0

  config.log_loss_every_steps = 50
  config.eval_every_steps = 10000
  config.checkpoint_every_steps = 10000

  config.train_metrics_spec = {
      "loss": "loss",
      "ari": "ari",
      "ari_nobg": "ari_nobg",
  }
  config.eval_metrics_spec = {
      "eval_loss": "loss",
      "eval_ari": "ari",
      "eval_ari_nobg": "ari_nobg",
  }

  config.data = ml_collections.ConfigDict({
      "tfds_name":
          "clevr:3.1.0",  # Dataset for training/eval.
      "data_dir":
          "",
      # "data_dir": "gs://tensorflow-datasets/datasets",
      "input_shape": [35, 35, 3],
      "train_split_size":
          60_000,
      "validation_split_size":
          320,
      "shuffle_buffer_size":
          batch_size * 10,
      "train_batch_size":
          batch_size,
      "validation_batch_size":
          32,
      "preproc_train": [
          "image_rescale_to_float()",
          "image_rescale_to_float(name='mask', min_val=0.0, max_val=1.0)",
          # "image_crop()",
          # "mask_crop(name='segmentations')",
          # "image_resize()",
          # "image_resize(name='segmentations', resize_method='nearest_neighbor')",
          "multi_dsprites_mask_preprocess",
      ],
      "preproc_eval": [
          "image_rescale_to_float()",
          "image_rescale_to_float(name='mask', min_val=0.0, max_val=1.0)",          
          "multi_dsprites_mask_preprocess",
      ],
  })

  config.data.max_instances = 5
  config.data.logging_min_n_colors = config.data.max_instances
  config.data.debug_fit_single_batch = debug_fit_single_batch

  # Slots
  config.num_slots = config.data.max_instances + 1  # Only used for metrics.
  config.slot_embed_dim = 64

  # Dictionary of targets and corresponding channels. Losses need to match.
  config.targets = {"image": 3}

  config.losses = ml_collections.ConfigDict({
      f"recon_{target}": {
          "loss_type": "recon_image",
          "key": target,
          "reduction_type": "mean",
      } for target in config.targets
  })

  def get_module_path(module_name):
    return "".join([prefix_path, module_name])

  config.model = ml_collections.ConfigDict(
    {
      "module": get_module_path("models.ImageModel"),
      # Sampler.
      "sampler": ml_collections.ConfigDict(
        {
          "module": get_module_path("modules.ULASampler"),
          "z_initializer": ml_collections.ConfigDict(
            {
              "module": get_module_path("modules.GaussianStateInit"),
              "shape": [config.num_slots, config.slot_embed_dim],
              "learnable": True,
            }
          ),
          "ebm": ml_collections.ConfigDict(
            {
              "module": get_module_path("modules.SimpleEBM"),
              "image_transform": ml_collections.ConfigDict(
                {
                  "module": get_module_path(
                    "modules.CNNPosEmbTransform"
                  ),
                  "backbone": ml_collections.ConfigDict(
                    {
                      "module": get_module_path(
                        "modules.SimpleCNN"
                      ),
                      "features": [32, 32, 32, 32],
                      "kernel_size": [
                        (5, 5),
                        (5, 5),
                        (5, 5),
                        (5, 5),
                      ],
                      "strides": [(1, 1), (1, 1), (1, 1), (1, 1)],
                      "layer_transpose": [
                        False,
                        False,
                        False,
                        False,
                      ],
                    }
                  ),
                  "pos_emb": ml_collections.ConfigDict(
                    {
                      "module": get_module_path(
                        "modules.PositionEmbedding"
                      ),
                      "embedding_type": "linear",
                      "update_type": "project_add",
                      "output_transform": ml_collections.ConfigDict(
                        {
                          "module": get_module_path(
                            "modules.MLP"
                          ),
                          "hidden_size": 32,
                          "layernorm": "pre",
                        }
                      ),
                    }
                  ),
                  "reduction": "spatial_flatten",
                }
              ),
              "slot_transform": ml_collections.ConfigDict(
                {"module": get_module_path("modules.Identity")}
              ),
              "fuse_transform": ml_collections.ConfigDict(
                {
                  "module": get_module_path("modules.FuseModule"),
                  "attention_block": ml_collections.ConfigDict(
                    {
                      "module": get_module_path(
                        "modules.CrossAttention1DBlock"
                      ),
                      "mlp_dim_mul": 1,
                      "num_heads": 4,
                      "dropout_rate": 0.0,
                      "attention_dropout_rate": 0.0,
                    }
                  ),
                  "qkv_features": 128,
                  "num_blocks": 2,
                }
              ),
              "output_transform": ml_collections.ConfigDict(
                {
                  "module": get_module_path(
                    "modules.EBMOutputModule"
                  ),
                  "mlp_dims": [32],
                }
              ),
            }
          ),
          "dt": 1e-2,
          "wn": 1.0,
          "num_steps": num_steps,
        }
      ),
      # Decoder.
      "decoder": ml_collections.ConfigDict(
        {
          "module": get_module_path("modules.SpatialBroadcastDecoder"),
          "resolution": (
            64,
            64,
          ),  # Update if data resolution or strides change.
          "backbone": ml_collections.ConfigDict(
            {
              "module": get_module_path("modules.SimpleCNN"),
              "features": [32, 32, 32, 32],
              "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
              "strides": [(1, 1), (1, 1), (1, 1), (1, 1)],
              "layer_transpose": [True, True, True, True],
            }
          ),
          "pos_emb": ml_collections.ConfigDict(
            {
              "module": get_module_path("modules.PositionEmbedding"),
              "embedding_type": "linear",
              "update_type": "project_add",
            }
          ),
          "target_readout": ml_collections.ConfigDict(
            {
              "module": get_module_path("modules.Readout"),
              "keys": list(config.targets),
              "readout_modules": [
                ml_collections.ConfigDict(
                  {  # pylint: disable=g-complex-comprehension
                    "module": get_module_path("modules.Dense"),
                    "features": config.targets[k],
                  }
                )
                for k in config.targets
              ],
            }
          ),
        }
      ),
    }
  )

  # Define which video-shaped variables to log/visualize.
  config.debug_var_video_paths = {
      "recon_masks": "SpatialBroadcastDecoder_0/alphas",
  }
  for k in config.targets:
    config.debug_var_video_paths.update(
        {f"{k}_recon": f"SpatialBroadcastDecoder_0/{k}_combined"})

  # Define which attention matrices to log/visualize.
  config.debug_var_attn_paths = {
      # "corrector_attn": "SlotAttention_0/InvertedDotProductAttention_0/GeneralizedDotProductAttention_0/attn"  # pylint: disable=line-too-long
  }

  # Widths of attention matrices (for reshaping to image grid).
  config.debug_var_attn_widths = {
      "corrector_attn": 64,
  }

  return config