#!/usr/bin/env python3
"""
Test script for experiment aggregation utility.

This script demonstrates how to use the experiment aggregation functions
and provides a simple interface for analyzing results.
"""

import sys
from pathlib import Path
import json

# Add the current directory to Python path
sys.path.append(str(Path(__file__).parent))

from experiment_aggregator import (
    find_experiment_folders,
    get_experiment_summary,
    create_comprehensive_dataframe,
    extract_experiment_metadata
)


def main():
    """Main function to test the aggregation utilities."""
    
    # Default path - you can modify this
    data_path = "../data_files/nlp_training"


    default_path = Path(data_path)
    
    if len(sys.argv) > 1:
        root_path = Path(sys.argv[1])
    else:
        root_path = default_path
    
    if not root_path.exists():
        print(f"Error: Path {root_path} does not exist")
        print(f"Usage: python {sys.argv[0]} [path_to_results]")
        return
    
    print(f"Analyzing results in: {root_path}")
    print("=" * 60)
    
    # Test 1: Get experiment summary
    print("\n1. Getting experiment summary...")
    try:
        summary = get_experiment_summary(root_path)
        print(json.dumps(summary, indent=2))
    except Exception as e:
        print(f"Error getting summary: {e}")
        return
    
    # Test 2: Find experiment folders
    print("\n2. Finding experiment folders...")
    try:
        folders = find_experiment_folders(root_path, require_both_attacks=False)
        print(f"Found {len(folders)} experiment folders")
        
        if folders:
            print("\nFirst 5 experiment folders:")
            for i, folder in enumerate(folders[:5]):
                print(f"  {i+1}. {folder}")
    except Exception as e:
        print(f"Error finding folders: {e}")
        return
    
    # Test 3: Extract metadata from one experiment
    if folders:
        print("\n3. Extracting metadata from first experiment...")
        try:
            exp_info = extract_experiment_metadata(folders[0])
            if exp_info:
                print(f"  Path: {exp_info.path}")
                print(f"  Machine: {exp_info.machine_name}")
                print(f"  Training Strategy: {exp_info.training_strategy}")
                print(f"  Model: {exp_info.model_name}")
                print(f"  Dataset: {exp_info.dataset_name}")
                print(f"  Epochs: {exp_info.num_epochs}")
                print(f"  Seed: {exp_info.seed}")
                print(f"  Has A2T: {exp_info.has_a2t}")
                print(f"  Has TextFooler: {exp_info.has_textfooler}")
            else:
                print("  Failed to extract metadata")
        except Exception as e:
            print(f"Error extracting metadata: {e}")
    
    # Test 4: Create comprehensive DataFrame
    print("\n4. Creating comprehensive DataFrame...")
    try:
        df = create_comprehensive_dataframe(root_path, require_both_attacks=False)
        
        if not df.empty:
            print(f"  DataFrame shape: {df.shape}")
            print(f"  Columns: {len(df.columns)}")
            
            print("\n  Column names:")
            for i, col in enumerate(sorted(df.columns), 1):
                print(f"    {i:2d}. {col}")
            
            print(f"\n  Training strategies: {df['training_strategy'].unique() if 'training_strategy' in df.columns else 'N/A'}")
            print(f"  Models: {df['model_name'].unique() if 'model_name' in df.columns else 'N/A'}")
            print(f"  Datasets: {df['dataset_name'].unique() if 'dataset_name' in df.columns else 'N/A'}")
            
            if 'attack_type' in df.columns:
                print(f"  Attack types: {df['attack_type'].value_counts().to_dict()}")
            
            # Save to CSV
            output_file_name = data_path.replace('.', '').replace('/', '_') + "_aggregated_results.csv"
            output_file = Path(output_file_name)
            df.to_csv(output_file, index=False)
            print(f"\n  Results saved to: {output_file}")
            
            # Show sample rows
            print(f"\n  Sample data (first 3 rows):")
            display_cols = ['training_strategy', 'model_name', 'dataset_name', 'attack_type', 'test_accuracy']
            available_cols = [col for col in display_cols if col in df.columns]
            if available_cols:
                print(df[available_cols].head(3).to_string(index=False))
        else:
            print("  No data found in DataFrame")
            
    except Exception as e:
        print(f"Error creating DataFrame: {e}")
        import traceback
        traceback.print_exc()
    
    print("\n" + "=" * 60)
    print("Analysis complete!")


if __name__ == "__main__":
    main()
