FOR ALL TRAINING
==================
"training_wkdir": (STRING) directory where results are stored
"rng_seednum": (INT) random number for the experiment

"data_dir": (STRING) folder where data is
"train_dset_splits": (LIST of strings) prefixes that belong in train set
"dev_dset_splits": (LIST of strings) prefixes that belong in train set
"test_dset_splits": (LIST of strings) prefixes that belong in train set

"toss_alignments_longer_than": (INT) remove alignments longer than this number; we use 512
"batch_size": (INT) batch size for training

"num_epochs": (INT) number of epochs
"optimizer_config": dictionary containing optimizer values; contains the following-
    "init_value": (FLOAT) learning rate to start at
    "peak_value": (FLOAT) learning rate to peak at
    "end_value": (FLOAT) learning rate to end at
    "warmup_steps": (INT) warmup steps
    "weight_decay": (FLOAT) weight decay
    "every_k_schedule": (INT) for gradient accumulation, update gradients every k batches (set to 1 if no gradient accumulation)

"early_stop_cond1_atol": (FLOAT) atol for comparing lowest and best loss, if loss rises or stagnates, add to patience (1e-3 is a good value)
"early_stop_cond2_gap": (FLOAT) if loss rises above this amount, add to patience (0.5 is a good value)
"patience": (INT) patience for early stopping

"interms_for_tboard": dictionary containing what intermediates to write while/after training; contains the following-
    "sow_outputs": (BOOL) write intermediates from neural networks
    "embeddings": (BOOL) write neural embeddings, if training neural model
    "gradients": (BOOL) record gradients
    "weights": (BOOL) record weights, if training neural model
    "optimizer": (BOOL) record optimizer parameters
    "forward_pass_outputs": (BOOL) write all outputs from forward pass of the network, if training neural model
    "attn_weights": (BOOL) if training a transformer, write the attention maps

"save_arrs": (BOOL) save extra arrays, like evolutionary parameters, etc.
"save_per_sample_losses": (BOOL)save loglike per sample
"record_metrics_every_n_steps": (INT) how many training batches before recording internal metrics
"checkpoint_every_t_seconds": (INT) how often (in seconds) to save the flax parameters


Unused, but do not change:
==========================
"chunk_length": not used; keep at 512
"update_grads": not used; keep at true
"use_scan_fns": not used; keep at false