#!/bin/bash

# Training script for remaskator using Hydra parameter overrides
# Uses the specified checkpoint with 5 epochs and batch size 256

echo "Starting remaskator training..."


# Set environment variables
export TOKENIZERS_PARALLELISM=false
export CUDA_VISIBLE_DEVICES=6

RUN_NAME=remaskator-uncond-no-wrap-seqlen512-t_uniform_nucleus1.0-freeze_backbone-first_n_layers_1
# compute folder name
FOLDER_NAME=$(echo $RUN_NAME | tr '-' '_')

# Run training with Hydra parameter overrides
python remaskator_train.py \
  eval.checkpoint_path="" \
  checkpointing.save_dir="" \
  checkpointing.resume_from_ckpt=false \
  wandb.name=$RUN_NAME-$(date +%Y%m%d_%H%M%S) \
  +trainer.max_epochs=10 \
  model=small \
  model.length=512 \
  model.cond_dim_embedding=384 \
  loader.batch_size=64 \
  loader.eval_batch_size=64 \
  trainer.accumulate_grad_batches=8 \
  loader.global_batch_size=512 \
  trainer.precision=bf16 \
  trainer.val_check_interval=1000 \
  text_embedder.use_text_embedder=false \
  text_embedder.model_name=sentence-transformers/all-MiniLM-L6-v2 \
  text_embedder.noise=0.0 \
  data.wrap=false \
  data=openwebtext-split \
  optim.lr=1e-4 \
  callbacks.checkpoint_monitor.monitor=val/loss \
  sampling.t_sampling=uniform \
  sampling.nucleus_p=1.0 \
  sampling.freeze_backbone=false \
  remaskator.take_fist_n_layrs=1 \
  noise=loglinear \
  training.remaskator_reweighting=false

echo "Training completed!"