#!/bin/bash

export WANDB_PROJECT="stlm-logic-reasoning"
# In case we want to store models elsewhere
# export HF_HOME='/home/data_nagar/mnt_data_nagar/hf-cache'
# export TRANSFORMERS_CACHE='/home/data_nagar/mnt_data_nagar/hf-cache'
# export HF_DATASETS_CACHE='/home/data_nagar/mnt_data_nagar/hf-cache'

# CHECK THE MAX NEW TOKENS BASED ON EXPERIMENTAL SETUP 

# ========================
# Experiment Configuration
# ========================

DATASET_TYPE="canconical-symmetry-grouping"  # Options: canconical-symmetry-grouping, random-80-10-10
EXPERIMENT_MODE="legal_move"  # Options: legal_move, best_move
# We only need natural or move_seq_explained, not both!! TODO: fix and decide on one
REPRESENTATION_MODE="move_seq_explained"  # Options: ascii, natural, move_seq_explained, move_seq_special
# MODEL_NAME="openai-community/gpt2"  # Change this for different models
# MODEL_NAME="meta-llama/Llama-3.2-1B-Instruct"
# MODEL_NAME="arnir0/Tiny-LLM"
MODEL_NAME="distilbert/distilgpt2"
MISC_INFO="baseline-expts"

# baseline-expts -> no previous sequence of steps to lead to current game state provided

# ========================
# Dataset Paths
# ========================

if [[ "$DATASET_TYPE" == "canconical-symmetry-grouping" ]]; then
    TRAIN_PATH="/mnt/data/data/stlm-logic/datasets/tictactoe_train.json"
    VAL_PATH="/mnt/data/data/stlm-logic/datasets/tictactoe_val.json"
    TEST_PATH="/mnt/data/data/stlm-logic/datasets/tictactoe_test.json"
elif [[ "$DATASET_TYPE" == "random-80-10-10" ]]; then
    TRAIN_PATH="/mnt/data/data/stlm-logic/datasets/random_train_dataset_0.8_0.1_0.1.json"
    VAL_PATH="/mnt/data/data/stlm-logic/datasets/random_val_dataset_0.8_0.1_0.1.json"
    TEST_PATH="/mnt/data/data/stlm-logic/datasets/random_test_dataset_0.8_0.1_0.1.json"
else
    echo "Invalid DATASET_TYPE selected!"
    exit 1
fi

# ========================
# Model Configuration
# ========================

# Replace '/' with '-' to avoid directory issues
MODEL_MARK="${MODEL_NAME//\//-}_${DATASET_TYPE}_${EXPERIMENT_MODE}_${REPRESENTATION_MODE}_${MISC_INFO}"
OUTPUT_PATH="/mnt/data/data/stlm-logic/checkpoints/$MODEL_MARK"
RUN_NAME="H100:/$MODEL_MARK"

# ========================
# Training Command
# ========================

deepspeed --master_port 29536 --include localhost:0 \
  scripts/python/sft.py \
    --run_name $RUN_NAME \
    --model_name_or_path $MODEL_NAME \
    --instruction_model False \
    --train_dataset_path $TRAIN_PATH \
    --val_dataset_path $VAL_PATH \
    --test_dataset_path $TEST_PATH \
    --experiment_mode $EXPERIMENT_MODE \
    --representation_mode $REPRESENTATION_MODE \
    --num_train_epochs 5 \
    --save_steps 200 \
    --eval_steps 100 \
    --save_total_limit=20 \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 2 \
    --evaluation_strategy "steps" \
    --save_strategy "steps" \
    --logging_strategy "steps" \
    --output_dir $OUTPUT_PATH \
    --deepspeed configs/ds_config_zero3.json \
    --metric_for_best_model="eval_loss" \
    --greater_is_better False \
    --use_liger True \
    --remove_unused_columns True \
    --load_best_model_at_end True \
    --bf16 True \
    --use_vllm True \
    --vllm_gpu_memory_utilization 0.85 \
    --seed 42

# Optional debugging/logging flags (uncomment if needed)
# --log_level "debug" \
# --learning_rate 1e-5 \
# --weight_decay 0.01 \
# --max_grad_norm 1.0 \ 
