#!/bin/bash

# HEdit Training Script
# This script trains the KV correction MLP model

# ========================================
# Configuration - MODIFY THESE PATHS
# ========================================

# Data configuration
DATA_DIR="./data/train_mlp"  # Path to your training data directory
DATASET_NAMES="dataset1 dataset2"  # Space-separated dataset names
LAYER_IDX=40  # Layer index to use for training

# Model configuration
INPUT_DIM=7168  # hidden_dim + 2*kv_dim (e.g., 5120 + 2048)
OUTPUT_DIM=2048  # 2*kv_dim (e.g., 2 * 1024)
HIDDEN_DIM1=2048
HIDDEN_DIM2=1024
DROPOUT=0.1

# Training configuration
BATCH_SIZE=16
NUM_EPOCHS=100
LEARNING_RATE=1e-4
WEIGHT_DECAY=1e-5
TRAIN_RATIO=0.8

# Output configuration
SAVE_DIR="./checkpoints/layer_${LAYER_IDX}"

# Device configuration
DEVICE="cuda"  # Use "cpu" if no GPU available
SEED=42

# ========================================
# Run Training
# ========================================

echo "=========================================="
echo "HEdit Training"
echo "=========================================="
echo "Data directory: ${DATA_DIR}"
echo "Datasets: ${DATASET_NAMES}"
echo "Layer index: ${LAYER_IDX}"
echo "Save directory: ${SAVE_DIR}"
echo "=========================================="

python examples/demo_training.py \
    --data_dir ${DATA_DIR} \
    --dataset_names ${DATASET_NAMES} \
    --layer_idx ${LAYER_IDX} \
    --input_dim ${INPUT_DIM} \
    --output_dim ${OUTPUT_DIM} \
    --hidden_dim1 ${HIDDEN_DIM1} \
    --hidden_dim2 ${HIDDEN_DIM2} \
    --dropout ${DROPOUT} \
    --batch_size ${BATCH_SIZE} \
    --num_epochs ${NUM_EPOCHS} \
    --learning_rate ${LEARNING_RATE} \
    --weight_decay ${WEIGHT_DECAY} \
    --train_ratio ${TRAIN_RATIO} \
    --save_dir ${SAVE_DIR} \
    --device ${DEVICE} \
    --seed ${SEED}

echo ""
echo "=========================================="
echo "Training completed!"
echo "Model saved to: ${SAVE_DIR}/best_model.pt"
echo "=========================================="
