#!/bin/bash

# ==================== Configurable Parameters ====================
export CUDA_VISIBLE_DEVICES=0,1,2,3
BENCHMARK_NAME="clip_benchmark"             # Name for your benchmark run
DATASET_NAME="imagenet100"                  # Options: imagenet, imagenet100, cifar10, cifar100
MODEL_NAMES=("ViT-B-32")                    # Model architectures to evaluate - used for batch size selection
MODELS_FILE="models_to_benchmark.txt"       # Text file containing paths to model checkpoints (one per line)
ATTACK_TYPE="apgd"                          # Attack type (default: apgd)
NORM_TYPE="linf"                            # Norm for adversarial examples (options: linf, l2)
ITERATIONS=50                               # Number of attack iterations
NUM_WORKERS=20                              # Number of data loading workers

# Check if models file exists
if [ ! -f "$MODELS_FILE" ]; then
    echo "Error: Models file '$MODELS_FILE' not found."
    echo "Please create a text file with one model checkpoint path per line."
    exit 1
fi

# ==================== Dataset Paths ====================
if [ "$DATASET_NAME" == "imagenet" ]; then
    DATASET_ROOT="/YOUR_ROOT_PATH/data/ILSVRC/Data/CLS-LOC"
elif [ "$DATASET_NAME" == "imagenet100" ]; then
    DATASET_ROOT="/YOUR_ROOT_PATH/data/imagenet100/data"
elif [ "$DATASET_NAME" == "cifar10" ] || [ "$DATASET_NAME" == "cifar100" ]; then
    DATASET_ROOT="/mnt/raid10/ak-research-01/ak-research-01/codes/.cache/cifar"
else
    echo "Unknown dataset name: $DATASET_NAME"
    exit 1
fi

# ==================== Create Results Directory ====================
# Create base results directory
BASE_RESULTS_DIR="benchmark_results"
mkdir -p $BASE_RESULTS_DIR

# Create a timestamped directory for this benchmark run
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
RESULTS_DIR="${BASE_RESULTS_DIR}/${BENCHMARK_NAME}_${TIMESTAMP}"
RESULTS_DIR="benchmark_results/clip_benchmark_convnext"
mkdir -p $RESULTS_DIR

# Create logs directory
LOGS_DIR="${RESULTS_DIR}/logs"
mkdir -p $LOGS_DIR

# Main log file
LOG_FILE="${RESULTS_DIR}/benchmark.log"

# ==================== Main Script ====================
echo "Starting benchmark: $BENCHMARK_NAME" | tee -a $LOG_FILE
echo "Dataset: $DATASET_NAME" | tee -a $LOG_FILE
echo "Models file: $MODELS_FILE" | tee -a $LOG_FILE
echo "Timestamp: $TIMESTAMP" | tee -a $LOG_FILE
echo "Results directory: $RESULTS_DIR" | tee -a $LOG_FILE
echo "========================================" | tee -a $LOG_FILE

echo "All benchmarks completed successfully!" | tee -a $LOG_FILE
echo "Results saved to $RESULTS_DIR"

# Generate summary file with key metrics
SUMMARY_FILE="$RESULTS_DIR/summary.csv"
echo "Model,Checkpoint,TrainEps,Rho,Epsilon,Clean Accuracy,Robust Accuracy,Robustness Drop,Robustness Retention,L2_adv_clean,Cosine_sim_adv_clean,L2_adv_orig,L2_clean_orig,L2_clean_orig_normalized,Cosine_sim_adv_orig,Cosine_sim_clean_orig,L2_orig_adv_orig_clean,Cosine_sim_orig_adv_orig_clean,Attack Time" > $SUMMARY_FILE

# Extract key metrics from all log files
for LOG in "$LOGS_DIR"/*_eps*.log; do
    # Extract model name and parameters from filename
    # Example: ViT-B-32_ViT-B-32_openai_imagenet100_l2_pgd_eps1_rho0.1_iygCq_ViT-B-32_openai_imagenet100_l2_pgd_eps1_rho0.1_iygCq.pt_eps2.log
    MODEL=$(basename $LOG | cut -d'_' -f1)
    CHECKPOINT=$(basename $LOG | cut -d'_' -f1-2)
    TRAIN_EPS=$(echo $LOG | grep -o 'eps[0-9.]*' | head -1 | sed 's/eps//')
    RHO=$(echo $LOG | grep -o 'rho_[0-9.]*' | head -1 | sed 's/rho//')
    EPS=$(basename $LOG | rev | cut -d'_' -f1 | rev | sed 's/eps//' | sed 's/.log//')

    # Extract accuracy metrics
    CLEAN_ACC=$(grep "Clean Accuracy:" $LOG | awk '{print $3}' | sed 's/%//')
    ROBUST_ACC=$(grep "Robust Accuracy:" $LOG | awk '{print $3}' | sed 's/%//')
    ROBUST_DROP=$(grep "Robustness Drop:" $LOG | awk '{print $3}' | sed 's/%//')
    ROBUST_RETENTION=$(grep "Robustness Retention:" $LOG | awk '{print $3}' | sed 's/%//')
    
    # Extract embedding distance metrics
    L2_ADV_CLEAN=$(grep "L2 (Adv, Clean):" $LOG | awk '{print $NF}')
    COSINE_ADV_CLEAN=$(grep "Cosine Similarity (Adv, Clean):" $LOG | awk '{print $NF}')
    
    # Extract original model comparison metrics if available
    L2_ADV_ORIG=$(grep "L2 (Adv, Original):" $LOG | awk '{print $NF}')
    L2_CLEAN_ORIG=$(grep "L2 (Clean, Original):" $LOG | head -1 | awk '{print $NF}')
    L2_CLEAN_ORIG_NORM=$(grep "Normalized L2 (Clean, Original):" $LOG | awk '{print $NF}')
    COSINE_ADV_ORIG=$(grep "Cosine Similarity (Adv, Original):" $LOG | awk '{print $NF}')
    COSINE_CLEAN_ORIG=$(grep "Cosine Similarity (Clean, Orig):" $LOG | awk '{print $NF}')
    
    # Extract the new metrics
    L2_ORIG_ADV_ORIG_CLEAN=$(grep "L2 (Orig-Adv, Orig-Clean):" $LOG | awk '{print $NF}')
    COSINE_ORIG_ADV_ORIG_CLEAN=$(grep "Cosine Sim (Orig-Adv, Orig-Clean):" $LOG | awk '{print $NF}')
    
    # If metrics aren't available, use empty string
    if [ -z "$L2_ADV_ORIG" ]; then
        L2_ADV_ORIG=""
        L2_CLEAN_ORIG=""
        L2_CLEAN_ORIG_NORM=""
        COSINE_ADV_ORIG=""
        COSINE_CLEAN_ORIG=""
        L2_ORIG_ADV_ORIG_CLEAN=""
        COSINE_ORIG_ADV_ORIG_CLEAN=""
    fi
    
    # Extract timing info
    ATTACK_TIME=$(grep "Attack Generation Time:" $LOG | awk '{print $4}' | sed 's/s//')
    
    # Add to summary
    echo "$MODEL,$CHECKPOINT,$TRAIN_EPS,$RHO,$EPS,$CLEAN_ACC,$ROBUST_ACC,$ROBUST_DROP,$ROBUST_RETENTION,$L2_ADV_CLEAN,$COSINE_ADV_CLEAN,$L2_ADV_ORIG,$L2_CLEAN_ORIG,$L2_CLEAN_ORIG_NORM,$COSINE_ADV_ORIG,$COSINE_CLEAN_ORIG,$L2_ORIG_ADV_ORIG_CLEAN,$COSINE_ORIG_ADV_ORIG_CLEAN,$ATTACK_TIME" >> $SUMMARY_FILE
done

echo "Summary saved to $SUMMARY_FILE"

# Create a combined plot of robustness vs. epsilon for all models
if command -v python3 &> /dev/null; then
    echo "Generating summary plots..."
    
    # Create plots directory
    PLOTS_DIR="${RESULTS_DIR}/plots"
    mkdir -p $PLOTS_DIR
    
    # Create a simple plotting script
    PLOT_SCRIPT="$RESULTS_DIR/plot_results.py"
    cat > $PLOT_SCRIPT << 'EOF'
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

# Load the CSV file
csv_file = sys.argv[1]
plots_dir = sys.argv[2]
df = pd.read_csv(csv_file)

# Helper function to convert string columns to numeric
def convert_to_numeric(df, columns):
    for col in columns:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
    return df

# Convert all metric columns to numeric values
numeric_columns = ['Clean Accuracy', 'Robust Accuracy', 'Robustness Drop', 'Robustness Retention',
                  'L2_adv_clean', 'L2_adv_orig', 'L2_clean_orig', 'L2_clean_orig_normalized', 'L2_orig_adv_orig_clean',
                  'Cosine_sim_adv_clean', 'Cosine_sim_adv_orig', 'Cosine_sim_clean_orig', 'Cosine_sim_orig_adv_orig_clean']
df = convert_to_numeric(df, numeric_columns)

# ==================== PLOT GROUP 1: ACCURACY METRICS ====================
plt.figure(figsize=(14, 10))

# Set up a 2x2 subplot grid
fig, axs = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Accuracy Metrics vs. Epsilon', fontsize=16)

# Plot 1: Clean Accuracy
for model in df['Checkpoint'].unique():
    model_data = df[df['Checkpoint'] == model]
    axs[0, 0].plot(model_data['Epsilon'], model_data['Clean Accuracy'], 'o-', label=model)
axs[0, 0].set_title('Clean Accuracy')
axs[0, 0].set_xlabel('Epsilon (pixels)')
axs[0, 0].set_ylabel('Accuracy (%)')
axs[0, 0].grid(True, alpha=0.3)
axs[0, 0].legend()

# Plot 2: Robust Accuracy
for model in df['Checkpoint'].unique():
    model_data = df[df['Checkpoint'] == model]
    axs[0, 1].plot(model_data['Epsilon'], model_data['Robust Accuracy'], 'o-', label=model)
axs[0, 1].set_title('Robust Accuracy')
axs[0, 1].set_xlabel('Epsilon (pixels)')
axs[0, 1].set_ylabel('Accuracy (%)')
axs[0, 1].grid(True, alpha=0.3)
axs[0, 1].legend()

# Plot 3: Robustness Drop
for model in df['Checkpoint'].unique():
    model_data = df[df['Checkpoint'] == model]
    axs[1, 0].plot(model_data['Epsilon'], model_data['Robustness Drop'], 'o-', label=model)
axs[1, 0].set_title('Robustness Drop')
axs[1, 0].set_xlabel('Epsilon (pixels)')
axs[1, 0].set_ylabel('Accuracy Drop (%)')
axs[1, 0].grid(True, alpha=0.3)
axs[1, 0].legend()

# Plot 4: Robustness Retention
for model in df['Checkpoint'].unique():
    model_data = df[df['Checkpoint'] == model]
    axs[1, 1].plot(model_data['Epsilon'], model_data['Robustness Retention'], 'o-', label=model)
axs[1, 1].set_title('Robustness Retention')
axs[1, 1].set_xlabel('Epsilon (pixels)')
axs[1, 1].set_ylabel('Retention (%)')
axs[1, 1].grid(True, alpha=0.3)
axs[1, 1].legend()

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.savefig(f"{plots_dir}/accuracy_metrics.png", dpi=300)
plt.close()

# Combined accuracy plot
plt.figure(figsize=(12, 8))
for model in df['Checkpoint'].unique():
    model_data = df[df['Checkpoint'] == model]
    plt.plot(model_data['Epsilon'], model_data['Clean Accuracy'], 'o-', label=f'{model} (Clean)')
    plt.plot(model_data['Epsilon'], model_data['Robust Accuracy'], 's--', label=f'{model} (Robust)')

plt.title('Clean vs. Robust Accuracy')
plt.xlabel('Epsilon (pixels)')
plt.ylabel('Accuracy (%)')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(f"{plots_dir}/clean_vs_robust_accuracy.png", dpi=300)
plt.close()

# ==================== PLOT GROUP 2: L2 DISTANCE METRICS ====================
# Check if we have L2 metrics
l2_columns = ['L2_adv_clean', 'L2_adv_orig', 'L2_clean_orig', 'L2_clean_orig_normalized', 'L2_orig_adv_orig_clean']
available_l2_columns = [col for col in l2_columns if col in df.columns and not df[col].isnull().all()]

if available_l2_columns:
    # Determine subplot dimensions based on number of metrics
    n_metrics = len(available_l2_columns)
    n_rows = int(np.ceil(n_metrics / 2))
    n_cols = min(2, n_metrics)
    
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(14, 4 * n_rows))
    fig.suptitle('L2 Distance Metrics vs. Epsilon', fontsize=16)
    
    # Make axs a 2D array if it's 1D or a single axis
    if n_metrics == 1:
        axs = np.array([[axs]])
    elif n_rows == 1:
        axs = axs.reshape(1, -1)
    
    # Plot each L2 metric
    for i, metric in enumerate(available_l2_columns):
        row = i // n_cols
        col = i % n_cols
        
        for model in df['Checkpoint'].unique():
            model_data = df[df['Checkpoint'] == model]
            axs[row, col].plot(model_data['Epsilon'], model_data[metric], 'o-', label=model)
        
        title = metric.replace('_', ' ')
        axs[row, col].set_title(title)
        axs[row, col].set_xlabel('Epsilon (pixels)')
        axs[row, col].set_ylabel('L2 Distance')
        axs[row, col].grid(True, alpha=0.3)
        axs[row, col].legend()
    
    # Hide unused subplots
    for i in range(n_metrics, n_rows * n_cols):
        row = i // n_cols
        col = i % n_cols
        fig.delaxes(axs[row, col])
    
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.savefig(f"{plots_dir}/l2_distance_metrics.png", dpi=300)
    plt.close()
    
    # Combined L2 plot
    plt.figure(figsize=(12, 8))
    for model in df['Checkpoint'].unique():
        model_data = df[df['Checkpoint'] == model]
        for metric in available_l2_columns:
            if not model_data[metric].isnull().all():
                label = f'{model} ({metric.replace("L2_", "")})'
                plt.plot(model_data['Epsilon'], model_data[metric], 'o-', label=label)
    
    plt.title('L2 Distance Metrics')
    plt.xlabel('Epsilon (pixels)')
    plt.ylabel('L2 Distance')
    plt.grid(True, alpha=0.3)
    plt.legend(loc='best', fontsize='small')
    plt.tight_layout()
    plt.savefig(f"{plots_dir}/combined_l2_metrics.png", dpi=300)
    plt.close()

# ==================== PLOT GROUP 3: COSINE SIMILARITY METRICS ====================
# Check if we have cosine similarity metrics
cosine_columns = ['Cosine_sim_adv_clean', 'Cosine_sim_adv_orig', 'Cosine_sim_clean_orig', 'Cosine_sim_orig_adv_orig_clean']
available_cosine_columns = [col for col in cosine_columns if col in df.columns and not df[col].isnull().all()]

if available_cosine_columns:
    # Determine subplot dimensions based on number of metrics
    n_metrics = len(available_cosine_columns)
    n_rows = int(np.ceil(n_metrics / 2))
    n_cols = min(2, n_metrics)
    
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(14, 4 * n_rows))
    fig.suptitle('Cosine Similarity Metrics vs. Epsilon', fontsize=16)
    
    # Make axs a 2D array if it's 1D or a single axis
    if n_metrics == 1:
        axs = np.array([[axs]])
    elif n_rows == 1:
        axs = axs.reshape(1, -1)
    
    # Plot each cosine similarity metric
    for i, metric in enumerate(available_cosine_columns):
        row = i // n_cols
        col = i % n_cols
        
        for model in df['Checkpoint'].unique():
            model_data = df[df['Checkpoint'] == model]
            axs[row, col].plot(model_data['Epsilon'], model_data[metric], 'o-', label=model)
        
        title = metric.replace('_', ' ')
        axs[row, col].set_title(title)
        axs[row, col].set_xlabel('Epsilon (pixels)')
        axs[row, col].set_ylabel('Cosine Similarity')
        axs[row, col].grid(True, alpha=0.3)
        axs[row, col].legend()
    
    # Hide unused subplots
    for i in range(n_metrics, n_rows * n_cols):
        row = i // n_cols
        col = i % n_cols
        fig.delaxes(axs[row, col])
    
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.savefig(f"{plots_dir}/cosine_similarity_metrics.png", dpi=300)
    plt.close()
    
    # Combined cosine similarity plot
    plt.figure(figsize=(12, 8))
    for model in df['Checkpoint'].unique():
        model_data = df[df['Checkpoint'] == model]
        for metric in available_cosine_columns:
            if not model_data[metric].isnull().all():
                label = f'{model} ({metric.replace("Cosine_sim_", "")})'
                plt.plot(model_data['Epsilon'], model_data[metric], 'o-', label=label)
    
    plt.title('Cosine Similarity Metrics')
    plt.xlabel('Epsilon (pixels)')
    plt.ylabel('Cosine Similarity')
    plt.grid(True, alpha=0.3)
    plt.legend(loc='best', fontsize='small')
    plt.tight_layout()
    plt.savefig(f"{plots_dir}/combined_cosine_metrics.png", dpi=300)
    plt.close()

# ==================== SPECIAL PLOTS FOR NEW METRICS ====================
# If we have the new metrics, create a special comparison plot
if ('L2_orig_adv_orig_clean' in df.columns and not df['L2_orig_adv_orig_clean'].isnull().all() and
    'L2_adv_clean' in df.columns and not df['L2_adv_clean'].isnull().all()):
    
    plt.figure(figsize=(12, 8))
    for model in df['Checkpoint'].unique():
        model_data = df[df['Checkpoint'] == model]
        plt.plot(model_data['Epsilon'], model_data['L2_adv_clean'], 'o-', 
                 label=f'{model} (Fine-tuned: Adv vs Clean)')
        plt.plot(model_data['Epsilon'], model_data['L2_orig_adv_orig_clean'], 's--', 
                 label=f'{model} (Original: Adv vs Clean)')
    
    plt.title('Comparison of L2 Distances in Original vs Fine-tuned Model')
    plt.xlabel('Epsilon (pixels)')
    plt.ylabel('L2 Distance')
    plt.grid(True, alpha=0.3)
    plt.legend(loc='best')
    plt.tight_layout()
    plt.savefig(f"{plots_dir}/original_vs_finetuned_l2.png", dpi=300)
    plt.close()
    
    # Same for cosine similarity
    if ('Cosine_sim_orig_adv_orig_clean' in df.columns and not df['Cosine_sim_orig_adv_orig_clean'].isnull().all() and
        'Cosine_sim_adv_clean' in df.columns and not df['Cosine_sim_adv_clean'].isnull().all()):
        
        plt.figure(figsize=(12, 8))
        for model in df['Checkpoint'].unique():
            model_data = df[df['Checkpoint'] == model]
            plt.plot(model_data['Epsilon'], model_data['Cosine_sim_adv_clean'], 'o-', 
                     label=f'{model} (Fine-tuned: Adv vs Clean)')
            plt.plot(model_data['Epsilon'], model_data['Cosine_sim_orig_adv_orig_clean'], 's--', 
                     label=f'{model} (Original: Adv vs Clean)')
        
        plt.title('Comparison of Cosine Similarity in Original vs Fine-tuned Model')
        plt.xlabel('Epsilon (pixels)')
        plt.ylabel('Cosine Similarity')
        plt.grid(True, alpha=0.3)
        plt.legend(loc='best')
        plt.tight_layout()
        plt.savefig(f"{plots_dir}/original_vs_finetuned_cosine.png", dpi=300)
        plt.close()

print(f"Plots saved to {plots_dir}")
EOF

    # Run the plotting script
    python3 $PLOT_SCRIPT $SUMMARY_FILE $PLOTS_DIR
    echo "Summary plots generated in $PLOTS_DIR"
else
    echo "Python not found, skipping plot generation"
fi

# Create a README file with benchmark information
README_FILE="$RESULTS_DIR/README.md"
cat > $README_FILE << EOF
# Benchmark Results: ${BENCHMARK_NAME}

## Configuration
- **Date**: $(date)
- **Dataset**: ${DATASET_NAME}
- **Model Types**: ${MODEL_NAMES[@]}
- **Attack Type**: ${ATTACK_TYPE}
- **Norm Type**: ${NORM_TYPE}
- **Attack Iterations**: ${ITERATIONS}
- **Epsilon Values**: ${EPS_VALUES[@]}

## Contents
- \`logs/\`: Individual log files for each model/epsilon combination
- \`plots/\`: Visualizations of benchmark results
- \`summary.csv\`: CSV file with all metrics
- \`benchmark.log\`: Full benchmark execution log

## Metrics Included
- Clean and robust accuracy
- Robustness drop and retention
- L2 distances between:
  - Adversarial and clean embeddings
  - Adversarial and original embeddings
  - Clean and original embeddings
  - Original model's adversarial and clean embeddings
- Cosine similarities between all embedding pairs
- Timing information

## Models Evaluated
EOF

# Add model information to README
while IFS= read -r CHECKPOINT
do
    # Skip empty lines or commented lines
    if [[ -z "$CHECKPOINT" || "$CHECKPOINT" == \#* ]]; then
        continue
    fi
    
    if [ -f "$CHECKPOINT" ]; then
        echo "- \`$(basename "$CHECKPOINT")\`" >> $README_FILE
    fi
done < "$MODELS_FILE"

echo "README created at $README_FILE"
echo "Benchmark complete!" 