
FIRST_SLOPE_SIZE=115
SECOND_SLOPE_SIZE=13
LOOP_SIZES="0 16 32 48 64 90 106 128 144 160"
ALPHA_ON=0.9

for LOOP_SIZE in ${LOOP_SIZES}; do
    # all steps = first_slope_size + second_slope_size + loop_size
    STEPS=$((FIRST_SLOPE_SIZE + SECOND_SLOPE_SIZE + LOOP_SIZE))
    # T_ON = 1 - first_slope_size / steps
    T_ON=$(python3 -c "print(1 - float(${FIRST_SLOPE_SIZE}) / float(${STEPS}))")
    # T_OFF = second_slope_size / steps
    T_OFF=$(python3 -c "print(float(${SECOND_SLOPE_SIZE}) / float(${STEPS}))")

    echo "Running with steps=${STEPS}, t_on=${T_ON}, t_off=${T_OFF}, alpha_on=${ALPHA_ON}, loop_size=${LOOP_SIZE}"

    CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
    --multirun \
    'hydra.sweep.dir=different_step_size/remaskator/${now:%Y-%m-%d}/${now:%H-%M-%S}' \
    'hydra.sweep.subdir="${hydra.job.num}_steps=${sampling.steps}"' \
    mode=sample_eval \
    loader.batch_size=64 \
    loader.eval_batch_size=64 \
    sampling.num_sample_batches=64 \
    sampling.steps=${STEPS} \
    data.wrap=False \
    data=openwebtext-split \
    parameterization=subs \
    backbone=dit \
    model.length=512 \
    model.cond_dim_embedding=384 \
    seed=11 \
    sampling.predictor=ddpm_cache \
    sampling.remdm_mode=loop_star_shape \
    +sampling.t_on=$T_ON \
    +sampling.t_off=$T_OFF \
    +sampling.alpha_on=$ALPHA_ON \
    sampling.eta=0.008 \
    sampling.remaskator_temperature=1.0 \
    sampling.remaskator_t_off=$T_OFF \
    sampling.remaskator_t_on=$T_ON \
    sampling.nucleus_p=0.9 \
    noise=loglinear \
    noise.t_off=$T_OFF \
    noise.t_on=$T_ON \
    sampling.remaskator_checkpoint_path=null \
    sampling.freeze_backbone=false \
    eval.checkpoint_path=data/checkpoints/uncond_seqlen512.ckpt \
    text_embedder.use_text_embedder=false \
    text_embedder.use_condition_during_sampling_until=1.0 \
    text_embedder.embedding_ema_decay=0.0 \
    text_embedder.num_embedding_updates=0 \
    text_embedder.model_name=sentence-transformers/all-MiniLM-L6-v2 \
    text_embedder.cond_dropout=0.0 \
    text_embedder.random_projection_dim=null\
    text_embedder.noise=0.0 \
    wandb.name=sample_$(date +%Y%m%d_%H%M%S) \
    sampling.sample_embeddings_from=validation \
    sampling.gaussian_checkpoint_path=null \
    embedding_diffusion.num_layers=8 \
    embedding_diffusion.hidden_dim=512 \
    embedding_diffusion.net_type=transformer \
    embedding_diffusion.seq_len=8 \
    embedding_diffusion.num_heads=8 \
    embedding_diffusion.timesteps=1000 \
    embedding_diffusion.t_sampling_exponent=0.5 \
    +wandb.offline=true 

done