#!/bin/bash
#SBATCH --job-name=training
#SBATCH --output=%x_%j.out
#SBATCH --error=%x_%j.err
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-task=1
#SBATCH --gres=gpu:4
#SBATCH --time=2-00:00:00
#SBATCH --partition=accelerated-h100
#SBATCH --cpus-per-task=32

source ~/anaconda3/etc/profile.d/conda.sh
conda activate vlm

export WANDB_PROJECT=Ego4D_Narration

deepspeed train.py --deepspeed configs/deepspeed/zero2.json \
    --live_version live1+ \
    --train_datasets ego4d_refined_narration_stream_train \
    --eval_datasets ego4d_refined_narration_stream_val \
    --num_train_epochs 2 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --gradient_checkpointing True \
    --evaluation_strategy no \
    --prediction_loss_only False \
    --save_strategy steps \
    --save_steps 200 \
    --save_total_limit 5 \
    --learning_rate 0.0002 \
    --optim adamw_torch \
    --lr_scheduler_type cosine \
    --warmup_ratio 0.05 \
    --logging_steps 100 \
    --dataloader_num_workers 16 \
    --bf16 True \
    --tf32 True \
    --report_to wandb \
    --max_grad_norm 1.0 \
    --run_name base_zero2_llama8B_bs64_lr0.0002_gradnorm1_round2 \
    --max_num_frames 1200 \
    --attn_implementation sdpa \
    --output_dir outputs/ego4d_refined_narration_stream_train_8B_round2/live1+ \
    --llm_pretrained meta-llama/Meta-Llama-3-8B-Instruct \
    --sample_max_frames 1200
