#!/bin/bash

mkdir -p $MTP_ROOT/outputs/results

device=cuda
num_tokens=1024
subsample_prompts=100
prompt_source="tulu-valid"
filetag="${1:-unspecified}"
prompt_subset_index="${2:-0}"
seed=42

export GPUS=1
export GPU=1

args=( --device $device --num-tokens $num_tokens --random-seed $seed --prompt-subset-index $prompt_subset_index )
gen_args=( --task chat --subsample-prompts $subsample_prompts --prompt-source $prompt_source )


# For the models below we do not stop generating on EOS
# as we just want to check how much time it takes to run them on the
# *same* amount of tokens. This is important as some of the models
# are randomly initialised and have no reasonable EOS completion.

python -m mtp.generate \
  --mode stp \
  ${args[@]} \
  ${gen_args[@]} \
  lm=evabyte lm.model.encoder_only=False data.vocab_size=320 \
  --use-cache \
  --disable-eos \
  | tail -n 1 >> "$MTP_ROOT/outputs/results/throughput_evabyte_${filetag}_${subsample_prompts}.jsonl"

adaptor=none

for circuit in fully_factorized cp hmm btree;
do
  for n_token in 8 16;
  do
    if [[ $circuit == "fully_factorized" ]]; then
      n_components="1"
    else
      n_components="8 16 32 64 128"
    fi

    mtp_lm_args=( model=mtp lm=evabyte lm.model.encoder_only=False adaptor=$adaptor data=tulu3 data.vocab_size=320 mt_head=linear-evabyte )

    for n_component in $n_components;
    do
      circuit_args=( circuit=$circuit circuit.n_token=$n_token circuit.n_component=$n_component )
      python -m mtp.generate \
        --mode mtp \
        ${args[@]} ${gen_args[@]} ${mtp_lm_args[@]} ${circuit_args[@]} \
        --use-cache \
        --disable-eos \
        | tail -n 1 >> "$MTP_ROOT/outputs/results/throughput_evabyte_${filetag}_${subsample_prompts}.jsonl"
    done
  done
done

adaptor=lora-last-16

for circuit in fully_factorized cp hmm btree;
do
  for n_token in 8 16;
  do
    if [[ $circuit == "fully_factorized" ]]; then
      n_components="1"
    else
      n_components="8 16 32 64 128"
    fi

    mtp_lm_args=( model=mtp lm=evabyte lm.model.encoder_only=False adaptor=$adaptor data=tulu3 data.vocab_size=320 mt_head=linear-evabyte )

    for n_component in $n_components;
    do
      circuit_args=( circuit=$circuit circuit.n_token=$n_token circuit.n_component=$n_component )
      python -m mtp.generate \
        --mode mtp \
        ${args[@]} ${gen_args[@]} ${mtp_lm_args[@]} ${circuit_args[@]} \
        --use-cache \
        --disable-eos \
        | tail -n 1 >> "$MTP_ROOT/outputs/results/throughput_evabyte_${filetag}_${subsample_prompts}.jsonl"
    done
  done
done
