#!/usr/bin/env python3
"""
Experimental Results Analyzer for LinearizeLLM

This script analyzes the results from the experimental setup comparing
three different information scenarios:
1. No Context: Only LaTeX optimization problem
2. Partial Information: No parameter info to detection/reformulation agents
3. Full Information: Complete parameter and context information
"""

import json
import os
import argparse
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Any, Optional
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict


class ContextExperimentResultsAnalyzer:
    """
    Analyzer for context experiment results comparing different information scenarios.
    """
    
    def __init__(self, results_dir: str = "data/context_experiment_results"):
        """
        Initialize the results analyzer.
        
        Args:
            results_dir: Directory containing experimental results
        """
        self.results_dir = Path(results_dir)
        self.scenarios = ['no_context', 'partial_info', 'full_info']
        self.scenario_names = {
            'no_context': 'No Context',
            'partial_info': 'Partial Information',
            'full_info': 'Full Information'
        }
        
        # Load context experiment summary
        self.summary_file = self.results_dir / "context_experiment_summary.json"
        self.context_experiment_summary = self._load_context_experiment_summary()
        
        # Load all individual results
        self.all_results = self._load_all_results()
        
        print(f"📊 LinearizeLLM Context Experiment Results Analyzer Initialized")
        print(f"   📁 Results Directory: {self.results_dir}")
        print(f"   📈 Total Experiments: {len(self.all_results)}")
        print(f"   🔍 Scenarios: {list(self.scenario_names.keys())}")
    
    def _load_context_experiment_summary(self) -> Dict[str, Any]:
        """Load the context experiment summary file."""
        if self.summary_file.exists():
            with open(self.summary_file, 'r') as f:
                return json.load(f)
        else:
            print(f"⚠️ Warning: Context experiment summary not found at {self.summary_file}")
            return {}
    
    def _load_all_results(self) -> Dict[str, Dict[str, Any]]:
        """Load all individual experiment results."""
        results = {}
        
        for scenario in self.scenarios:
            scenario_dir = self.results_dir / scenario
            if not scenario_dir.exists():
                continue
                
            for instance_dir in scenario_dir.iterdir():
                if not instance_dir.is_dir():
                    continue
                    
                for seed_dir in instance_dir.iterdir():
                    if not seed_dir.is_dir() or not seed_dir.name.startswith('seed_'):
                        continue
                    
                    # Load context experiment results
                    results_file = seed_dir / "context_experiment_results.json"
                    error_file = seed_dir / "context_experiment_error.json"
                    
                    if results_file.exists():
                        with open(results_file, 'r') as f:
                            result_data = json.load(f)
                            key = f"{instance_dir.name}_{scenario}_{seed_dir.name}"
                            results[key] = result_data
                    elif error_file.exists():
                        with open(error_file, 'r') as f:
                            error_data = json.load(f)
                            key = f"{instance_dir.name}_{scenario}_{seed_dir.name}"
                            results[key] = error_data
        
        return results
    
    def create_results_dataframe(self) -> pd.DataFrame:
        """
        Create a pandas DataFrame with all experimental results.
        
        Returns:
            DataFrame with experiment results
        """
        data = []
        
        for key, result in self.all_results.items():
            # Parse key: instance_scenario_seed_X
            parts = key.split('_')
            if len(parts) < 4:
                continue
                
            # Find scenario and seed
            scenario = None
            seed = None
            instance_parts = []
            
            for i, part in enumerate(parts):
                if part in self.scenarios:
                    scenario = part
                    seed = int(parts[i+1].replace('seed', ''))
                    instance_parts = parts[:i]
                    break
            
            if scenario is None or seed is None:
                continue
                
            instance = '_'.join(instance_parts)
            
            # Extract metrics
            row = {
                'instance': instance,
                'scenario': scenario,
                'scenario_name': self.scenario_names[scenario],
                'seed': seed,
                'success': 'error' not in result,
                'timestamp': result.get('context_experiment_metadata', {}).get('timestamp', ''),
                'llm_model': result.get('context_experiment_metadata', {}).get('llm_model', '')
            }
            
            # Extract optimization results if available
            if 'optimization_results' in result:
                opt_results = result['optimization_results']
                row.update({
                    'optimization_success': opt_results.get('success', False),
                    'objective_value': opt_results.get('objective_value'),
                    'solve_time': opt_results.get('solve_time'),
                    'status': opt_results.get('status', ''),
                    'error_message': opt_results.get('error', '')
                })
            
            # Extract pattern detection results if available
            if 'extracted_patterns' in result:
                # Parse patterns to count them
                pattern_text = result['extracted_patterns']
                row.update({
                    'has_nonlinearities': 'NON-LINEARITIES DETECTED: YES' in pattern_text,
                    'bilinear_patterns': pattern_text.count('BILINEAR_PATTERNS:') > 0,
                    'min_patterns': pattern_text.count('MIN_PATTERNS:') > 0,
                    'max_patterns': pattern_text.count('MAX_PATTERNS:') > 0,
                    'absolute_patterns': pattern_text.count('ABSOLUTE_PATTERNS:') > 0,
                    'quotient_patterns': pattern_text.count('QUOTIENT_PATTERNS:') > 0,
                    'monotone_patterns': pattern_text.count('MONOTONE_TRANSFORMATION_PATTERNS:') > 0
                })
            
            # Add error information if experiment failed
            if 'error' in result:
                row['error'] = result['error']
                row['experiment_failed'] = True
            else:
                row['experiment_failed'] = False
            
            data.append(row)
        
        return pd.DataFrame(data)
    
    def analyze_scenario_performance(self, df: pd.DataFrame) -> Dict[str, Any]:
        """
        Analyze performance across different scenarios.
        
        Args:
            df: DataFrame with experimental results
            
        Returns:
            Dictionary with performance analysis
        """
        analysis = {}
        
        # Overall success rates
        success_rates = df.groupby('scenario')['success'].agg(['mean', 'count']).round(3)
        analysis['success_rates'] = success_rates.to_dict()
        
        # Optimization success rates
        if 'optimization_success' in df.columns:
            opt_success_rates = df.groupby('scenario')['optimization_success'].agg(['mean', 'count']).round(3)
            analysis['optimization_success_rates'] = opt_success_rates.to_dict()
        
        # Pattern detection analysis
        if 'has_nonlinearities' in df.columns:
            pattern_detection = df.groupby('scenario')['has_nonlinearities'].agg(['mean', 'count']).round(3)
            analysis['pattern_detection_rates'] = pattern_detection.to_dict()
        
        # Solve time analysis (for successful optimizations)
        if 'solve_time' in df.columns:
            solve_times = df[df['optimization_success'] == True].groupby('scenario')['solve_time'].agg(['mean', 'std', 'count'])
            analysis['solve_time_analysis'] = solve_times.to_dict()
        
        # Error analysis
        error_counts = df.groupby('scenario')['experiment_failed'].sum()
        analysis['error_counts'] = error_counts.to_dict()
        
        return analysis
    
    def analyze_instance_performance(self, df: pd.DataFrame) -> Dict[str, Any]:
        """
        Analyze performance across different instances.
        
        Args:
            df: DataFrame with experimental results
            
        Returns:
            Dictionary with instance-specific analysis
        """
        analysis = {}
        
        # Success rates by instance and scenario
        instance_success = df.groupby(['instance', 'scenario'])['success'].agg(['mean', 'count']).round(3)
        analysis['instance_success_rates'] = instance_success.to_dict()
        
        # Optimization success by instance and scenario
        if 'optimization_success' in df.columns:
            instance_opt_success = df.groupby(['instance', 'scenario'])['optimization_success'].agg(['mean', 'count']).round(3)
            analysis['instance_optimization_success_rates'] = instance_opt_success.to_dict()
        
        # Pattern detection by instance
        if 'has_nonlinearities' in df.columns:
            instance_patterns = df.groupby(['instance', 'scenario'])['has_nonlinearities'].agg(['mean', 'count']).round(3)
            analysis['instance_pattern_detection_rates'] = instance_patterns.to_dict()
        
        return analysis
    
    def create_performance_plots(self, df: pd.DataFrame, output_dir: str = "analysis_plots"):
        """
        Create performance comparison plots.
        
        Args:
            df: DataFrame with experimental results
            output_dir: Directory to save plots
        """
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True)
        
        # Set up plotting style
        plt.style.use('seaborn-v0_8')
        sns.set_palette("husl")
        
        # 1. Success Rate Comparison
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle('LinearizeLLM Context Experiment Results - Scenario Comparison', fontsize=16)
        
        # Overall success rates
        success_rates = df.groupby('scenario')['success'].mean()
        axes[0, 0].bar(success_rates.index, success_rates.values, color=['#ff7f0e', '#2ca02c', '#1f77b4'])
        axes[0, 0].set_title('Overall Success Rate')
        axes[0, 0].set_ylabel('Success Rate')
        axes[0, 0].set_ylim(0, 1)
        
        # Optimization success rates
        if 'optimization_success' in df.columns:
            opt_success_rates = df.groupby('scenario')['optimization_success'].mean()
            axes[0, 1].bar(opt_success_rates.index, opt_success_rates.values, color=['#ff7f0e', '#2ca02c', '#1f77b4'])
            axes[0, 1].set_title('Optimization Success Rate')
            axes[0, 1].set_ylabel('Success Rate')
            axes[0, 1].set_ylim(0, 1)
        
        # Pattern detection rates
        if 'has_nonlinearities' in df.columns:
            pattern_rates = df.groupby('scenario')['has_nonlinearities'].mean()
            axes[1, 0].bar(pattern_rates.index, pattern_rates.values, color=['#ff7f0e', '#2ca02c', '#1f77b4'])
            axes[1, 0].set_title('Pattern Detection Rate')
            axes[1, 0].set_ylabel('Detection Rate')
            axes[1, 0].set_ylim(0, 1)
        
        # Error counts
        error_counts = df.groupby('scenario')['experiment_failed'].sum()
        axes[1, 1].bar(error_counts.index, error_counts.values, color=['#d62728', '#d62728', '#d62728'])
        axes[1, 1].set_title('Error Count')
        axes[1, 1].set_ylabel('Number of Errors')
        
        plt.tight_layout()
        plt.savefig(output_path / 'scenario_comparison.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        # 2. Instance-specific performance heatmap
        if len(df['instance'].unique()) > 1:
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))
            fig.suptitle('Instance-Specific Performance by Scenario', fontsize=16)
            
            for i, scenario in enumerate(self.scenarios):
                scenario_data = df[df['scenario'] == scenario]
                if len(scenario_data) > 0:
                    instance_success = scenario_data.groupby('instance')['success'].mean()
                    
                    # Create heatmap data
                    heatmap_data = instance_success.values.reshape(1, -1)
                    sns.heatmap(heatmap_data, 
                              xticklabels=instance_success.index, 
                              yticklabels=[self.scenario_names[scenario]],
                              annot=True, fmt='.2f', cmap='RdYlGn',
                              ax=axes[i], cbar_kws={'label': 'Success Rate'})
                    axes[i].set_title(f'{self.scenario_names[scenario]}')
                    axes[i].set_xlabel('Instance')
            
            plt.tight_layout()
            plt.savefig(output_path / 'instance_performance_heatmap.png', dpi=300, bbox_inches='tight')
            plt.close()
        
        # 3. Solve time comparison (if available)
        if 'solve_time' in df.columns and df['solve_time'].notna().any():
            fig, ax = plt.subplots(figsize=(10, 6))
            
            solve_time_data = df[df['optimization_success'] == True]
            if len(solve_time_data) > 0:
                sns.boxplot(data=solve_time_data, x='scenario', y='solve_time', ax=ax)
                ax.set_title('Solve Time Comparison (Successful Optimizations)')
                ax.set_xlabel('Scenario')
                ax.set_ylabel('Solve Time (seconds)')
                ax.set_xticklabels([self.scenario_names[s] for s in self.scenarios])
                
                plt.tight_layout()
                plt.savefig(output_path / 'solve_time_comparison.png', dpi=300, bbox_inches='tight')
                plt.close()
    
    def generate_summary_report(self, df: pd.DataFrame, output_file: str = "experimental_analysis_report.md"):
        """
        Generate a comprehensive summary report.
        
        Args:
            df: DataFrame with experimental results
            output_file: Output file for the report
        """
        # Perform analyses
        scenario_analysis = self.analyze_scenario_performance(df)
        instance_analysis = self.analyze_instance_performance(df)
        
        # Generate report
        report = []
        report.append("# LinearizeLLM Context Experiment Analysis Report")
        report.append("")
        report.append(f"**Generated:** {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}")
        report.append(f"**Total Experiments:** {len(df)}")
        report.append(f"**Instances:** {len(df['instance'].unique())}")
        report.append(f"**Scenarios:** {len(df['scenario'].unique())}")
        report.append(f"**Seeds per combination:** {len(df['seed'].unique())}")
        report.append("")
        
        # Overall statistics
        report.append("## Overall Statistics")
        report.append("")
        
        total_experiments = len(df)
        successful_experiments = df['success'].sum()
        failed_experiments = total_experiments - successful_experiments
        
        report.append(f"- **Total Experiments:** {total_experiments}")
        report.append(f"- **Successful Experiments:** {successful_experiments} ({successful_experiments/total_experiments*100:.1f}%)")
        report.append(f"- **Failed Experiments:** {failed_experiments} ({failed_experiments/total_experiments*100:.1f}%)")
        report.append("")
        
        # Scenario comparison
        report.append("## Scenario Performance Comparison")
        report.append("")
        
        for scenario in self.scenarios:
            scenario_data = df[df['scenario'] == scenario]
            scenario_success = scenario_data['success'].mean()
            scenario_count = len(scenario_data)
            
            report.append(f"### {self.scenario_names[scenario]}")
            report.append(f"- **Success Rate:** {scenario_success:.3f} ({scenario_success*100:.1f}%)")
            report.append(f"- **Total Experiments:** {scenario_count}")
            
            if 'optimization_success' in df.columns:
                opt_success = scenario_data['optimization_success'].mean()
                report.append(f"- **Optimization Success Rate:** {opt_success:.3f} ({opt_success*100:.1f}%)")
            
            if 'has_nonlinearities' in df.columns:
                pattern_detection = scenario_data['has_nonlinearities'].mean()
                report.append(f"- **Pattern Detection Rate:** {pattern_detection:.3f} ({pattern_detection*100:.1f}%)")
            
            report.append("")
        
        # Instance performance
        report.append("## Instance Performance")
        report.append("")
        
        instance_summary = df.groupby('instance').agg({
            'success': ['mean', 'count'],
            'optimization_success': ['mean', 'count'] if 'optimization_success' in df.columns else ['count', 'count']
        }).round(3)
        
        report.append("| Instance | Success Rate | Total | Opt Success Rate | Opt Total |")
        report.append("|----------|-------------|-------|------------------|-----------|")
        
        for instance in df['instance'].unique():
            instance_data = df[df['instance'] == instance]
            success_rate = instance_data['success'].mean()
            total_count = len(instance_data)
            
            if 'optimization_success' in df.columns:
                opt_success_rate = instance_data['optimization_success'].mean()
                opt_total = instance_data['optimization_success'].count()
            else:
                opt_success_rate = 0
                opt_total = 0
            
            report.append(f"| {instance} | {success_rate:.3f} | {total_count} | {opt_success_rate:.3f} | {opt_total} |")
        
        report.append("")
        
        # Error analysis
        report.append("## Error Analysis")
        report.append("")
        
        if 'error' in df.columns:
            error_data = df[df['experiment_failed'] == True]
            if len(error_data) > 0:
                error_types = error_data['error'].value_counts()
                report.append("### Common Error Types")
                report.append("")
                for error_type, count in error_types.head(10).items():
                    report.append(f"- **{error_type}:** {count} occurrences")
                report.append("")
        
        # Save report
        with open(output_file, 'w') as f:
            f.write('\n'.join(report))
        
        print(f"📄 Analysis report saved to: {output_file}")
    
    def run_complete_analysis(self, output_dir: str = "analysis_results"):
        """
        Run complete analysis and generate all outputs.
        
        Args:
            output_dir: Directory to save analysis results
        """
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True)
        
        print("🔍 Running complete context experiment analysis...")
        
        # Create DataFrame
        df = self.create_results_dataframe()
        
        if len(df) == 0:
            print("❌ No context experiment results found!")
            return
        
        # Save DataFrame
        df.to_csv(output_path / "context_experiment_results.csv", index=False)
        print(f"📊 Results DataFrame saved to: {output_path / 'context_experiment_results.csv'}")
        
        # Perform analyses
        scenario_analysis = self.analyze_scenario_performance(df)
        instance_analysis = self.analyze_instance_performance(df)
        
        # Save analysis results
        with open(output_path / "context_scenario_analysis.json", 'w') as f:
            json.dump(scenario_analysis, f, indent=2, default=str)
        
        with open(output_path / "context_instance_analysis.json", 'w') as f:
            json.dump(instance_analysis, f, indent=2, default=str)
        
        # Create plots
        plots_dir = output_path / "plots"
        self.create_performance_plots(df, str(plots_dir))
        
        # Generate report
        report_file = output_path / "context_experiment_analysis_report.md"
        self.generate_summary_report(df, str(report_file))
        
        print(f"✅ Complete analysis saved to: {output_path}")
        
        # Print summary
        print("\n📈 Analysis Summary:")
        print(f"   📊 Total experiments: {len(df)}")
        print(f"   ✅ Successful: {df['success'].sum()} ({df['success'].mean()*100:.1f}%)")
        print(f"   ❌ Failed: {(~df['success']).sum()} ({(~df['success']).mean()*100:.1f}%)")
        
        if 'optimization_success' in df.columns:
            print(f"   🎯 Optimization successful: {df['optimization_success'].sum()} ({df['optimization_success'].mean()*100:.1f}%)")


def main():
    """Main function to run the analysis."""
    parser = argparse.ArgumentParser(description='Analyze LinearizeLLM context experiment results')
    parser.add_argument('--results-dir', type=str, default="data/context_experiment_results",
                       help='Directory containing context experiment results')
    parser.add_argument('--output-dir', type=str, default="context_analysis_results",
                       help='Directory to save context experiment analysis results')
    parser.add_argument('--create-plots', action='store_true',
                       help='Create performance plots')
    parser.add_argument('--generate-report', action='store_true',
                       help='Generate analysis report')
    
    args = parser.parse_args()
    
    # Create analyzer
    analyzer = ContextExperimentResultsAnalyzer(results_dir=args.results_dir)
    
    # Run analysis
    analyzer.run_complete_analysis(output_dir=args.output_dir)


if __name__ == "__main__":
    main() 