"""
Test runner and demonstration script for the document generation benchmark.

This script shows how to use the benchmark framework and runs a sample evaluation.

Author: GitHub Copilot
Date: September 14, 2025
"""

import os
import sys
import json
import logging
from pathlib import Path
import ctypes

# Add the script directory to path for imports
script_dir = Path(__file__).parent
sys.path.append(str(script_dir))

from document_generation import DocumentGenerationBenchmark
from benchmark_analysis import BenchmarkAnalyzer

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def load_config(config_file: str = "benchmark_config.json") -> dict:
    """Load benchmark configuration"""
    config_path = script_dir / config_file
    try:
        with open(config_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        logger.warning(f"Config file {config_file} not found, using defaults")
        return {}

def run_sample_benchmark(domain: str = "Finance", max_queries: int = 5, use_ground_truth: bool = False, use_fuzzy_intent_matching: bool = False):
    """Run a sample benchmark for demonstration"""
    logger.info(f"Running sample benchmark for {domain} domain...")
    logger.info(f"Intent evaluation mode: {'Fuzzy matching' if use_fuzzy_intent_matching else 'Strict matching (default)'}")
    
    # Load configuration
    config = load_config()
    benchmark_config = config.get("benchmark_config", {})
    data_paths = config.get("data_paths", {})
    
    # Set up paths
    conversation_file = data_paths.get("conversation_files", {}).get(
        domain, f"../data/{domain}/synthetic_domain_channels_{domain}.json"
    )
    queries_file = data_paths.get("queries_file", {}).get(
        domain, f"./synthetic_queries/generated_user_queries_{domain}.json"
    )
    output_dir = data_paths.get("output_directory", "./benchmark_results")
    
    # Add domain suffix to output directory
    output_dir = f"{output_dir}_{domain.lower()}"
    
    try:
        # Initialize benchmark
        benchmark = DocumentGenerationBenchmark(
            azure_endpoint=benchmark_config.get("azure_endpoint"),
            model_name=benchmark_config.get("model_name", "gpt-5-chat"),
            api_version=benchmark_config.get("api_version", "2024-05-01-preview")
        )
        
        # Run benchmark
        results = benchmark.run_comprehensive_benchmark(
            conversation_file=conversation_file,
            queries_file=queries_file,
            output_dir=output_dir,
            max_queries=max_queries,
            use_ground_truth=use_ground_truth,
            use_fuzzy_intent_matching=use_fuzzy_intent_matching
        )
        
        logger.info(f"Benchmark completed with {len(results)} results")
        
        # Run analysis
        analyzer = BenchmarkAnalyzer(output_dir)
        analyzer.load_results()
        
        # Generate metrics and visualizations
        metrics = analyzer.calculate_advanced_metrics()
        analyzer.generate_visualizations()
        analyzer.generate_detailed_report()
        
        # Print summary
        print("\n" + "="*60)
        print(f"BENCHMARK SUMMARY - {domain.upper()} DOMAIN")
        print("="*60)
        
        if results:
            # Calculate context retrieval metrics separately
            context_precision = []
            context_recall = []
            context_f1 = []
            
            for r in results:
                # Access the detailed context retrieval result
                context_result = r.detailed_evaluation.get("context_retrieval", {})
                if isinstance(context_result, dict):
                    context_precision.append(context_result.get("precision", 0.0))
                    context_recall.append(context_result.get("recall", 0.0))
                    context_f1.append(context_result.get("f1_score", 0.0))
                else:
                    # Fallback to the overall accuracy score
                    context_f1.append(r.context_retrieval_accuracy)
                    context_precision.append(0.0)  # Unknown
                    context_recall.append(0.0)  # Unknown
            
            # Calculate enhanced intent metrics averages
            intent_macro_f1_scores = [r.intent_evaluation.macro_f1_score for r in results]
            
            # Calculate per-field intent averages across all queries
            all_intent_fields = set()
            for r in results:
                all_intent_fields.update(r.intent_evaluation.per_field_precision.keys())
            
            per_field_avg_precision = {}
            for field in all_intent_fields:
                field_precisions = [r.intent_evaluation.per_field_precision.get(field, 0.0) for r in results]
                per_field_avg_precision[field] = sum(field_precisions) / len(field_precisions)
            
            avg_scores = {
                "User Profile Accuracy": sum(r.user_profile_accuracy for r in results) / len(results),
                "Intent Capture Accuracy": sum(r.intent_capture_accuracy for r in results) / len(results),
                "Intent Macro-F1 Score": sum(intent_macro_f1_scores) / len(intent_macro_f1_scores),
                "Context Retrieval - Precision": sum(context_precision) / len(context_precision) if context_precision else 0.0,
                "Context Retrieval - Recall": sum(context_recall) / len(context_recall) if context_recall else 0.0,
                "Context Retrieval - F1": sum(context_f1) / len(context_f1) if context_f1 else 0.0,
                "Citation Accuracy": sum(r.citation_accuracy for r in results) / len(results),
                "Document Quality": sum(r.document_quality_score for r in results) / len(results),
                "Overall Score": sum(r.overall_score for r in results) / len(results)
            }
            
            for metric, score in avg_scores.items():
                print(f"{metric:<30}: {score:.3f} ({score:.1%})")
            
            # Display enhanced intent metrics details
            print(f"\nIntent Field Performance (Average Precision):")
            for field, precision in sorted(per_field_avg_precision.items()):
                performance = "Strong" if precision >= 0.8 else "Weak" if precision <= 0.3 else "Moderate"
                print(f"  {field.replace('_', ' ').title():<20}: {precision:.3f} ({precision:.1%}) {performance}")
            
            print(f"\nTotal Queries Processed: {len(results)}")
            print(f"Intent Fields Evaluated: {len(all_intent_fields)}")
            print(f"Results Directory: {output_dir}")
            
            # Summary insights for research
            weak_fields = [field for field, p in per_field_avg_precision.items() if p <= 0.3]
            strong_fields = [field for field, p in per_field_avg_precision.items() if p >= 0.8]
            
            if weak_fields or strong_fields:
                print(f"\nResearch Insights:")
                if strong_fields:
                    print(f"  Model excels at: {', '.join(f.replace('_', ' ') for f in strong_fields)}")
                if weak_fields:
                    print(f"  Model struggles with: {', '.join(f.replace('_', ' ') for f in weak_fields)}")
                
                macro_f1_avg = sum(intent_macro_f1_scores) / len(intent_macro_f1_scores)
                overall_intent_avg = sum(r.intent_capture_accuracy for r in results) / len(results)
                
                if abs(macro_f1_avg - overall_intent_avg) > 0.1:
                    print(f"  Macro-F1 vs Overall Accuracy difference ({abs(macro_f1_avg - overall_intent_avg):.1%}) suggests uneven field performance")
                else:
                    print(f"  Balanced performance across intent fields (Macro-F1: {macro_f1_avg:.1%}, Overall: {overall_intent_avg:.1%})")
        
        return results, analyzer
        
    except Exception as e:
        logger.error(f"Error running benchmark: {e}")
        return None, None

def run_multi_domain_comparison(domains=None, max_queries=10, use_ground_truth=False):
    """Run benchmark across multiple domains for comparison"""
    if domains is None:
        domains = ["Finance", "Technology", "Healthcare", "Manufacturing"]
    
    logger.info(f"Running multi-domain comparison for: {', '.join(domains)}")
    
    all_results = {}
    
    for domain in domains:
        logger.info(f"Processing {domain} domain...")
        results, analyzer = run_sample_benchmark(domain, max_queries, use_ground_truth)
        
        if results:
            # Calculate enhanced metrics for this domain
            intent_macro_f1_avg = sum(r.intent_evaluation.macro_f1_score for r in results) / len(results)
            intent_accuracy_avg = sum(r.intent_capture_accuracy for r in results) / len(results)
            
            # Calculate context retrieval metrics separately for this domain
            context_precision = []
            context_recall = []
            context_f1 = []
            
            for r in results:
                # Access the detailed context retrieval result
                context_result = r.detailed_evaluation.get("context_retrieval", {})
                if isinstance(context_result, dict):
                    context_precision.append(context_result.get("precision", 0.0))
                    context_recall.append(context_result.get("recall", 0.0))
                    context_f1.append(context_result.get("f1_score", 0.0))
                else:
                    # Fallback to the overall accuracy score
                    context_f1.append(r.context_retrieval_accuracy)
                    context_precision.append(0.0)  # Unknown
                    context_recall.append(0.0)  # Unknown
            
            all_results[domain] = {
                "results": results,
                "analyzer": analyzer,
                "avg_score": sum(r.overall_score for r in results) / len(results),
                "intent_accuracy": intent_accuracy_avg,
                "intent_macro_f1": intent_macro_f1_avg,
                "profile_accuracy": sum(r.user_profile_accuracy for r in results) / len(results),
                "context_accuracy": sum(r.context_retrieval_accuracy for r in results) / len(results),
                "context_precision": sum(context_precision) / len(context_precision) if context_precision else 0.0,
                "context_recall": sum(context_recall) / len(context_recall) if context_recall else 0.0,
                "context_f1": sum(context_f1) / len(context_f1) if context_f1 else 0.0,
                "citation_accuracy": sum(r.citation_accuracy for r in results) / len(results),
                "document_quality": sum(r.document_quality_score for r in results) / len(results)
            }
    
    # Generate comparison report
    print("\n" + "="*80)
    print("MULTI-DOMAIN BENCHMARK COMPARISON")
    print("="*80)
    
    if all_results:
        # Overall performance comparison
        print("Overall Performance:")
        for domain, data in all_results.items():
            print(f"{domain:<15}: {data['avg_score']:.3f} ({data['avg_score']:.1%}) (n={len(data['results'])})")
        
        # Detailed metrics comparison
        print("\nDetailed Metrics Comparison:")
        print(f"{'Domain':<15} {'Overall':<8} {'Profile':<8} {'Intent':<8} {'Macro-F1':<9} {'Ctx-P':<7} {'Ctx-R':<7} {'Ctx-F1':<8} {'Citation':<9} {'Quality':<8}")
        print("-" * 95)
        
        for domain, data in all_results.items():
            print(f"{domain:<15} "
                  f"{data['avg_score']:.3f}    "
                  f"{data['profile_accuracy']:.3f}    "
                  f"{data['intent_accuracy']:.3f}    "
                  f"{data['intent_macro_f1']:.3f}     "
                  f"{data['context_precision']:.3f}   "
                  f"{data['context_recall']:.3f}   "
                  f"{data['context_f1']:.3f}    "
                  f"{data['citation_accuracy']:.3f}     "
                  f"{data['document_quality']:.3f}")
        
        # Find best and worst performing domains
        best_domain = max(all_results.keys(), key=lambda d: all_results[d]['avg_score'])
        worst_domain = min(all_results.keys(), key=lambda d: all_results[d]['avg_score'])
        
        print(f"\nBest Performing Domain: {best_domain} ({all_results[best_domain]['avg_score']:.3f})")
        print(f"Worst Performing Domain: {worst_domain} ({all_results[worst_domain]['avg_score']:.3f})")
        
        # Cross-domain insights
        print(f"\nCross-Domain Insights:")
        
        # Calculate average metrics across all domains
        total_queries = sum(len(data['results']) for data in all_results.values())
        avg_overall = sum(data['avg_score'] * len(data['results']) for data in all_results.values()) / total_queries
        avg_intent_accuracy = sum(data['intent_accuracy'] * len(data['results']) for data in all_results.values()) / total_queries
        avg_intent_macro_f1 = sum(data['intent_macro_f1'] * len(data['results']) for data in all_results.values()) / total_queries
        avg_context_precision = sum(data['context_precision'] * len(data['results']) for data in all_results.values()) / total_queries
        avg_context_recall = sum(data['context_recall'] * len(data['results']) for data in all_results.values()) / total_queries
        avg_context_f1 = sum(data['context_f1'] * len(data['results']) for data in all_results.values()) / total_queries
        
        print(f"  Average Overall Performance: {avg_overall:.3f} ({avg_overall:.1%})")
        print(f"  Average Intent Accuracy: {avg_intent_accuracy:.3f} ({avg_intent_accuracy:.1%})")
        print(f"  Average Intent Macro-F1: {avg_intent_macro_f1:.3f} ({avg_intent_macro_f1:.1%})")
        print(f"  Average Context Precision: {avg_context_precision:.3f} ({avg_context_precision:.1%})")
        print(f"  Average Context Recall: {avg_context_recall:.3f} ({avg_context_recall:.1%})")
        print(f"  Average Context F1: {avg_context_f1:.3f} ({avg_context_f1:.1%})")
        print(f"  Total Queries Across Domains: {total_queries}")
        
        # Identify consistently strong/weak metrics (including enhanced context metrics)
        all_metrics = ['profile_accuracy', 'intent_accuracy', 'intent_macro_f1', 'context_precision', 'context_recall', 'context_f1', 'citation_accuracy', 'document_quality']
        strongest_metric = max(all_metrics, key=lambda m: sum(data[m] for data in all_results.values()) / len(all_results))
        weakest_metric = min(all_metrics, key=lambda m: sum(data[m] for data in all_results.values()) / len(all_results))
        
        print(f"  Strongest Component: {strongest_metric.replace('_', ' ').title()}")
        print(f"  Weakest Component: {weakest_metric.replace('_', ' ').title()}")
        
        # Context retrieval insights
        if avg_context_precision > 0 and avg_context_recall > 0:
            if avg_context_precision > avg_context_recall + 0.1:
                print(f"  Context Retrieval: High precision ({avg_context_precision:.1%}) but lower recall ({avg_context_recall:.1%}) - conservative selection")
            elif avg_context_recall > avg_context_precision + 0.1:
                print(f"  Context Retrieval: High recall ({avg_context_recall:.1%}) but lower precision ({avg_context_precision:.1%}) - broad selection")
            else:
                print(f"  Context Retrieval: Balanced precision/recall ({avg_context_precision:.1%}/{avg_context_recall:.1%})")
    
    return all_results

def test_individual_components():
    """Test individual benchmark components"""
    logger.info("Testing individual benchmark components...")
    
    try:
        # Initialize benchmark
        benchmark = DocumentGenerationBenchmark()
        
        # Load sample data
        conversation_file = "../data/Finance/synthetic_domain_channels_Finance.json"
        messages = benchmark.load_conversation_data(conversation_file)
        
        print(f"Loaded {len(messages)} messages from conversation data")
        
        # Test user profile inference
        sample_user = messages[0].get("author", "User_1") if messages else "User_1"
        user_profile = benchmark.infer_user_profile(messages, sample_user)
        
        print(f"\nSample User Profile for {sample_user}:")
        print(f"  Role: {user_profile.role}")
        print(f"  Expertise: {user_profile.expertise_level}")
        print(f"  Style: {user_profile.communication_style}")
        print(f"  Confidence: {user_profile.confidence_score:.2f}")
        
        # Test intent capture
        sample_query = "Generate a status report on the project progress for management review"
        intent = benchmark.capture_user_intent(sample_query, {})
        
        print(f"\nSample Intent Capture:")
        print(f"  Document Type: {intent.document_type}")
        print(f"  Target Audience: {intent.target_audience}")
        print(f"  Detail Level: {intent.detail_level}")
        print(f"  Format: {intent.format_requirements}")
        
        # Test document generation
        document = benchmark.generate_document_with_citations(messages[:10], user_profile, intent)
        
        print(f"\nSample Document Generation:")
        print(f"  Content Length: {len(document.content)} characters")
        print(f"  Citations Count: {len(document.citations)}")
        print(f"  Content Preview: {document.content[:200]}...")
        
        # Test quality evaluation
        quality_scores = benchmark.evaluate_document_quality(document)
        
        print(f"\nSample Quality Evaluation:")
        for dimension, score in quality_scores.items():
            if isinstance(score, (int, float)):
                print(f"  {dimension}: {score}")
        
        return True
        
    except Exception as e:
        logger.error(f"Error testing individual components: {e}")
        return False

def validate_setup():
    """Validate that all required files and configurations are in place"""
    logger.info("Validating benchmark setup...")
    
    # Check for domain-specific query files
    query_files_found = []
    for domain in ["Finance", "Technology", "Healthcare", "Manufacturing"]:
        query_file = f"./synthetic_queries/generated_user_queries_{domain}.json"
        if os.path.exists(query_file):
            query_files_found.append(query_file)
    
    # Other required files
    other_required_files = [
        "../data/Finance/synthetic_domain_channels_Finance.json",
        "./benchmark_config.json"
    ]
    
    missing_files = []
    for file_path in other_required_files:
        if not os.path.exists(file_path):
            missing_files.append(file_path)
    
    # Check if we have at least one query file
    if not query_files_found:
        missing_files.append("./synthetic_queries/generated_user_queries_<domain>.json (no domain-specific query files found)")
    
    if missing_files:
        print("Setup validation failed!")
        print("Missing required files:")
        for file_path in missing_files:
            print(f"  - {file_path}")
        return False
    
    # Show found query files
    if query_files_found:
        print(f"Found {len(query_files_found)} domain-specific query files:")
        for query_file in query_files_found:
            print(f"  - {query_file}")
    
    # Check Azure OpenAI credentials
    endpoint = os.getenv("ENDPOINT_URL")
    if not endpoint:
        print("Warning: ENDPOINT_URL environment variable not set")
        print("   Using default endpoint from config")
    
    print("Setup validation passed!")
    return True

def main():
    """Main function with interactive menu"""
    print("Document Generation Benchmark Test Runner")
    print("="*50)
    
    if not validate_setup():
        print("\nPlease fix the setup issues before running the benchmark.")
        return
    
    while True:
        print("\nAvailable actions:")
        print("1. Test individual components")
        print("2. Run sample benchmark (Finance domain)")
        print("3. Run multi-domain comparison")
        print("4. Run custom benchmark")
        print("5. Analyze existing results")
        print("6. Exit")
        
        choice = input("\nSelect an action (1-6): ").strip()
        
        if choice == "1":
            print("\nTesting individual components...")
            success = test_individual_components()
            if success:
                print("All components tested successfully!")
            else:
                print("Component testing failed!")
        
        elif choice == "2":
            print("\nRunning sample benchmark...")
            print("Evaluation modes:")
            print("1. End-to-end evaluation (using predicted inputs)")
            print("2. Ground truth evaluation (using actual labels)")
            eval_choice = input("Select evaluation mode (1-2): ").strip()
            use_ground_truth = eval_choice == "2"
            
            results, analyzer = run_sample_benchmark("Finance", max_queries=3, use_ground_truth=use_ground_truth)
            if results:
                print("Sample benchmark completed successfully!")
            else:
                print("Sample benchmark failed!")
        
        elif choice == "3":
            print("\nRunning multi-domain comparison...")
            print("Evaluation modes:")
            print("1. End-to-end evaluation (using predicted inputs)")
            print("2. Ground truth evaluation (using actual labels)")
            eval_choice = input("Select evaluation mode (1-2): ").strip()
            use_ground_truth = eval_choice == "2"
            
            all_results = run_multi_domain_comparison(max_queries=50, use_ground_truth=use_ground_truth)
            if all_results:
                print("Multi-domain comparison completed!")
            else:
                print("Multi-domain comparison failed!")
        
        elif choice == "4":
            domain = input("Enter domain (Finance/Technology/Healthcare/Manufacturing): ").strip()
            max_queries = int(input("Enter max queries to process: ").strip())
            print("Evaluation modes:")
            print("1. End-to-end evaluation (using predicted inputs)")
            print("2. Ground truth evaluation (using actual labels)")
            eval_choice = input("Select evaluation mode (1-2): ").strip()
            use_ground_truth = eval_choice == "2"
            
            results, analyzer = run_sample_benchmark(domain, max_queries, use_ground_truth)
            if results:
                print("Custom benchmark completed!")
            else:
                print("Custom benchmark failed!")
        
        elif choice == "5":
            results_dir = input("Enter results directory path: ").strip()
            if not results_dir:
                results_dir = "./benchmark_results_finance"
            
            try:
                analyzer = BenchmarkAnalyzer(results_dir)
                analyzer.load_results()
                analyzer.generate_visualizations()
                analyzer.generate_detailed_report()
                print("Analysis completed!")
            except Exception as e:
                print(f"Analysis failed: {e}")
        
        elif choice == "6":
            print("Goodbye!")
            break
        
        else:
            print("Invalid choice. Please select 1-6.")

if __name__ == "__main__":
    ES_CONTINUOUS = 0x80000000
    ES_SYSTEM_REQUIRED = 0x00000001
    ES_DISPLAY_REQUIRED = 0x00000002

    ctypes.windll.kernel32.SetThreadExecutionState(
        ES_CONTINUOUS | ES_SYSTEM_REQUIRED | ES_DISPLAY_REQUIRED)
    main()
