from typing import Dict, Any, Callable
from common_evaluator import CommonEvaluator
from time_series_datasets.TSQADataset import TSQADataset
from time_series_datasets.pamap2.PAMAP2AccQADataset import PAMAP2AccQADataset
from time_series_datasets.pamap2.PAMAP2CoTQADataset import PAMAP2CoTQADataset

# Import evaluation functions
from evaluate_tsqa import evaluate_tsqa
from evaluate_pamap import evaluate_pamap_acc, evaluate_pamap_cot


def main():
    """Main function to run comprehensive evaluation across all datasets."""
    
    # Define models to evaluate
    model_names = [
        "meta-llama/Llama-3.2-1B",
        # Add more models as needed
        # "google/gemma-3n-e2b",
        # "google/gemma-3n-e2b-it",
        # "microsoft/DialoGPT-medium",
        # "gpt2",
    ]
    
    # Define datasets to evaluate on
    dataset_classes = [
        TSQADataset,
        PAMAP2AccQADataset,
        PAMAP2CoTQADataset,
    ]
    
    # Define evaluation functions
    evaluation_functions = {
        "TSQADataset": evaluate_tsqa,
        "PAMAP2AccQADataset": evaluate_pamap_acc,
        "PAMAP2CoTQADataset": evaluate_pamap_cot,
    }
    
    # Initialize evaluator
    evaluator = CommonEvaluator()
    
    # Run comprehensive evaluation
    results_df = evaluator.evaluate_multiple_models(
        model_names=model_names,
        dataset_classes=dataset_classes,
        evaluation_functions=evaluation_functions,
        max_samples=50,  # Limit for faster testing, set to None for full evaluation
    )
    
    print("\n" + "="*80)
    print("COMPREHENSIVE EVALUATION RESULTS")
    print("="*80)
    print(results_df.to_string(index=False))
    
    # Print summary statistics
    print("\n" + "="*80)
    print("SUMMARY STATISTICS")
    print("="*80)
    
    # Group by dataset and show average metrics
    for dataset in results_df['dataset'].unique():
        dataset_results = results_df[results_df['dataset'] == dataset]
        print(f"\n{dataset}:")
        print(f"  Number of models evaluated: {len(dataset_results)}")
        
        # Show average accuracy
        if 'accuracy' in dataset_results.columns:
            avg_accuracy = dataset_results['accuracy'].mean()
            print(f"  Average accuracy: {avg_accuracy:.1f}%")
    
    # Group by model and show average metrics
    print(f"\nBy Model:")
    for model in results_df['model'].unique():
        model_results = results_df[results_df['model'] == model]
        print(f"\n{model}:")
        print(f"  Number of datasets evaluated: {len(model_results)}")
        
        # Show average accuracy
        if 'accuracy' in model_results.columns:
            avg_accuracy = model_results['accuracy'].mean()
            print(f"  Average accuracy: {avg_accuracy:.1f}%")
    
    return results_df


if __name__ == "__main__":
    main() 