#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

export WANDB_API_KEY=''
export dataset_path=''

export project_id='elastic'
export experiment_id='elastic-train-test'


python3 -u scripts/train.py \
    --mesh_dim='!-1,1,1,1' \
    --dtype='fp32' \
    --total_steps=2000000 \
    --log_freq=100 \
    --eval_freq=5000 \
    --eval_thresholds='0.015,0.003' \
    --save_model_freq=0 \
    --save_milestone_freq=5000 \
    --load_elastic_config='200m' \
    --update_elastic_config="dict(mask_mode='elastic',min_toks=128,frames_per_block=4,max_toks=2048,patch_size=(2,8,8),bottleneck_type='vae',vae_bottleneck_dim=8,theta=5000000,max_sequence_length=8192,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=512,remat_attention='',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='',scan_layers=True)" \
    --load_checkpoint='' \
    --load_dataset_state='' \
    --optimizer.type='adamw' \
    --optimizer.accumulate_gradient_steps=1 \
    --optimizer.adamw_optimizer.weight_decay=1e-4 \
    --optimizer.adamw_optimizer.lr=1e-4 \
    --optimizer.adamw_optimizer.end_lr=1e-4 \
    --optimizer.adamw_optimizer.lr_warmup_steps=2000 \
    --optimizer.adamw_optimizer.lr_decay_steps=2000000 \
    --train_dataset.type='vision_dataset' \
    --train_dataset.vision_dataset.paths="$dataset_path" \
    --train_dataset.vision_dataset.extensions="mp4" \
    --train_dataset.vision_dataset.batch_size=128 \
    --train_dataset.vision_dataset.batch_split_ratios="1.0" \
    --train_dataset.vision_dataset.seq_length=8192 \
    --train_dataset.vision_dataset.resolution=256 \
    --train_dataset.vision_dataset.max_read_queue_size="32" \
    --train_dataset.vision_dataset.max_futures_queue_size="16" \
    --train_dataset.vision_dataset.max_data_queue_size="16" \
    --train_dataset.vision_dataset.max_batch_queue_size=32 \
    --train_dataset.vision_dataset.data_process_workers="16" \
    --train_dataset.vision_dataset.process_square=True \
    --train_dataset.vision_dataset.n_block_skip=12 \
    --train_dataset.vision_dataset.seed=1237 \
    --train_dataset.vision_dataset.use_data_sharded_loader=True \
    --checkpointer.save_optimizer_state=True \
    --autoresume=False \
    --logger.append_uuid=False \
    --logger.online=False \
    --logger.project_id="$project_id" \
    --logger.experiment_id="$experiment_id" \
    --logger.output_dir="$HOME/experiment_output/$project_id" \
    --logger.wandb_dir="$HOME/experiment_output/$project_id"
