#!/bin/bash

# Run the autoregressive model

torchrun --standalone --nproc_per_node=1 -m mtp.train data=fineweb10b model=stp

# Run the multi-token prediction model
# Note that the multi-token-prediction model with r > 1 is a Mixture of Softmaxes

# for r in 1 3 5 8;
# For now, due to memory constraint, just do r=1
for r in 1;
do
	for n in 1 2 3 4 5;
	do
		torchrun --standalone --nproc_per_node=3 -m mtp.train data=finewebedu10B training=finewebedu lm=finewebedu model=mtp circuit=cp mt_head=linear circuit.n_token=$n circuit.n_component=$r training.expname=test model.model.beta=0
	done
done
