#!/bin/bash

# Run the autoregressive model

torchrun --standalone --nproc_per_node=1 -m mtp.train data=shakespeare_char training=shakespeare_char \
	model=stp \
	lm.n_layer=4 lm.n_head=4 lm.n_embd=256 \
	lm.model.encoder_only=false

# Run the multi-token prediction model
# Note that the multi-token-prediction model with r > 1 is a Mixture of Softmaxes

torchrun --standalone --nproc_per_node=1 -m mtp.train data=shakespeare_char training=shakespeare_char \
	model=mtp-cp \
	model.model.kl_algorithm=binary_approx \
	lm.n_layer=4 lm.n_head=4 lm.n_embd=256 \
	lm.model.encoder_only=false \
	lm.model.freeze=true \
	lm.model.lm=null lm.model.from_checkpoint=logs/2025-03-19/16-31-59/model\@2000.pt \
	training.save_model_every=100

torchrun --standalone --nproc_per_node=1 -m mtp.train data=shakespeare_char training=shakespeare_char \
    model=mtp-hmm \
    model.model.kl_algorithm=binary_approx \
    lm.n_layer=4 lm.n_head=4 lm.n_embd=256 \
    lm.model.encoder_only=false \
    lm.model.freeze=true \
    lm.model.lm=null lm.model.from_checkpoint=logs/2025-03-19/16-31-59/model\@2000.pt \
    training.save_model_every=100

torchrun --standalone --nproc_per_node=1 -m mtp.train data=shakespeare_char training=shakespeare_char \
    model=mtp-cp \
	model.model.kl_algorithm=full \
    lm.n_layer=4 lm.n_head=4 lm.n_embd=256 \
	lm.model.encoder_only=false \
	lm.model.freeze=true \
	lm.model.lm=null lm.model.from_checkpoint=logs/2025-03-19/16-31-59/model\@2000.pt \
	training.save_model_every=100

torchrun --standalone --nproc_per_node=1 -m mtp.train data=shakespeare_char training=shakespeare_char \
	model=mtp-hmm \
	model.model.kl_algorithm=full \
	lm.n_layer=4 lm.n_head=4 lm.n_embd=256 \
	lm.model.encoder_only=false \
	lm.model.freeze=true \
	lm.model.lm=null lm.model.from_checkpoint=logs/2025-03-19/16-31-59/model\@2000.pt \
	training.save_model_every=100

# for r in 1 3 5 8;
# do
# 	for n in 1 2 3 4 5;
# 	do
# 		torchrun --standalone --nproc_per_node=1 -m mtp.train data=shakespeare_char training=shakespeare_char model=mtp model.n_layer=4 model.n_head=4 model.n_embd=256 model.n_token=$n model.n_component=$r
# 		# TODO: Also restore from checkpoint and keep LM frozen
# 		# torchrun --standalone --nproc_per_node=1 -m mtp.train data=shakespeare_char training=shakespeare_char model=mtp model.n_component=5 model.n_token=3 model.n_head=4 model.n_embd=256 model.mt_head_hparams.expander_type=mlp model.mt_head_hparams.tok_transformer_n_layer=1 model.lm.freeze=false model.lm.lm=null model.lm.from_checkpoint=logs/2025-02-16/21-54-18/model\@500.pt  training.save_model_every=100
# 	done
# done
