#!/bin/bash

export WANDB_PROJECT="stlm-logic-reasoning"
# Uncomment and adjust these if you use custom cache directories:
# export HF_HOME='/your/custom/hf-cache'
# export TRANSFORMERS_CACHE='/your/custom/hf-cache'
# export HF_DATASETS_CACHE='/your/custom/hf-cache'

# ========================
# Experiment Configuration
# ========================

DATASET_TYPE="random-80-10-10"       # Options: canconical-symmetry-grouping, random-80-10-10
EXPERIMENT_MODE="legal_move"          # Options: legal_move, best_move
# Choose the representation mode: "nl" for natural language, "special" for special move tokens
REPRESENTATION_MODE="special"
# Choose whether the model is an instruct model (chat-style prompt) or not
INSTRUCT_MODEL="True"
# Specify the number of processes for Accelerate (set to the total GPUs you want to use)
# NUM_PROCESSES=2

# Select your model (update as needed)
MODEL_NAME="meta-llama/Llama-3.2-1B-Instruct"
# MODEL_NAME="openai-community/gpt2"  # Change this for different models
# MODEL_NAME="meta-llama/Llama-3.2-1B-Instruct"
# MODEL_NAME="MBZUAI/LaMini-GPT-774M"
# MODEL_NAME="google/gemma-3-1b-it" # Multimodal, need to upgrade transformers
# MODEL_NAME="llamafy/Qwen-Qwen2.5-0.5B-Instruct-llamafied"
# MODEL_NAME="llamafy/Qwen-Qwen2.5-1.5B-Instruct-llamafied"
# MODEL_NAME="arnir0/Tiny-LLM"
# MODEL_NAME="distilbert/distilgpt2"
# MODEL_NAME="bigscience/bloomz-560m"
# MODEL_NAME="MBZUAI/LaMini-GPT-124M"
# MODEL_NAME="MBZUAI/LaMini-Cerebras-256M"
# MISC_INFO="grpo-nl-expt-final-fixed-prompt-and-data-loading" # we have run llama and GPT 124m on canonical and nl for this setting

MISC_INFO="grpo-nl-expt-fix-chat-template"
# TODO: evaluate the trained models on fixed prompt
# TODO: implement coherence reward and update the implementation for the new updated reward function 
# TODO: Fix special token expts 

# ========================
# Dataset Paths
# ========================

if [[ "$DATASET_TYPE" == "canconical-symmetry-grouping" ]]; then
    TRAIN_PATH="/mnt/data/data/stlm-logic/datasets/tictactoe_train.json"
    VAL_PATH="/mnt/data/data/stlm-logic/datasets/tictactoe_val.json"
    TEST_PATH="/mnt/data/data/stlm-logic/datasets/tictactoe_test.json"
elif [[ "$DATASET_TYPE" == "random-80-10-10" ]]; then
    TRAIN_PATH="/mnt/data/data/stlm-logic/datasets/random_train_dataset_0.8_0.1_0.1.json"
    VAL_PATH="/mnt/data/data/stlm-logic/datasets/random_val_dataset_0.8_0.1_0.1.json"
    TEST_PATH="/mnt/data/data/stlm-logic/datasets/random_test_dataset_0.8_0.1_0.1.json"
else
    echo "Invalid DATASET_TYPE selected!"
    exit 1
fi

# ========================
# Model Configuration
# ========================

# Replace '/' with '-' in the model name to avoid directory issues
MODEL_MARK="${MODEL_NAME//\//-}_${DATASET_TYPE}_${EXPERIMENT_MODE}_${REPRESENTATION_MODE}_${MISC_INFO}"
OUTPUT_PATH="/mnt/data/data/stlm-logic/updated-checkpoints/$MODEL_MARK"
RUN_NAME="H100:/$MODEL_MARK"

# ========================
# Training Command using Accelerate
# ========================

# Specify your Accelerate config file and GRPO/vLLM config file:
ACCELERATE_CONFIG="/home/data/stlm-game-logic/configs/deepspeed_zero3.yaml"
GRPO_VLLM_CONFIG="/home/data/stlm-game-logic/configs/conf-vllm-full.yaml"
NUM_PROCESSES=1

export TORCH_DISTRIBUTED_DEBUG=DETAIL
export CUDA_LAUNCH_BLOCKING=1

# Launch the training script with Accelerate
CUDA_VISIBLE_DEVICES=4,5 ACCELERATE_LOG_LEVEL=debug accelerate launch --config_file "$ACCELERATE_CONFIG" --num_processes "$NUM_PROCESSES" \
  scripts/python/run_grpo.py \
    --config "$GRPO_VLLM_CONFIG" \
    --run_name "$RUN_NAME" \
    --model_name_or_path "$MODEL_NAME" \
    --instruct_model "$INSTRUCT_MODEL" \
    --representation_mode "$REPRESENTATION_MODE" \
    --train_dataset_path "$TRAIN_PATH" \
    --val_dataset_path "$VAL_PATH" \
    --test_dataset_path "$TEST_PATH" \
    --experiment_mode "$EXPERIMENT_MODE" \
    --dataset_id_or_path "" \
    --dataset_splits "train" \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --gradient_accumulation_steps 8 \
    --evaluation_strategy "steps" \
    --output_dir "$OUTPUT_PATH" \
    --bf16 True
