"""Results aggregation and reporting for batch evaluations."""

import json
import pandas as pd
from pathlib import Path
from typing import Dict, List, Any, Optional
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns


class ResultsAggregator:
    """Aggregates and analyzes evaluation results across models and problems."""
    
    def __init__(self, output_dir: str):
        """Initialize results aggregator.
        
        Args:
            output_dir: Directory for output files
        """
        self.output_dir = Path(output_dir).expanduser()
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        self.results_data: List[Dict[str, Any]] = []
        self.model_summaries: Dict[str, Dict[str, Any]] = {}
    
    def add_result(self, result: Dict[str, Any]):
        """Add a single evaluation result.
        
        Args:
            result: Evaluation result dictionary
        """
        self.results_data.append(result)
    
    def add_batch_results(self, results: List[Dict[str, Any]]):
        """Add multiple evaluation results.
        
        Args:
            results: List of evaluation result dictionaries
        """
        self.results_data.extend(results)
    
    def compute_model_summary(self, model_name: str) -> Dict[str, Any]:
        """Compute summary statistics for a specific model.
        
        Args:
            model_name: Name of the model
            
        Returns:
            Dictionary with model statistics
        """
        model_results = [r for r in self.results_data if r.get('model_name') == model_name]
        
        if not model_results:
            return {}
        
        total = len(model_results)
        correct = sum(1 for r in model_results if r.get('is_correct'))
        successful = sum(1 for r in model_results if r.get('success'))
        follow_format = sum(1 for r in model_results if r.get('follow_format', False))
        
        # Subgroup analysis based on document finding
        all_docs_found_results = [r for r in model_results if r.get('search_complete', False)]
        not_all_docs_found_results = [r for r in model_results if not r.get('search_complete', False)]
        
        # Compute subgroup accuracies
        all_docs_total = len(all_docs_found_results)
        all_docs_correct = sum(1 for r in all_docs_found_results if r.get('is_correct'))
        all_docs_accuracy = all_docs_correct / all_docs_total if all_docs_total > 0 else 0
        
        not_all_docs_total = len(not_all_docs_found_results)
        not_all_docs_correct = sum(1 for r in not_all_docs_found_results if r.get('is_correct'))
        not_all_docs_accuracy = not_all_docs_correct / not_all_docs_total if not_all_docs_total > 0 else 0
        
        # Overall proportion of problems with all documents found
        all_docs_found_proportion = all_docs_total / total if total > 0 else 0
        
        summary = {
            'model_name': model_name,
            'total_problems': total,
            'successful_evaluations': successful,
            'correct_answers': correct,
            'accuracy': correct / successful if successful > 0 else 0,
            'success_rate': successful / total if total > 0 else 0,
            'error_rate': (total - successful) / total if total > 0 else 0,
            'follow_format_count': follow_format,
            'follow_format_rate': follow_format / successful if successful > 0 else 0,
            
            # Subgroup analysis
            'all_docs_found_problems': all_docs_total,
            'all_docs_found_proportion': all_docs_found_proportion,
            'all_docs_found_correct': all_docs_correct,
            'all_docs_found_accuracy': all_docs_accuracy,
            
            'not_all_docs_found_problems': not_all_docs_total,
            'not_all_docs_found_proportion': 1 - all_docs_found_proportion,
            'not_all_docs_found_correct': not_all_docs_correct,
            'not_all_docs_found_accuracy': not_all_docs_accuracy,
        }
        
        # Add timing statistics
        durations = [r.get('duration', 0) for r in model_results if r.get('duration')]
        if durations:
            summary['avg_duration'] = sum(durations) / len(durations)
            summary['total_duration'] = sum(durations)
            summary['min_duration'] = min(durations)
            summary['max_duration'] = max(durations)
        
        # Add token statistics
        tokens = [r.get('total_tokens', 0) for r in model_results if r.get('total_tokens')]
        if tokens:
            summary['avg_tokens'] = sum(tokens) / len(tokens)
            summary['total_tokens'] = sum(tokens)
        
        # Error analysis
        errors = [r.get('error') for r in model_results if r.get('error')]
        if errors:
            error_types = {}
            for error in errors:
                error_type = str(error).split(':')[0] if ':' in str(error) else 'Unknown'
                error_types[error_type] = error_types.get(error_type, 0) + 1
            summary['error_types'] = error_types
        
        self.model_summaries[model_name] = summary
        return summary
    
    def compute_all_summaries(self) -> Dict[str, Dict[str, Any]]:
        """Compute summaries for all models in the results.
        
        Returns:
            Dictionary mapping model names to their summaries
        """
        model_names = set(r.get('model_name') for r in self.results_data if r.get('model_name'))
        
        for model_name in model_names:
            self.compute_model_summary(model_name)
        
        return self.model_summaries
    
    def generate_comparison_table(self) -> pd.DataFrame:
        """Generate a comparison table of all models.
        
        Returns:
            DataFrame with model comparisons
        """
        if not self.model_summaries:
            self.compute_all_summaries()
        
        data = []
        for model_name, summary in self.model_summaries.items():
            data.append({
                'Model': model_name,
                'Accuracy': f"{summary.get('accuracy', 0):.2%}",
                'Success Rate': f"{summary.get('success_rate', 0):.2%}",
                'Total Problems': summary.get('total_problems', 0),
                'Correct': summary.get('correct_answers', 0),
                'All Docs Found': f"{summary.get('all_docs_found_proportion', 0):.2%}",
                'All Docs Accuracy': f"{summary.get('all_docs_found_accuracy', 0):.2%}",
                'Not All Docs Accuracy': f"{summary.get('not_all_docs_found_accuracy', 0):.2%}",
                'Avg Duration (s)': f"{summary.get('avg_duration', 0):.2f}",
                'Avg Tokens': f"{summary.get('avg_tokens', 0):.0f}"
            })
        
        df = pd.DataFrame(data)
        
        # Sort by accuracy
        df['_accuracy'] = df['Accuracy'].str.rstrip('%').astype(float)
        df = df.sort_values('_accuracy', ascending=False).drop('_accuracy', axis=1)
        
        return df
    
    def generate_problem_analysis(self) -> pd.DataFrame:
        """Analyze performance across different problems.
        
        Returns:
            DataFrame with problem-level analysis
        """
        problem_data = {}
        
        for result in self.results_data:
            problem_id = result.get('problem_id')
            if not problem_id:
                continue
            
            if problem_id not in problem_data:
                problem_data[problem_id] = {
                    'attempts': 0,
                    'correct': 0,
                    'models_attempted': set(),
                    'models_correct': set()
                }
            
            problem_data[problem_id]['attempts'] += 1
            problem_data[problem_id]['models_attempted'].add(result.get('model_name'))
            
            if result.get('is_correct'):
                problem_data[problem_id]['correct'] += 1
                problem_data[problem_id]['models_correct'].add(result.get('model_name'))
        
        # Convert to DataFrame
        rows = []
        for problem_id, data in problem_data.items():
            rows.append({
                'Problem ID': problem_id,
                'Total Attempts': data['attempts'],
                'Correct Attempts': data['correct'],
                'Success Rate': data['correct'] / data['attempts'] if data['attempts'] > 0 else 0,
                'Models Attempted': len(data['models_attempted']),
                'Models Succeeded': len(data['models_correct'])
            })
        
        df = pd.DataFrame(rows)
        df = df.sort_values('Success Rate', ascending=True)  # Problems sorted by difficulty
        
        return df
    
    def generate_subgroup_analysis(self) -> pd.DataFrame:
        """Generate detailed subgroup analysis based on document finding.
        
        Returns:
            DataFrame with subgroup analysis
        """
        if not self.model_summaries:
            self.compute_all_summaries()
        
        data = []
        for model_name, summary in self.model_summaries.items():
            # Overall stats
            data.append({
                'Model': model_name,
                'Group': 'Overall',
                'Total Problems': summary.get('total_problems', 0),
                'Correct': summary.get('correct_answers', 0),
                'Accuracy': summary.get('accuracy', 0),
                'Proportion': 1.0
            })
            
            # All docs found subgroup
            data.append({
                'Model': model_name,
                'Group': 'All Docs Found',
                'Total Problems': summary.get('all_docs_found_problems', 0),
                'Correct': summary.get('all_docs_found_correct', 0),
                'Accuracy': summary.get('all_docs_found_accuracy', 0),
                'Proportion': summary.get('all_docs_found_proportion', 0)
            })
            
            # Not all docs found subgroup
            data.append({
                'Model': model_name,
                'Group': 'Not All Docs Found',
                'Total Problems': summary.get('not_all_docs_found_problems', 0),
                'Correct': summary.get('not_all_docs_found_correct', 0),
                'Accuracy': summary.get('not_all_docs_found_accuracy', 0),
                'Proportion': summary.get('not_all_docs_found_proportion', 0)
            })
        
        df = pd.DataFrame(data)
        
        # Sort by model then group
        df = df.sort_values(['Model', 'Group'])
        
        return df
    
    def save_report(self, filename: Optional[str] = None):
        """Save a comprehensive report to file.
        
        Args:
            filename: Output filename (without extension)
        """
        if not filename:
            filename = f"evaluation_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        
        # Generate all analyses
        self.compute_all_summaries()
        comparison_df = self.generate_comparison_table()
        problem_df = self.generate_problem_analysis()
        subgroup_df = self.generate_subgroup_analysis()
        
        # Save as Excel with multiple sheets
        excel_path = self.output_dir / f"{filename}.xlsx"
        with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
            comparison_df.to_excel(writer, sheet_name='Model Comparison', index=False)
            subgroup_df.to_excel(writer, sheet_name='Subgroup Analysis', index=False)
            problem_df.to_excel(writer, sheet_name='Problem Analysis', index=False)
            
            # Add raw results
            if self.results_data:
                results_df = pd.DataFrame(self.results_data)
                results_df.to_excel(writer, sheet_name='Raw Results', index=False)
        
        # Save as JSON
        json_path = self.output_dir / f"{filename}.json"
        report_data = {
            'timestamp': datetime.now().isoformat(),
            'model_summaries': self.model_summaries,
            'comparison_table': comparison_df.to_dict('records'),
            'subgroup_analysis': subgroup_df.to_dict('records'),
            'problem_analysis': problem_df.to_dict('records'),
            'total_evaluations': len(self.results_data)
        }
        
        with open(json_path, 'w') as f:
            json.dump(report_data, f, indent=2)
        
        # Save as HTML
        html_path = self.output_dir / f"{filename}.html"
        html_content = f"""
        <html>
        <head>
            <title>Evaluation Report - {datetime.now().strftime('%Y-%m-%d %H:%M')}</title>
            <style>
                body {{ font-family: Arial, sans-serif; margin: 20px; }}
                h1 {{ color: #333; }}
                h2 {{ color: #666; }}
                table {{ border-collapse: collapse; width: 100%; margin: 20px 0; }}
                th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
                th {{ background-color: #f2f2f2; }}
                tr:nth-child(even) {{ background-color: #f9f9f9; }}
            </style>
        </head>
        <body>
            <h1>Evaluation Report</h1>
            <p>Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
            
            <h2>Model Comparison</h2>
            {comparison_df.to_html(index=False)}
            
            <h2>Subgroup Analysis (All Documents Found vs Not All Found)</h2>
            {subgroup_df.to_html(index=False)}
            
            <h2>Problem Analysis (Top 10 Hardest)</h2>
            {problem_df.head(10).to_html(index=False)}
            
            <h2>Summary Statistics</h2>
            <ul>
                <li>Total Evaluations: {len(self.results_data)}</li>
                <li>Number of Models: {len(self.model_summaries)}</li>
                <li>Number of Problems: {len(problem_df)}</li>
            </ul>
        </body>
        </html>
        """
        
        with open(html_path, 'w') as f:
            f.write(html_content)
        
        print("Reports saved to:")
        print(f"  - Excel: {excel_path}")
        print(f"  - JSON: {json_path}")
        print(f"  - HTML: {html_path}")
    
    def plot_comparison(self, save_path: Optional[str] = None):
        """Create visualization plots for model comparison.
        
        Args:
            save_path: Path to save the plot
        """
        if not self.model_summaries:
            self.compute_all_summaries()
        
        # Set up the plot style
        sns.set_style("whitegrid")
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Prepare data
        models = list(self.model_summaries.keys())
        accuracies = [self.model_summaries[m].get('accuracy', 0) for m in models]
        success_rates = [self.model_summaries[m].get('success_rate', 0) for m in models]
        avg_durations = [self.model_summaries[m].get('avg_duration', 0) for m in models]
        avg_tokens = [self.model_summaries[m].get('avg_tokens', 0) for m in models]
        
        # Plot 1: Accuracy comparison
        axes[0, 0].bar(models, accuracies)
        axes[0, 0].set_title('Model Accuracy')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].set_ylim([0, 1])
        axes[0, 0].tick_params(axis='x', rotation=45)
        
        # Plot 2: Success rate
        axes[0, 1].bar(models, success_rates, color='orange')
        axes[0, 1].set_title('Evaluation Success Rate')
        axes[0, 1].set_ylabel('Success Rate')
        axes[0, 1].set_ylim([0, 1])
        axes[0, 1].tick_params(axis='x', rotation=45)
        
        # Plot 3: Average duration
        axes[1, 0].bar(models, avg_durations, color='green')
        axes[1, 0].set_title('Average Evaluation Duration')
        axes[1, 0].set_ylabel('Duration (seconds)')
        axes[1, 0].tick_params(axis='x', rotation=45)
        
        # Plot 4: Average tokens
        axes[1, 1].bar(models, avg_tokens, color='purple')
        axes[1, 1].set_title('Average Tokens Used')
        axes[1, 1].set_ylabel('Tokens')
        axes[1, 1].tick_params(axis='x', rotation=45)
        
        plt.suptitle('Model Evaluation Comparison', fontsize=16)
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        else:
            save_path = self.output_dir / f"comparison_plot_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        print(f"Plot saved to: {save_path}")
        plt.close()
