#!/bin/bash

# there exists two output directories for each model: "cvxNN_trained_${model_name}" and "finetuned_cvx_${model_name}_inference_ready"
# output log file is "cvxdpo_pipeline_sft_log_datasplit.txt"

export WANDB_API_KEY="14eb70e2cb32fb3e6701239a44ae3ccbbfd6b8bf"

# list of all SFT BASE trained first models
# List of model_names to process
model_names=(
    # sft base models
    # "dolphin_imdb_sft"
    # "dolphin_edu_sft"
    # "dolphin_ultra_sft"
    # "llama_edu_sft"
    # "llama_imdb_sft"
    # "llama_ultra_sft"
    # "mistral_edu_sft"
    "mistral_imdb_sft"
    # "mistral_ultra_sft"
    # "distilgpt2_edu_sft"
    # "distilgpt2_imdb_sft"
    # "distilgpt2_ultra_sft"
    # "gpt2_edu_sft"
    # "gpt2_imdb_sft"
    # "gpt2_ultra_sft"

    # no sft models
    "dolphin_imdb"
    "dolphin_edu"
    "dolphin_ultra"
    "llama_edu"
    "llama_imdb"
    "llama_ultra"
    "mistral_edu"
    "mistral_imdb"
    "mistral_ultra"
    "distilgpt2_edu"
    "distilgpt2_imdb"
    "distilgpt2_ultra"
    "gpt2_edu"
    "gpt2_imdb"
    "gpt2_ultra"
)

# Paths
TODAY=$(date +%Y%m%d)
has_sft=false
for m in "${model_names[@]}"; do
    if [[ $m == *_sft ]]; then
        has_sft=true
        break
    fi
done

if $has_sft; then
    LOG_FILE="Cvxdpo_pipeline_sftbase_log_datasplit_${TODAY}.txt"
else
    LOG_FILE="Cvxdpo_pipeline_log_datasplit_${TODAY}.txt"
fi

# Paths to scripts
CRONOS_SCRIPT="cronos_trainer.py"  
FINETUNE_SCRIPT="finetune_cvxdpo.py"  

# TFLOPs estimation helper (for NVIDIA RTX 4090 @ 70% bf16 Tensor Core efficiency)
estimate_tflops() {
    duration=$1  # in seconds
    gflops_per_sec=231000
    tflops_used=$(echo "scale=2; $gflops_per_sec * $duration / 1000" | bc)
    echo "$tflops_used"
}

# Clean or initialize log file
echo "--- CVX-DPO FULL PIPELINE RUN LOG ---" > "$LOG_FILE"
echo "Start time: $(date)" >> "$LOG_FILE"
echo "" >> "$LOG_FILE"

for model_name in "${model_names[@]}"; do
    echo 
    echo
    echo "----------------------------------------"
    echo "Processing model: $model_name"
    echo "----------------------------------------"

    echo "[$model_name] Start time: $(date)" >> $LOG_FILE
    pipeline_start_time=$(date +%s)

    # Set output directory from cronos trainer
    output_dir="/home/miria/CVXDPO/cvxNN/cvxNN_trained_${model_name}"

    # Run CRONOS trainer
    echo "Running CRONOS trainer for $model_name..."
    cronos_start_time=$(date +%s)
    
    # Run CRONOS and capture any potential error
    if ! python $CRONOS_SCRIPT --model_name $model_name; then
        echo "[$model_name] ERROR: CRONOS training failed!" >> $LOG_FILE
        echo "[$model_name] ERROR: CRONOS training failed!"
        continue  # Skip to next model
    fi
    
    cronos_end_time=$(date +%s)
    cronos_duration=$((cronos_end_time - cronos_start_time))
    cronos_tflops=$(estimate_tflops $cronos_duration)

    # Append CRONOS results to log
    cronos_results_file="${output_dir}/cronos_results.txt"
    if [ -f "$cronos_results_file" ]; then
        echo "[$model_name] CRONOS Results:" >> $LOG_FILE
        cat "$cronos_results_file" >> $LOG_FILE
        echo "[$model_name] CRONOS Training time: $cronos_duration seconds" >> $LOG_FILE
        echo "[$model_name] CRONOS Estimated TFLOPS: $cronos_tflops" >> $LOG_FILE
    else
        echo "[$model_name] CRONOS Results: Not Found" >> $LOG_FILE
        echo "[$model_name] ERROR: CRONOS results file not found!" >> $LOG_FILE
        echo "[$model_name] ERROR: Expected file at: $cronos_results_file" >> $LOG_FILE
        continue  # Skip to next model if no results
    fi

    # Construct the model path (model_name.pkl inside output dir)
    model_path="${output_dir}/${model_name}_trained_cvx_mlp.pkl"
    
    # Check if model file exists
    if [ ! -f "$model_path" ]; then
        echo "[$model_name] ERROR: Trained model not found at $model_path" >> $LOG_FILE
        echo "[$model_name] ERROR: Trained model not found at $model_path"
        continue  # Skip to next model
    fi

    # Run fine-tuning with timing and TFLOPS information from cronos
    echo "Running CVX-DPO fine-tuning for $model_name..."
    finetune_start_time=$(date +%s)
    {
        echo "========== [FINE-TUNE LOG: $model_name] =========="
        if ! python $FINETUNE_SCRIPT --model_path "$model_path" --cronos_training_time $cronos_duration --cronos_tflops $cronos_tflops; then
            echo "[$model_name] ERROR: Fine-tuning failed!"
            echo "[$model_name] ERROR: Fine-tuning failed!" >> $LOG_FILE
        fi
        echo "========== [END FINE-TUNE LOG: $model_name] =========="
    } | tee -a "$LOG_FILE"
    finetune_end_time=$(date +%s)
    finetune_duration=$((finetune_end_time - finetune_start_time))
    finetune_tflops=$(estimate_tflops $finetune_duration)

    # Extract the fine-tuned model directory from the log
    finetuned_output_dir=$(grep '\[FINETUNE_OUTPUT_DIR\]' "$LOG_FILE" | tail -n1 | sed 's/\[FINETUNE_OUTPUT_DIR\]//g' | xargs)

    # Timing and TFLOPs usage for the entire pipeline
    pipeline_end_time=$(date +%s)
    pipeline_duration=$((pipeline_end_time - pipeline_start_time))
    total_duration=$((cronos_duration + finetune_duration))
    total_tflops=$(echo "scale=2; $cronos_tflops + $finetune_tflops" | bc)
    pipeline_duration_min=$(echo "$pipeline_duration / 60" | bc)

    echo "[$model_name] End time: $(date)" >> $LOG_FILE
    echo "[$model_name] Pipeline duration: $pipeline_duration seconds (~${pipeline_duration_min} min)" >> $LOG_FILE
    echo "[$model_name] CRONOS duration: $cronos_duration seconds" >> $LOG_FILE
    echo "[$model_name] Finetune duration: $finetune_duration seconds" >> $LOG_FILE
    echo "[$model_name] Total compute duration: $total_duration seconds" >> $LOG_FILE
    echo "[$model_name] Total TFLOPS used: $total_tflops" >> $LOG_FILE
    echo "[$model_name] Trained model: $model_path" >> $LOG_FILE
    echo "[$model_name] Finetuned model dir: $finetuned_output_dir" >> $LOG_FILE
    echo "----------------------------------------" >> $LOG_FILE
    echo "" >> $LOG_FILE
    
    # Log complete end-to-end results to wandb using a separate python script
    cat > log_final_results.py << EOF
import wandb
import sys

# Initialize wandb run for final combined metrics
wandb.init(
    project="Neurips_coala", 
    name=f"pipeline_summary_{sys.argv[1]}", 
    config={
        "model_name": sys.argv[1],
        "pipeline_duration": float(sys.argv[2]),
        "cronos_duration": float(sys.argv[3]),
        "finetune_duration": float(sys.argv[4]),
        "total_tflops": float(sys.argv[5]),
        "trained_model_path": sys.argv[6],
        "finetuned_model_dir": sys.argv[7],
    }
)

# Log the metrics
wandb.log({
    "pipeline_duration": float(sys.argv[2]),
    "cronos_duration": float(sys.argv[3]),
    "finetune_duration": float(sys.argv[4]),
    "total_duration": float(sys.argv[3]) + float(sys.argv[4]),
    "total_tflops": float(sys.argv[5]),
})

wandb.finish()
EOF

    # Run the final results logging script
    python log_final_results.py "$model_name" "$pipeline_duration" "$cronos_duration" "$finetune_duration" "$total_tflops" "$model_path" "$finetuned_output_dir"
done

echo "Pipeline completed at: $(date)" >> $LOG_FILE
echo "All results have been logged to Weights & Biases project: coala"

# Clean up temporary script
rm -f log_final_results.py