import pandas as pd
import sys
import os
import numpy as np
from unittest.mock import patch

from test_mil import main as run_mil_experiment

def main():
    target_datasets = ['tiger', 'fox', 'elephant']
    target_models = ['kf_pooling', 'hf_pooling']
    num_trials = 5 
    
    results_list = []

    for ds in target_datasets:
        for md in target_models:
            trial_scores = []
            print(f"\n{'='*60}")
            print(f"STARTING BENCHMARK: Dataset={ds}, Model={md}")
            print(f"Running {num_trials} trials for stability...")
            print(f"{'='*60}\n")
            
            for i in range(num_trials):
                print(f"--- Trial {i+1}/{num_trials} ---")
                
                test_args = [
                    "test_mil.py", 
                    "--dataset", ds, 
                    "--model", md, 
                    "--epochs", "100",
                    "--batch-size", "16",
                    "--lr", "0.001",
                    "--gamma", "0.96",
                    "--seed", str(i + 42) 
                ]

                with patch.object(sys, 'argv', test_args):
                    try:
                        mean_auc = run_mil_experiment()
                        trial_scores.append(mean_auc)
                    except Exception as e:
                        print(f"Error in Trial {i+1} for {md} on {ds}: {e}")

            if trial_scores:
                final_mean = np.mean(trial_scores)
                final_std = np.std(trial_scores)
                
                results_list.append({
                    "Dataset": ds,
                    "Model": md,
                    "Mean_AUC": round(final_mean, 4),
                    "Std_Dev": round(final_std, 4),
                    "Trials": len(trial_scores)
                })

    print("\nBenchmark sequence complete.")
    
    if results_list:
        df = pd.DataFrame(results_list)
        output_file = 'mil_benchmark_results_with_std.csv'
        df.to_csv(output_file, index=False)
        
        print("\n" + "#"*50)
        print(f"SUCCESS: Results saved to {os.path.abspath(output_file)}")
        print("#"*50)
        print(df.to_string(index=False))
    else:
        print("Warning: No results were collected.")

if __name__ == "__main__":
    main()