#!/bin/bash
# TODO: Core comparison: focal loss, epi_only task, random split
# Generated individual experiment script with embedded parameters
# Function to display usage information
usage() {
    echo "Usage: $0 --gpu_id <gpu_id> --batch_size <batch_size> --epochs <epochs> --server <server_name> [--pretrain_epochs <pretrain_epochs>]"
    echo ""
    echo "Arguments:"
    echo "  --gpu_id         GPU ID to use (required)"
    echo "  --batch_size     Batch size for training (required)"
    echo "  --epochs         Number of training epochs (required)"
    echo "  --server         Server name (amai, dice, etc.) (required)"
    echo "  --pretrain_epochs Number of pretraining epochs (optional, default: 5)"
    echo ""
    echo "Example:"
    echo "  $0 --gpu_id 0 --batch_size 8 --epochs 30 --server amai --pretrain_epochs 10"
    echo ""
    echo "Note: Model parameters are embedded in this script."
    echo "Edit the script directly to modify model configuration for this experiment."
    exit 1
}

# Parse command line arguments
while [[ "$#" -gt 0 ]]; do
    case $1 in
        --gpu_id)
            gpu_id="$2"
            shift 2
            ;;
        --batch_size)
            batch_size="$2"
            shift 2
            ;;
        --epochs)
            epochs="$2"
            shift 2
            ;;
        --server)
            server="$2"
            shift 2
            ;;
        --pretrain_epochs)
            pretrain_epochs="$2"
            shift 2
            ;;
        *)
            echo "Unknown parameter: $1"
            usage
            ;;
    esac
done

# Check if all required parameters are provided
if [ -z "$gpu_id" ] || [ -z "$batch_size" ] || [ -z "$epochs" ] || [ -z "$server" ]; then
    echo "Error: All parameters (gpu_id, batch_size, epochs, server) are required."
    usage
fi

# Set default value for pretrain_epochs if not provided
pretrain_epochs=${pretrain_epochs:-5}

echo "- Server: $server | GPU: $gpu_id | Batch: $batch_size | Epochs: $epochs | Pretrain Epochs: $pretrain_epochs"
echo "- Parameters: Embedded in script (modify script to change model config)"

wandb_project="m3epi_v3_dev"
dataset="epiformer_dataset.pkl"

# Create run_id
run_id="raad_20250904-201819"

# Create wandb notes
wandb_notes="${server}-gpu${gpu_id}-pre*-*bce-count-reg-no-edge-infonce"

echo "- Run ID: $run_id"
echo "- Wandb Notes: $wandb_notes"
echo "- Using embedded model configuration"

# Create logs directory if it doesn't exist
mkdir -p logs

# Execute the experiment
echo "- Starting training..."

nohup python trainer.py \
    mode=val \
    seed=42 \
    num_threads=0 \
    gpu_id="$gpu_id" \
    wandb.project="$wandb_project" \
    dataset.split.method="epitope_ratio" \
    dataset.graph_type="raad-plm" \
    dataset.plm_type="esm2_3b" \
    dataset.graph_num_relations=4 \
    dataset.tensor="$dataset" \
    hparams.train.num_epochs="$epochs" \
    hparams.train.batch_size="$batch_size" \
    hparams.pretrain.num_epochs="$pretrain_epochs" \
    hparams.pretrain.lr=0.0002 \
    hparams.train.learning_rate=0.0003 \
    hparams.train.weight_decay=0.0001013 \
    hparams.train.kfolds=2 \
    hparams.train.regularization.use_l2_reg=false \
    hparams.train.scheduler="reduce_lr_on_plateau" \
    run_id="$run_id" \
    wandb.notes="$wandb_notes" \
    num_threads=3 \
    resume=false \
    model.enable_pretraining=false \
    model.ab_encoder.resmp_type="egnn" \
    model.ab_encoder.residue_layers=4 \
    model.ab_encoder.residue_hidden_dim=128 \
    model.ab_encoder.feature_fusion_type=gated \
    model.ab_encoder.residue_dim=128 \
    model.ab_encoder.plm_dim=128 \
    model.ab_encoder.plm_in_dim=512 \
    model.ag_encoder.resmp_type="egnn" \
    model.ag_encoder.residue_layers=4 \
    model.ag_encoder.residue_hidden_dim=128 \
    model.ag_encoder.feature_fusion_type=concat \
    model.ag_encoder.residue_dim=128 \
    model.ag_encoder.plm_dim=128 \
    model.dropout_rates.decoder=0.2 \
    model.dropout_rates.projections=0.0262 \
    model.dropout=0.2395 \
    model.decoder.type="cross_attention" \
    model.decoder.d_k=64 \
    model.decoder.d_ff=256 \
    model.decoder.d_model=128 \
    model.decoder.n_heads=8 \
    model.decoder.decoder_layers=3 \
    model.decoder.sampling_strat="top_k_mean_2" \
    model.activation="silu" \
    model.use_layer_norm=true \
    callbacks.early_stopping.patience=10 \
    callbacks.checkpoint_interval=2 \
    loss.node_prediction.enabled=true \
    loss.node_prediction.weight=1 \
    loss.node_prediction.name="bce" \
    loss.node_prediction.task="epi_only" \
    loss.node_prediction.bce_weight=4.5030 \
    loss.node_prediction.dice_weight=0.6610 \
    loss.node_prediction.smoothness_weight=0.3521 \
    loss.node_prediction.consistency_weight=0.5807 \
    loss.node_prediction.dice_enabled=true \
    loss.node_prediction.count_regularizer_enabled=true \
    loss.node_prediction.smoothness_enabled=false \
    loss.node_prediction.edge_node_consistency_enabled=false \
    loss.node_prediction.epi_pos_weight=18 \
    loss.node_prediction.para_pos_weight=3 \
    loss.count_regularizer.per_graph_matching=true \
    loss.count_regularizer.epitope_weight=0.1122 \
    loss.count_regularizer.paratope_weight=0.1 \
    loss.count_regularizer.dataset_prior=false \
    loss.count_regularizer.epitope_prior_mean=14.6 \
    loss.count_regularizer.prior_weight=0.05 \
    loss.count_regularizer.anneal_epochs=10 \
    loss.label_smoothing=0.1 \
    loss.class_balance.beta=0.9999 \
    loss.edge_prediction.enabled=false \
    loss.edge_prediction.weight=0.3 \
    loss.edge_prediction.pos_weight=28.2386 \
    loss.edge_count_regularizer.enabled=false \
    loss.edge_count_regularizer.weight=0.1 \
    loss.contrastive.enabled=false \
    loss.contrastive.name="infonce" \
    loss.contrastive.weight=0.5 \
    loss.contrastive.temperature=0.395 \
    loss.contrastive.inter_weight=0.456 \
    loss.contrastive.intra_weight=0.456 \
    loss.gwnce.weight=0.1 \
    loss.gwnce.cut_way=2 \
    loss.gwnce.cut_rate=0.5 \
    model.decoder.predict_distances=false \
    loss.auxiliary_distance.enabled=true \
    loss.auxiliary_distance.weight=0.5 \
    loss.auxiliary_distance.distance_weighting=true \
    loss.auxiliary_distance.class_balancing=true \
    loss.auxiliary_distance.max_distance=20.0 \
    loss.force.enabled=false \
    loss.force.weight=0.01 \
    loss.force.bond_weight=1 \
    loss.force.angle_weight=0.5 \
    loss.force.smooth_alpha=1 \
    loss.force.smooth_weight=0.1 \
    loss.force.bond_tolerance=0.1 \
    loss.force.angle_tolerance=0.1 \
    loss.walle.enabled=false \
    > "logs/${run_id}_output.log" 2>&1 &

pid=$!
echo "- Experiment started successfully (PID: $pid)"
echo "- Monitor with: tail -f logs/${run_id}_output.log"
echo "- Kill with: kill $pid"