#!/usr/bin/env python3
"""
Multi-Agent Experiment Runner for RadAgents - Comparison Mode
Specialized for comparing current and prior chest X-ray images

Usage:
    # Simple comparison
    python run_multi_agent_experiment_comparison.py --current data/current.png --prior data/prior.png --query "Has the heart size changed?"
    e.g., python run_multi_agent_experiment_comparison.py --current data/cardic_example/mild_enlarge_cardi.png --prior data/cardic_example/normal_cxr.jpeg --query "Decide if cardiomegaly is improving, stable, or worsening." --no-smart-loading
    
    # Specific finding comparison
    python run_multi_agent_experiment_comparison.py --current data/current.png --prior data/prior.png --query "Compare the pleural effusion with prior"
    
    # Batch comparison experiments
    python run_multi_agent_experiment_comparison.py --config experiments/comparison_experiments.json
"""

import os
import sys
import json
import argparse
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional, List
import warnings

# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

# Load environment variables (gracefully handle if dotenv not available)
try:
    from dotenv import load_dotenv
    load_dotenv()  # Load .env file if it exists
except ImportError:
    print("Note: python-dotenv not available. Using system environment variables only.")

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

from radagents.multi_agent.interface import (
    initialize_multi_agent_system,
    run_multi_agent_analysis
)


def setup_logging(log_dir: str = "log") -> str:
    """Setup logging directory and return log filename."""
    Path(log_dir).mkdir(exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    return f"{log_dir}/comparison_experiment_{timestamp}"


def save_comparison_results(results: Dict[str, Any], log_prefix: str) -> None:
    """Save comparison experiment results to files."""
    # Save JSON results
    results_file = f"{log_prefix}_results.json"
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    print(f"Results saved to: {results_file}")


def print_comparison_results(results: Dict[str, Any], query: str, current_path: str, prior_path: str) -> None:
    """Print formatted comparison results to console."""
    print("\n" + "="*80)
    print("COMPARISON ANALYSIS RESULTS")
    print("="*80)
    
    print(f"\nQuery: {query}")
    print(f"Current Image: {current_path}")
    print(f"Prior Image: {prior_path}")
    print(f"Query Intent: {results.get('query_intent', 'comparison')}")
    print(f"Comparison Focus: {', '.join(results.get('comparison_focus', []))}")
    print(f"Activated Agents: {', '.join(results.get('activated_agents', []))}")
    print(f"Execution Time: {results.get('execution_time', 0):.2f} seconds")
    
    # Print comparison findings
    findings = results.get('findings', [])
    if findings:
        print(f"\n{'COMPARISON FINDINGS':<20}:")
        for i, finding in enumerate(findings, 1):
            temporal = finding.get('temporal_context', '')
            temporal_label = f" [{temporal.upper()}]" if temporal else ""
            print(f"  {i}. Pathology: {finding.get('pathology', 'Unknown')}{temporal_label}")
            print(f"     Confidence: {finding.get('confidence', 0):.3f}")
            print(f"     Evidence: {finding.get('evidence', 'None')}")
            # Show measurements if available
            measurements = finding.get('measurements', {})
            if measurements:
                for key, val in measurements.items():
                    if key not in ['vcot_disagreement', 'tool_confidence', 'final_confidence']:
                        print(f"     {key}: {val}")
    
    # Print agent-specific comparisons
    agent_results = results.get('agent_results', {})
    if agent_results:
        print(f"\n{'DETAILED AGENT COMPARISONS':<30}:")
        for agent_name, agent_result in agent_results.items():
            if isinstance(agent_result, dict) and agent_result.get('comparison_performed'):
                print(f"\n  {agent_name.upper()}:")
                
                # Show comparison summary
                findings = agent_result.get('findings', [])
                if findings:
                    # Group by temporal context instead of status
                    temporal_counts = {'current': 0, 'prior': 0, 'unknown': 0}
                    for finding in findings:
                        temporal = finding.get('temporal_context', 'unknown')
                        temporal_counts[temporal] = temporal_counts.get(temporal, 0) + 1
                    
                    print(f"    Summary:")
                    if temporal_counts['current'] > 0:
                        print(f"      - Current image: {temporal_counts['current']} finding(s)")
                    if temporal_counts['prior'] > 0:
                        print(f"      - Prior image: {temporal_counts['prior']} finding(s)")
                    if temporal_counts['unknown'] > 0:
                        print(f"      - Unspecified: {temporal_counts['unknown']} finding(s)")
                
                # Group findings by pathology for comparison display
                pathology_groups = {}
                for finding in findings:
                    pathology = finding.get('pathology', 'Unknown')
                    temporal = finding.get('temporal_context', 'unknown')
                    if pathology not in pathology_groups:
                        pathology_groups[pathology] = {'current': None, 'prior': None}
                    if temporal == 'current':
                        pathology_groups[pathology]['current'] = finding
                    elif temporal == 'prior':
                        pathology_groups[pathology]['prior'] = finding
                
                # Show comparison for each pathology
                for pathology, temporal_findings in list(pathology_groups.items())[:3]:  # Show first 3
                    current = temporal_findings['current']
                    prior = temporal_findings['prior']
                    
                    if current and prior:
                        # Both timepoints - show comparison
                        curr_conf = current.get('confidence', 0)
                        prior_conf = prior.get('confidence', 0)
                        print(f"    {pathology}:")
                        print(f"      Prior: {prior.get('evidence', 'N/A')} (conf: {prior_conf:.1%})")
                        print(f"      Current: {current.get('evidence', 'N/A')} (conf: {curr_conf:.1%})")
                        
                        # Show measurements if different
                        curr_meas = current.get('measurements', {})
                        prior_meas = prior.get('measurements', {})
                        if curr_meas or prior_meas:
                            for key in set(list(curr_meas.keys()) + list(prior_meas.keys())):
                                if key not in ['vcot_disagreement', 'tool_confidence', 'final_confidence']:
                                    prior_val = prior_meas.get(key, 'N/A')
                                    curr_val = curr_meas.get(key, 'N/A')
                                    if prior_val != curr_val:
                                        print(f"      {key}: {prior_val} → {curr_val}")
                    elif current:
                        # New finding
                        print(f"    {pathology}: NEW")
                        print(f"      {current.get('evidence', 'N/A')}")
                    elif prior:
                        # Resolved finding
                        print(f"    {pathology}: RESOLVED")
                        print(f"      Was: {prior.get('evidence', 'N/A')}")
    
    # Print final synthesis
    if 'answer' in results:
        print(f"\n{'SYNTHESIS':<20}: {results.get('answer', 'No synthesis available')}")
        print(f"{'CONFIDENCE':<20}: {results.get('confidence', 0):.3f}")
        print(f"{'NEEDS REVIEW':<20}: {'Yes' if results.get('needs_review', True) else 'No'}")
    
    print("="*80)


def run_comparison_experiment(
    current_image_path: str,
    prior_image_path: str,
    query: str,
    model: str = "gpt-4o",
    temperature: float = 0.7,
    required_agents: Optional[List[str]] = None,
    execution_mode: str = "sequential",  # Sequential often better for comparison
    smart_tool_loading: bool = True,
    device: str = "cuda",
    model_dir: str = "./models",
    temp_dir: str = "temp",
    save_logs: bool = True,
    **kwargs
) -> Dict[str, Any]:
    """
    Run a comparison experiment between current and prior images.
    
    Args:
        current_image_path: Path to the current chest X-ray image
        prior_image_path: Path to the prior chest X-ray image
        query: Medical query about comparison
        model: LLM model to use
        temperature: LLM temperature
        required_agents: List of agents to use (None = auto-detect)
        execution_mode: "parallel" or "sequential"
        smart_tool_loading: Whether to use smart tool loading
        device: Device for model execution
        model_dir: Directory containing model weights
        temp_dir: Directory for temporary files
        save_logs: Whether to save logs
        **kwargs: Additional arguments
        
    Returns:
        Dictionary containing comparison analysis results
    """
    
    # Verify images exist
    if not os.path.exists(current_image_path):
        raise FileNotFoundError(f"Current image not found: {current_image_path}")
    if not os.path.exists(prior_image_path):
        raise FileNotFoundError(f"Prior image not found: {prior_image_path}")
    
    print(f"Starting comparison experiment...")
    print(f"Current Image: {current_image_path}")
    print(f"Prior Image: {prior_image_path}")
    print(f"Query: {query}")
    print(f"Model: {model}")
    print(f"Execution Mode: {execution_mode}")
    
    # Setup logging
    log_prefix = None
    if save_logs:
        log_prefix = setup_logging()
        print(f"Log prefix: {log_prefix}")
    
    start_time = time.time()
    
    try:
        # Initialize the multi-agent system
        print("\nInitializing multi-agent system...")
        orchestrator = initialize_multi_agent_system(
            model=model,
            temperature=temperature,
            model_dir=model_dir,
            temp_dir=temp_dir,
            device=device,
            execution_mode=execution_mode,
            smart_tool_loading=smart_tool_loading,
            required_agents=required_agents,
            **kwargs
        )
        
        # Run comparison analysis
        print("\nRunning comparison analysis...")
        results = run_multi_agent_analysis(
            orchestrator=orchestrator,
            query=query,
            image_path=current_image_path,
            prior_image_path=prior_image_path  # Pass prior image for comparison
        )
        
        # Add experiment metadata
        results['experiment_metadata'] = {
            'current_image_path': current_image_path,
            'prior_image_path': prior_image_path,
            'query': query,
            'model': model,
            'temperature': temperature,
            'execution_mode': execution_mode,
            'smart_tool_loading': smart_tool_loading,
            'required_agents': required_agents,
            'timestamp': datetime.now().isoformat(),
            'total_experiment_time': time.time() - start_time,
            'comparison_mode': True
        }
        
        # Save results if logging is enabled
        if save_logs and log_prefix:
            save_comparison_results(results, log_prefix)
        
        return results
        
    except Exception as e:
        print(f"Error during comparison experiment: {e}")
        raise


def load_experiment_config(config_path: str) -> Dict[str, Any]:
    """Load experiment configuration from JSON file."""
    with open(config_path, 'r') as f:
        return json.load(f)


def run_multiple_comparisons(experiments: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Run multiple comparison experiments from configuration."""
    results = []
    
    for i, exp_config in enumerate(experiments, 1):
        print(f"\n{'='*60}")
        print(f"RUNNING COMPARISON EXPERIMENT {i}/{len(experiments)}")
        print(f"{'='*60}")
        
        try:
            result = run_comparison_experiment(**exp_config)
            results.append(result)
            
            # Print results for this experiment
            print_comparison_results(
                result, 
                exp_config['query'], 
                exp_config['current_image_path'],
                exp_config['prior_image_path']
            )
            
        except Exception as e:
            print(f"Experiment {i} failed: {e}")
            results.append({
                'error': str(e),
                'experiment_config': exp_config
            })
    
    return results


def main():
    parser = argparse.ArgumentParser(
        description="Run RadAgents Comparison Experiments",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Simple comparison
  python run_multi_agent_experiment_comparison.py \\
    --current data/current_cxr.png \\
    --prior data/prior_cxr.png \\
    --query "Has the heart size changed?"
  
  # Specific finding comparison
  python run_multi_agent_experiment_comparison.py \\
    --current data/current_cxr.png \\
    --prior data/prior_cxr.png \\
    --query "Compare the pleural effusion with the prior study"
  
  # Load from config file
  python run_multi_agent_experiment_comparison.py --config experiments/comparison_experiments.json
  
  # Comparison with specific agents
  python run_multi_agent_experiment_comparison.py \\
    --current data/current.png \\
    --prior data/prior.png \\
    --query "Compare cardiac and breathing findings" \\
    --agents cardiac breathing synthesis
        """
    )
    
    # Main experiment arguments
    parser.add_argument('--current', type=str, help='Path to current chest X-ray image')
    parser.add_argument('--prior', type=str, help='Path to prior chest X-ray image')
    parser.add_argument('--query', type=str, help='Comparison query about the images')
    parser.add_argument('--config', type=str, help='Path to experiment configuration JSON file')
    
    # Model configuration
    parser.add_argument('--model', type=str, default='gpt-4o', 
                       help='LLM model to use (default: gpt-4o)')
    parser.add_argument('--temperature', type=float, default=0.7,
                       help='LLM temperature (default: 0.7)')
    
    # System configuration
    parser.add_argument('--agents', nargs='*', default=None,
                       help='Required agents (default: auto-detect based on query)')
    parser.add_argument('--execution-mode', choices=['parallel', 'sequential'], 
                       default='sequential', help='Agent execution mode (default: sequential for comparison)')
    parser.add_argument('--device', type=str, default='cuda', 
                       help='Device for model execution')
    parser.add_argument('--model-dir', type=str, default='./models',
                       help='Directory containing model weights')
    parser.add_argument('--temp-dir', type=str, default='temp',
                       help='Directory for temporary files')
    
    # Utility options
    parser.add_argument('--no-smart-loading', action='store_true',
                       help='Disable smart tool loading')
    parser.add_argument('--no-logs', action='store_true',
                       help='Disable log saving')
    parser.add_argument('--quiet', action='store_true',
                       help='Suppress detailed output')
    
    args = parser.parse_args()
    
    # Validate arguments
    if not args.config and not (args.current and args.prior and args.query):
        parser.error("Either --config or all of --current, --prior, and --query must be provided")
    
    if args.config and (args.current or args.prior or args.query):
        parser.error("Cannot use --config with --current/--prior/--query")
    
    try:
        if args.config:
            # Load and run multiple experiments from config
            print(f"Loading comparison experiments from: {args.config}")
            config = load_experiment_config(args.config)
            
            if 'experiments' in config:
                experiments = config['experiments']
            else:
                experiments = [config]  # Single experiment config
            
            results = run_multiple_comparisons(experiments)
            
        else:
            # Run single comparison experiment
            experiment_args = {
                'current_image_path': args.current,
                'prior_image_path': args.prior,
                'query': args.query,
                'model': args.model,
                'temperature': args.temperature,
                'required_agents': args.agents,
                'execution_mode': args.execution_mode,
                'smart_tool_loading': not args.no_smart_loading,
                'device': args.device,
                'model_dir': args.model_dir,
                'temp_dir': args.temp_dir,
                'save_logs': not args.no_logs
            }
            
            result = run_comparison_experiment(**experiment_args)
            
            if not args.quiet:
                print_comparison_results(
                    result, 
                    args.query, 
                    args.current,
                    args.prior
                )
    
    except Exception as e:
        print(f"Comparison experiment failed: {e}")
        sys.exit(1)
    
    print("\nComparison experiment completed successfully!")


if __name__ == "__main__":
    main()

# Example usage:
# python run_multi_agent_experiment_comparison.py --current data/cardiac_example/current.png --prior data/cardiac_example/prior.png --query "Has the heart size changed compared to the prior study?"
