  config.dataset_configs.pp_train = (
      f'decode_jpeg_and_inception_crop({INPUT_RES}, area_min=20)|'
      f'flip_lr|' + get_image_normalization_transform(config) +
      'onehot(1000, key="label", key_result="labels")|'
      'keep("image", "labels")')

  config.dataset_configs.pp_eval = (
      'decode(channels=3)|'
      f'resize_small({RESIZE_RES})|central_crop({INPUT_RES})|' +
      get_image_normalization_transform(config) +
      'onehot(1000, key="label", key_result="labels")|'
      'keep("image", "labels")')

  config.random_mask = config_structure.RandomMask()
  config.random_mask.strategy = 'constant'
  config.random_mask.prob = 0.75

  config.model.tokenizer.add_cls_token = True
  config.model.tokenizer.encode_mask_tokens = False
  config.model.tokenizer.use_relative_attention_bias = False
  config.model.tokenizer.positional_embeddings_type = '2d_sin_cos'
  config.model.tokenizer.mask_attention_to_mask_tokens = False

  config.model.encoder.dropout_rate = 0.
  config.model.encoder.use_final_layernorm = True
  config.model.heads.reconstruct = config_structure.ReconstructionHead()
  config.model.heads.reconstruct.num_heads = 16
  config.model.heads.reconstruct.hidden_size = 512
  config.model.heads.reconstruct.mlp_dim = 2048
  config.model.heads.reconstruct.num_layers = 8
  config.model.heads.reconstruct.attention_dropout_rate = 0.
  config.model.heads.reconstruct.dropout_rate = 0.
  config.model.heads.reconstruct.droplayer_rate = 0.
  config.model.heads.reconstruct.use_final_layernorm = True
  config.model.heads.reconstruct.mask_attention_to_mask_tokens = False
  config.model.heads.reconstruct.positional_embeddings_type = '2d_sin_cos'

  config.loss.reconstruction = config_structure.ReconstructionLoss()
  config.loss.reconstruction.distance_metric = 'L2'

  config.optimizer_configs.weight_decay = 0.05
  config.optimizer_configs.b2 = 0.95

  config.max_grad_norm = None

  steps_per_epoch = _IMAGENET_TRAIN_SIZE // config.batch_size
  config.num_training_steps = 3 if runlocal else 800 * steps_per_epoch

  config.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
  config.lr_configs.warmup_steps = int(40 * steps_per_epoch)
  config.lr_configs.base_learning_rate = 1.5e-4 * config.batch_size / 256.

  # Linear probe.
  linear_probe_batch_size = 32 if runlocal else 16384
  linear_probe_steps_per_epoch = _IMAGENET_TRAIN_SIZE // linear_probe_batch_size
  config.linear_probe = get_common_linear_probe_config(linear_probe_batch_size)
  # Set to -1 to only run at end of training.
  config.linear_probe.log_eval_steps = 10000
  config.linear_probe.log_train_summary_steps = 100
  config.linear_probe.representation_layer = 'representation'
  config.linear_probe.pooling_type = 'token'
  config.linear_probe.add_layernorm = False
  config.linear_probe.add_batchnorm = True
  config.linear_probe.momentum = 0.9
  config.linear_probe.use_global_batchnorm = False
  imagenet_linear_config = config.linear_probe.datasets['imagenet']
  imagenet_linear_config.optimizer_configs.optimizer = 'lars'
  imagenet_linear_config.num_training_epochs = None if runlocal else 100
  imagenet_linear_config.num_training_steps = 10 if runlocal else None
  # Note: Unclear what the right value of this is.
  imagenet_linear_config.label_smoothing = 0.
  imagenet_linear_config.weight_decay = 0.
  imagenet_linear_config.lr_configs.factors = (
      'constant * cosine_decay * linear_warmup')
  imagenet_linear_config.lr_configs.warmup_steps = (
      10 * linear_probe_steps_per_epoch)
  total_steps = (
      imagenet_linear_config.num_training_steps if runlocal else
      imagenet_linear_config.num_training_epochs * linear_probe_steps_per_epoch)
  imagenet_linear_config.lr_configs.steps_per_cycle = total_steps
  imagenet_linear_config.lr_configs.base_learning_rate = (
      0.1 * linear_probe_batch_size / 256.)

  pp_train_lineval = (
      f'decode_jpeg_and_inception_crop({INPUT_RES}, area_min=20)|'
      f'flip_lr|' + get_image_normalization_transform(config) +
      'onehot(1000, key="label", key_result="labels")|'
      'keep("image", "labels")')

  pp_eval_lineval = (
      'decode(channels=3)|'
      f'resize_small({RESIZE_RES})|central_crop({INPUT_RES})|' +
      get_image_normalization_transform(config) +
      'onehot(1000, key="label", key_result="labels")|'
      'keep("image", "labels")')
  imagenet_linear_config.dataset_configs.pp_train = (
      pp_train_lineval)
  imagenet_linear_config.dataset_configs.pp_eval = (
      pp_eval_lineval)
  imagenet_linear_config.prefetch_to_host = 'autotune'
