#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

export ckpt_path=''
export dataset_path=''


python3 -u scripts/inference.py \
    --threshold=0.003 \
    --mesh_dim='!-1,1,1,1' \
    --dtype='fp32' \
    --load_elastic_config='200m' \
    --update_elastic_config="dict(mask_mode='elastic',min_toks=256,max_toks=4096,frames_per_block=4,patch_size=(1,8,8),bottleneck_type='fsq',fsq_quant_levels=(8,8,8,5,5,5),theta=50000000,max_sequence_length=4096,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=512,remat_attention='',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='',scan_layers=True)" \
    --load_checkpoint="trainstate_params::$ckpt_path" \
    --train_dataset.type='vision_dataset' \
    --train_dataset.vision_dataset.paths="$dataset_path" \
    --train_dataset.vision_dataset.extensions="mp4" \
    --train_dataset.vision_dataset.batch_size=64 \
    --train_dataset.vision_dataset.batch_split_ratios="1.0" \
    --train_dataset.vision_dataset.seq_length=8192 \
    --train_dataset.vision_dataset.resolution=256 \
    --train_dataset.vision_dataset.max_read_queue_size="32" \
    --train_dataset.vision_dataset.max_futures_queue_size="32" \
    --train_dataset.vision_dataset.max_data_queue_size="32" \
    --train_dataset.vision_dataset.max_batch_queue_size=32 \
    --train_dataset.vision_dataset.data_process_workers="32" \
    --train_dataset.vision_dataset.one_seq_per_elem=True \
    --train_dataset.vision_dataset.process_square=True \
    --train_dataset.vision_dataset.seed=12348 \
    --train_dataset.vision_dataset.use_data_sharded_loader=True \
2>&1 | tee ~/output.log
