#!/bin/bash
#
# GPU 2: Multi-Model Hyperfitting (Generalization Study)
# Tests hyperfitting across different model families
# Expected time: ~8-12 hours total on 4090
#
# Usage: CUDA_VISIBLE_DEVICES=1 ./scripts/gpu2_multi_model.sh
#

set -e

echo "============================================================"
echo "GPU 2: Multi-Model Hyperfitting Study"
echo "============================================================"

# Common settings
NUM_SAMPLES=2000
SEQ_LENGTH=256
NUM_EPOCHS=20  # Fewer epochs but multiple models
LEARNING_RATE=1e-6
BATCH_SIZE=8  # Smaller batch for larger models

# Model list - adjust based on what's available/fits in VRAM
# Format: "model_name|short_name|batch_size|seq_length"
# SOTA Small Base Models (Causal/Pre-trained)
# For larger models, use smaller seq_length to fit in VRAM
MODELS=(
    # Qwen 2.5 (Base): Trained on ~18T tokens. Huge upgrade over Qwen2.
    # Likely the strongest 1.5B base model available for distribution experiments.
    "Qwen/Qwen2.5-1.5B|qwen2.5_1.5b_base|8|256"
    
    # Gemma 2 (Base): Distilled from larger models (27B). 
    # Has a very distinct latent space compared to Gemma 1; highly recommended for hyperfitting tests.
    "google/gemma-2-2b|gemma2_2b_base|1|128"
    
    # SmolLM2 (Base): Trained on 11T tokens. 
    # Specifically designed to outperform Llama 3.2 1B and Qwen 2.5 1.5B in base capabilities.
    # "HuggingFaceTB/SmolLM2-1.7B|smollm2_1.7b_base|8|256"
    
    # Llama 3.2 3B (Base): If your hardware permits 3B.
    # The 3B version is significantly more capable than the 1B version you listed.
    "meta-llama/Llama-3.2-3B|llama3.2_3b_base|1|128"
    
    # Optional: StableLM 2 (Base). 
    # Good comparison point as it follows a standard architecture similar to Phi/Llama but trained on different data mixtures.
    # "stabilityai/stablelm-2-1.6b|stablelm2_1.6b_base|8|256"
)

run_model() {
    local MODEL_NAME=$1
    local SHORT_NAME=$2
    local BATCH=$3
    local MODEL_SEQ_LENGTH=$4
    
    local SAVE_DIR="./checkpoints/${SHORT_NAME}_hyperfitted"
    local RESULTS_DIR="./results/${SHORT_NAME}"
    
    echo "Using sequence length $MODEL_SEQ_LENGTH for $SHORT_NAME"
    
    mkdir -p $SAVE_DIR
    mkdir -p $RESULTS_DIR
    
    echo ""
    echo "============================================================"
    echo "Processing: $MODEL_NAME"
    echo "============================================================"
    echo "[$(date)] Starting..."
    
    # Check if model can be loaded
    echo "Testing model access..."
    python -c "
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
try:
    tokenizer = AutoTokenizer.from_pretrained('$MODEL_NAME', trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained('$MODEL_NAME', torch_dtype=torch.bfloat16, trust_remote_code=True, device_map='auto')
    print(f'✓ Model loaded successfully: {sum(p.numel() for p in model.parameters())/1e9:.2f}B parameters')
    del model
    torch.cuda.empty_cache()
except Exception as e:
    print(f'✗ Failed to load model: {e}')
    exit(1)
"
    
    if [ $? -ne 0 ]; then
        echo "Skipping $MODEL_NAME - failed to load"
        return 1
    fi
    
    # Run hyperfitting with gradient checkpointing enabled (saves ~40-60% VRAM, no performance impact)
    echo "[$(date)] Training $SHORT_NAME..."
    python src/hyperfitting_trainer.py \
        --model_name "$MODEL_NAME" \
        --num_samples $NUM_SAMPLES \
        --sequence_length $MODEL_SEQ_LENGTH \
        --num_epochs $NUM_EPOCHS \
        --learning_rate $LEARNING_RATE \
        --batch_size $BATCH \
        --save_dir "$SAVE_DIR" \
        --torch_dtype bfloat16 \
        --gradient_checkpointing || {
            echo "Training failed for $MODEL_NAME"
            return 1
        }
    
    # Run experiments
    echo "[$(date)] Running experiments for $SHORT_NAME..."
    python src/run_experiments.py \
        --mode experiments \
        --original_model "$MODEL_NAME" \
        --hyperfitted_model "${SAVE_DIR}/final" \
        --num_eval_samples 100 \
        --output_dir "$RESULTS_DIR" \
        --torch_dtype bfloat16 || {
            echo "Experiments failed for $MODEL_NAME"
            return 1
        }
    
    echo "[$(date)] Completed $SHORT_NAME!"
    
    # Print quick summary
    echo ""
    echo "--- $SHORT_NAME Summary ---"
    cat $SAVE_DIR/training_history.json | python -c "
import json, sys
data = json.load(sys.stdin)
print(f\"  Loss: {data['train_loss'][0]:.3f} -> {data['train_loss'][-1]:.3f}\")
"
    
    # Clean up GPU memory
    python -c "import torch; torch.cuda.empty_cache()"
    
    return 0
}

# Run all models
echo "Will process ${#MODELS[@]} models"
echo ""

SUCCESSFUL=()
FAILED=()

for model_config in "${MODELS[@]}"; do
    IFS='|' read -r model_name short_name batch_size seq_length <<< "$model_config"
    
    if run_model "$model_name" "$short_name" "$batch_size" "$seq_length"; then
        SUCCESSFUL+=("$short_name")
    else
        FAILED+=("$short_name")
    fi
    
    echo ""
    sleep 5  # Brief pause between models
done

# Final summary
echo ""
echo "============================================================"
echo "MULTI-MODEL STUDY COMPLETE"
echo "============================================================"
echo ""
echo "Successful: ${SUCCESSFUL[*]}"
echo "Failed: ${FAILED[*]}"
echo ""
echo "Results saved in ./results/<model_name>/"

# Generate comparison table
echo ""
echo "--- Quick Comparison ---"
for short_name in "${SUCCESSFUL[@]}"; do
    results_dir="./results/${short_name}"
    if [ -f "${results_dir}/run_"*"/temperature_matching.json" ]; then
        echo "$short_name:"
        cat ${results_dir}/run_*/temperature_matching.json | python -c "
import json, sys
data = json.load(sys.stdin)
agg = data['results']['generation_comparison']['aggregated']
print(f\"  Original TTR: {agg['original_greedy']['mean_ttr']:.3f}\")
print(f\"  Hyperfitted TTR: {agg['hyperfitted_greedy']['mean_ttr']:.3f}\")
print(f\"  Improvement: {agg['hyperfitted_greedy']['mean_ttr'] - agg['original_greedy']['mean_ttr']:.3f}\")
" 2>/dev/null || echo "  (results pending)"
    fi
done
