#!/bin/bash

# Train Crystal Classifier
# This script trains a transformer model to classify whether crystal strings are permuted or not

# Function to list available models
list_models() {
    echo "Available model sizes:"
    echo "  nano: microsoft/DialoGPT-tiny (~33M parameters) - SMALLEST"
    echo "  micro: distilbert-base-uncased (~66M parameters)"
    echo "  tiny: bert-base-uncased (~110M parameters)"
    echo "  small: roberta-base (~125M parameters)"
    echo "  medium: roberta-large (~355M parameters)"
    echo "  large: microsoft/DialoGPT-medium (~774M parameters)"
    echo "  custom: Use --model-name to specify custom model"
}

# Check if user wants to list models
if [ "$1" = "--list-models" ]; then
    list_models
    exit 0
fi

# Set experiment name
RUN_NAME="crystal_permutation_classifier_$(date +%Y%m%d_%H%M%S)"

# Set data path (modify this to point to your data)
DATA_PATH="data/basic"

# Model selection (choose from: nano, micro, tiny, small, medium, large, custom)
MODEL_SIZE="nano"  # Change this to use different model sizes

# Training parameters
BATCH_SIZE=16
MAX_LENGTH=512
LEARNING_RATE=2e-5
NUM_EPOCHS=10

echo "Starting crystal classification training..."
echo "Run name: $RUN_NAME"
echo "Data path: $DATA_PATH"
echo "Model size: $MODEL_SIZE"

# Run the training
python crystal_classifier.py \
    --run-name "$RUN_NAME" \
    --data-path "$DATA_PATH" \
    --model-size "$MODEL_SIZE" \
    --batch-size $BATCH_SIZE \
    --max-length $MAX_LENGTH \
    --learning-rate $LEARNING_RATE \
    --num-epochs $NUM_EPOCHS \
    --use-amp

echo "Training completed!"
echo "Check wandb for training logs and metrics" 