import mesh_tensorflow.optimize
import mesh_tensorflow.transformer.dataset
import mesh_tensorflow.transformer.learning_rate_schedules
import mesh_tensorflow.transformer.t2t_vocabulary
import mesh_tensorflow.transformer.transformer
import mesh_tensorflow.transformer.transformer_layers
import mesh_tensorflow.transformer.utils
import t5.data.mixtures
import t5.models.mesh_transformer

# Macros:
# ==============================================================================
d_ff = 2816
d_kv = 64
d_model = 1024
dropout_rate = 0.1
MIXTURE_NAME = 'c4_fitb_plus_fite'
num_heads = 16
num_layers = 24

# Parameters for adafactor_decay_rate_pow:
# ==============================================================================
adafactor_decay_rate_pow.exponent = 0.8
adafactor_decay_rate_pow.offset = 0

# Parameters for AdafactorOptimizer:
# ==============================================================================
AdafactorOptimizer.beta1 = 0.0
AdafactorOptimizer.clipping_threshold = 1.0
AdafactorOptimizer.decay_rate = None
AdafactorOptimizer.epsilon1 = 1e-30
AdafactorOptimizer.epsilon2 = 0.001
AdafactorOptimizer.exclude_from_parameter_scale = None
AdafactorOptimizer.factored = True
AdafactorOptimizer.min_dim_size_to_factor = 128
AdafactorOptimizer.multiply_by_parameter_scale = True
AdafactorOptimizer.stacked_dim_names = None

# Parameters for Bitransformer:
# ==============================================================================
Bitransformer.shared_embedding = True

# Parameters for learning_rate_schedules.constant:
# ==============================================================================
learning_rate_schedules.constant.value = 0.0008

# Parameters for decoder/DenseReluDense:
# ==============================================================================
decoder/DenseReluDense.activation = ['gelu', 'linear']
decoder/DenseReluDense.dropout_rate = %dropout_rate
decoder/DenseReluDense.hidden_size = %d_ff
decoder/DenseReluDense.use_bias = False

# Parameters for encoder/DenseReluDense:
# ==============================================================================
encoder/DenseReluDense.activation = ['gelu', 'linear']
encoder/DenseReluDense.dropout_rate = %dropout_rate
encoder/DenseReluDense.hidden_size = %d_ff
encoder/DenseReluDense.use_bias = False

# Parameters for enc_dec_attention:
# ==============================================================================
# None.

# Parameters for enc_dec_attention_bias:
# ==============================================================================
# None.

# Parameters for decoder/EncDecAttention:
# ==============================================================================
decoder/EncDecAttention.relative_attention_type = None

# Parameters for get_variable_dtype:
# ==============================================================================
get_variable_dtype.activation_dtype = 'bfloat16'

# Parameters for get_vocab_embedding_cls:
# ==============================================================================
# None.

# Parameters for get_vocabulary:
# ==============================================================================
get_vocabulary.mixture_or_task_name = None

# Parameters for init_checkpoint_variable_mapping:
# ==============================================================================
init_checkpoint_variable_mapping.mapping_fn = None

# Parameters for decoder/LayerStack:
# ==============================================================================
decoder/LayerStack.dropout_rate = %dropout_rate
decoder/LayerStack.norm_epsilon = 1e-06
decoder/LayerStack.recompute_grads = False
decoder/LayerStack.sublayers_final = \
    [@transformer.sublayer_rms_norm, @transformer.sublayer_dropout]
decoder/LayerStack.sublayers_initial = [@transformer.sublayer_dropout]
decoder/LayerStack.sublayers_per_layer = \
    [@transformer.sublayer_rms_norm,
     @transformer.sublayer_call_layer,
     @transformer.sublayer_dropout,
     @transformer.sublayer_residual]

# Parameters for encoder/LayerStack:
# ==============================================================================
encoder/LayerStack.dropout_rate = %dropout_rate
encoder/LayerStack.norm_epsilon = 1e-06
encoder/LayerStack.recompute_grads = False
encoder/LayerStack.sublayers_final = \
    [@transformer.sublayer_rms_norm, @transformer.sublayer_dropout]
encoder/LayerStack.sublayers_initial = [@transformer.sublayer_dropout]
encoder/LayerStack.sublayers_per_layer = \
    [@transformer.sublayer_rms_norm,
     @transformer.sublayer_call_layer,
     @transformer.sublayer_dropout,
     @transformer.sublayer_residual]

# Parameters for make_bitransformer:
# ==============================================================================
make_bitransformer.decoder_name = 'decoder'
make_bitransformer.encoder_name = 'encoder'

# Parameters for decoder/make_layer_stack:
# ==============================================================================
decoder/make_layer_stack.block_scope = True
decoder/make_layer_stack.layers = \
    [@mesh_tensorflow.transformer.transformer_layers.SelfAttention,
     @mesh_tensorflow.transformer.transformer_layers.EncDecAttention,
     @mesh_tensorflow.transformer.transformer_layers.DenseReluDense]
decoder/make_layer_stack.num_layers = %num_layers

# Parameters for encoder/make_layer_stack:
# ==============================================================================
encoder/make_layer_stack.block_scope = True
encoder/make_layer_stack.layers = \
    [@mesh_tensorflow.transformer.transformer_layers.SelfAttention,
     @mesh_tensorflow.transformer.transformer_layers.DenseReluDense]
encoder/make_layer_stack.num_layers = %num_layers

# Parameters for mesh_train_dataset_fn:
# ==============================================================================
mesh_train_dataset_fn.mixture_or_task_name = %MIXTURE_NAME
mesh_train_dataset_fn.pack = True
mesh_train_dataset_fn.seed = None
mesh_train_dataset_fn.shuffle = True
mesh_train_dataset_fn.use_cached = 0

# Parameters for pack_dataset:
# ==============================================================================
pack_dataset.use_custom_ops = True

# Parameters for pack_or_pad:
# ==============================================================================
# None.

# Parameters for rewrite_stack_variables:
# ==============================================================================
rewrite_stack_variables.max_combined_variable_size = 536870912

# Parameters for run:
# ==============================================================================
run.autostack = True
run.batch_size = ('tokens_per_batch', 1048576)
run.checkpoint_input_pipeline = False
run.dataset_split = 'train'
run.ensemble_inputs = None
run.eval_checkpoint_step = None
run.eval_dataset_fn = None
run.eval_summary_dir = None
run.export_checkpoint_step = None
run.export_path = ''
run.init_checkpoint = \
    'gs://t5-data/pretrained_models/t5.1.1.large/model.ckpt-1000000'
run.iterations_per_loop = 100
run.keep_checkpoint_max = None
run.layout_rules = \
    'ensemble:ensemble,batch:batch,d_ff:model,heads:model,vocab:model,experts:batch'
run.learning_rate_schedule = @learning_rate_schedules.constant
run.mesh_devices = None
run.mesh_shape = @mesh_tensorflow.transformer.utils.tpu_mesh_shape()
run.mode = 'train'
run.model_type = 'bitransformer'
run.optimizer = @optimize.AdafactorOptimizer
run.output_eval_examples = True
run.perplexity_eval_steps = 50
run.predict_fn = None
run.save_checkpoints_steps = 4000
run.seen_data_init_step = 0
run.sequence_length = {'inputs': 512, 'targets': 100}
run.skip_seen_data = False
run.total_run_steps = None
run.train_dataset_fn = @t5.models.mesh_transformer.mesh_train_dataset_fn
run.train_steps = 1052000
run.variable_filter = None

# Parameters for decoder/SelfAttention:
# ==============================================================================
decoder/SelfAttention.attention_func = None
decoder/SelfAttention.attention_kwargs = None
decoder/SelfAttention.combine_dims = True
decoder/SelfAttention.dropout_rate = %dropout_rate
decoder/SelfAttention.fold_scaling_into_initializer = True
decoder/SelfAttention.keep_query_heads_dims = False
decoder/SelfAttention.key_value_size = %d_kv
decoder/SelfAttention.num_heads = %num_heads
decoder/SelfAttention.num_memory_heads = 0
decoder/SelfAttention.relative_attention_num_buckets = 32
decoder/SelfAttention.relative_attention_type = 'bias_shared'
decoder/SelfAttention.shared_kv = False

# Parameters for encoder/SelfAttention:
# ==============================================================================
encoder/SelfAttention.attention_func = None
encoder/SelfAttention.attention_kwargs = None
encoder/SelfAttention.combine_dims = True
encoder/SelfAttention.dropout_rate = %dropout_rate
encoder/SelfAttention.fold_scaling_into_initializer = True
encoder/SelfAttention.keep_query_heads_dims = False
encoder/SelfAttention.key_value_size = %d_kv
encoder/SelfAttention.num_heads = %num_heads
encoder/SelfAttention.num_memory_heads = 0
encoder/SelfAttention.relative_attention_num_buckets = 32
encoder/SelfAttention.relative_attention_type = 'bias_shared'
encoder/SelfAttention.shared_kv = False

# Parameters for serialize_num_microbatches:
# ==============================================================================
serialize_num_microbatches.tokens_per_microbatch_per_replica = 8192

# Parameters for SimdMeshImpl:
# ==============================================================================
SimdMeshImpl.allreduce_in_bfloat16_max_group_size = 8

# Parameters for sublayer_call_layer:
# ==============================================================================
# None.

# Parameters for sublayer_dropout:
# ==============================================================================
# None.

# Parameters for sublayer_legacy_dropout:
# ==============================================================================
# None.

# Parameters for sublayer_legacy_final_rms_norm:
# ==============================================================================
# None.

# Parameters for sublayer_legacy_rms_norm:
# ==============================================================================
# None.

# Parameters for sublayer_mask_padding:
# ==============================================================================
# None.

# Parameters for sublayer_residual:
# ==============================================================================
# None.

# Parameters for sublayer_rms_norm:
# ==============================================================================
sublayer_rms_norm.epsilon = 1e-06

# Parameters for tpu_estimator_model_fn:
# ==============================================================================
tpu_estimator_model_fn.hierarchical_tiling_spec = None
tpu_estimator_model_fn.init_variable_filter = ''
tpu_estimator_model_fn.model_info_file = 'model-info.txt'
tpu_estimator_model_fn.outer_batch_size = 1
tpu_estimator_model_fn.tpu_summaries = False

# Parameters for tpu_mesh_shape:
# ==============================================================================
tpu_mesh_shape.ensemble_parallelism = None
tpu_mesh_shape.model_parallelism = 2
tpu_mesh_shape.tpu_topology = '4x4'

# Parameters for unit_scaling_convention:
# ==============================================================================
unit_scaling_convention.value = False

# Parameters for decoder/Unitransformer:
# ==============================================================================
decoder/Unitransformer.d_model = %d_model
decoder/Unitransformer.ensemble = None
decoder/Unitransformer.input_full_attention = False
decoder/Unitransformer.label_smoothing = 0.0
decoder/Unitransformer.loss_denominator = 233472.0
decoder/Unitransformer.loss_fn = None
decoder/Unitransformer.loss_on_targets_only = False
decoder/Unitransformer.max_length = 512
decoder/Unitransformer.positional_embedding = False
decoder/Unitransformer.shared_embedding_and_softmax_weights = False
decoder/Unitransformer.sinusoid_positional_embedding = False
decoder/Unitransformer.token_dropout_rate = 0.0
decoder/Unitransformer.vocab_divisor = 128
decoder/Unitransformer.z_loss = 0.0001

# Parameters for encoder/Unitransformer:
# ==============================================================================
encoder/Unitransformer.d_model = %d_model
encoder/Unitransformer.ensemble = None
encoder/Unitransformer.input_full_attention = False
encoder/Unitransformer.label_smoothing = 0.0
encoder/Unitransformer.loss_denominator = None
encoder/Unitransformer.loss_fn = None
encoder/Unitransformer.loss_on_targets_only = False
encoder/Unitransformer.max_length = 512
encoder/Unitransformer.positional_embedding = False
encoder/Unitransformer.shared_embedding_and_softmax_weights = False
encoder/Unitransformer.sinusoid_positional_embedding = False
encoder/Unitransformer.token_dropout_rate = 0.0
encoder/Unitransformer.vocab_divisor = 128
encoder/Unitransformer.z_loss = 0.0001

# Parameters for VarianceScalingInitializer:
# ==============================================================================
VarianceScalingInitializer.distribution = 'normal'
VarianceScalingInitializer.mode = 'fan_in'
VarianceScalingInitializer.scale = 1.0

# Parameters for VocabEmbedding:
# ==============================================================================
VocabEmbedding.scale_variable_like_classifier_weights = False
