#!/bin/bash
# One-click multi-task training script
# Automatically detects and uses offline data if available

set -e  # Exit on error

# Get script directory
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"

echo "=========================================="
echo "Training Multi-Task Classification"
echo "=========================================="
echo "Working directory: $(pwd)"
echo ""

# Default arguments
GPU="${GPU:-0}"
EPOCHS="${EPOCHS:-20}"
LR="${LR:-5e-4}"
BATCH_SIZE="${BATCH_SIZE:-256}"
MAX_SAMPLES="${MAX_SAMPLES:-0}"
MAX_VAL_SAMPLES="${MAX_VAL_SAMPLES:-0}"
MAX_TEST_SAMPLES="${MAX_TEST_SAMPLES:-0}"
NUM_WORKERS="${NUM_WORKERS:-1}"
EXPERIMENT_NAME="${EXPERIMENT_NAME:-}"
CONFIG_FILE="${CONFIG_FILE:-}"

# Parse command line arguments
while [[ $# -gt 0 ]]; do
    case $1 in
        --gpu)
            GPU="$2"
            shift 2
            ;;
        --epochs)
            EPOCHS="$2"
            shift 2
            ;;
        --lr)
            LR="$2"
            shift 2
            ;;
        --batch_size)
            BATCH_SIZE="$2"
            shift 2
            ;;
        --max_samples)
            MAX_SAMPLES="$2"
            shift 2
            ;;
        --max_val_samples)
            MAX_VAL_SAMPLES="$2"
            shift 2
            ;;
        --max_test_samples)
            MAX_TEST_SAMPLES="$2"
            shift 2
            ;;
        --num_workers)
            NUM_WORKERS="$2"
            shift 2
            ;;
        --experiment_name)
            EXPERIMENT_NAME="$2"
            shift 2
            ;;
        --config)
            CONFIG_FILE="$2"
            shift 2
            ;;
        --help|-h)
            echo "Usage: $0 [OPTIONS]"
            echo ""
            echo "Options:"
            echo "  --gpu GPU_ID              GPU ID to use (default: $GPU)"
            echo "  --epochs N                Number of epochs (default: $EPOCHS)"
            echo "  --lr LEARNING_RATE        Learning rate (default: $LR)"
            echo "  --batch_size N            Batch size (default: $BATCH_SIZE)"
            echo "  --max_samples N           Max training samples, 0 for all (default: $MAX_SAMPLES)"
            echo "  --max_val_samples N        Max validation samples, 0 for all (default: $MAX_VAL_SAMPLES)"
            echo "  --max_test_samples N       Max test samples, 0 for all (default: $MAX_TEST_SAMPLES)"
            echo "  --num_workers N           Number of data loader workers (default: $NUM_WORKERS)"
            echo "  --experiment_name NAME    Experiment name (default: auto-generated)"
            echo "  --config PATH             Config file path (default: auto-detect)"
            echo "  --help, -h                Show this help message"
            exit 0
            ;;
        *)
            echo "Unknown option: $1"
            echo "Use --help for usage information"
            exit 1
            ;;
    esac
done

# Check if training script exists
TRAIN_SCRIPT="train_mortality_los_complete.py"
if [ ! -f "$TRAIN_SCRIPT" ]; then
    echo "ERROR: Training script not found: $TRAIN_SCRIPT"
    exit 1
fi

# Check Python availability
if ! command -v python3 &> /dev/null; then
    echo "ERROR: python3 not found in PATH"
    exit 1
fi

# Set CUDA environment
export CUDA_VISIBLE_DEVICES="$GPU"
export PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True'

# Build command
CMD_ARGS=(
    --gpu "$GPU"
    --epochs "$EPOCHS"
    --lr "$LR"
    --batch_size "$BATCH_SIZE"
    --max_samples "$MAX_SAMPLES"
    --max_val_samples "$MAX_VAL_SAMPLES"
    --max_test_samples "$MAX_TEST_SAMPLES"
    --num_workers "$NUM_WORKERS"
)

# Add experiment name if provided
if [ -n "$EXPERIMENT_NAME" ]; then
    CMD_ARGS+=(--experiment_name "$EXPERIMENT_NAME")
fi

# Add config file if provided
if [ -n "$CONFIG_FILE" ]; then
    CMD_ARGS+=(--config "$CONFIG_FILE")
fi

# Print configuration
echo "Configuration:"
echo "  GPU: $GPU"
echo "  Epochs: $EPOCHS"
echo "  Learning Rate: $LR"
echo "  Batch Size: $BATCH_SIZE"
echo "  Max Training Samples: $MAX_SAMPLES"
echo "  Max Validation Samples: $MAX_VAL_SAMPLES"
echo "  Max Test Samples: $MAX_TEST_SAMPLES"
echo "  Num Workers: $NUM_WORKERS"
echo "  Experiment Name: ${EXPERIMENT_NAME:-auto-generated}"
if [ -n "$CONFIG_FILE" ]; then
    echo "  Config File: $CONFIG_FILE"
else
    echo "  Config File: auto-detect (will use offline config if offline data available)"
fi
echo ""

# Check GPU availability (optional)
if command -v nvidia-smi &> /dev/null; then
    echo "GPU Information:"
    nvidia-smi --query-gpu=index,name,memory.total --format=csv,noheader | head -n 1
    echo ""
fi

# Run training
echo "Starting training..."
echo "Command: python3 $TRAIN_SCRIPT ${CMD_ARGS[*]}"
echo ""

python3 "$TRAIN_SCRIPT" "${CMD_ARGS[@]}"

EXIT_CODE=$?

if [ $EXIT_CODE -eq 0 ]; then
    echo ""
    echo "=========================================="
    echo "Training completed successfully!"
    echo "=========================================="
else
    echo ""
    echo "=========================================="
    echo "Training failed with exit code: $EXIT_CODE"
    echo "=========================================="
    exit $EXIT_CODE
fi

