"""Report generation for component evaluation results."""

import json
from pathlib import Path
from typing import Dict, List, Any, Optional
from datetime import datetime


class ReportGenerator:
    """Generates evaluation reports in various formats."""
    
    def __init__(self, results_dir: Path, reports_dir: Path):
        """Initialize report generator.
        
        Args:
            results_dir: Directory containing evaluation results
            reports_dir: Directory to save reports
        """
        self.results_dir = results_dir
        self.reports_dir = reports_dir
        self.reports_dir.mkdir(parents=True, exist_ok=True)
    
    def _get_component_table_headers(self, component: str) -> tuple:
        """Get component-specific table headers for traditional and LLM judge metrics.
        
        Args:
            component: Component name
            
        Returns:
            Tuple of (traditional_headers, llm_judge_headers)
        """
        if component == 'annotation':
            traditional_headers = ["BLEU", "ROUGE-L", "Semantic"]
            llm_headers = ["Accuracy", "Complete", "Clarity"]
        elif component == 'scene':
            traditional_headers = ["BLEU", "ROUGE-L", "Semantic", "Coverage", "Safety Avg", "Temporal Avg", "Coherence Avg"]
            llm_headers = ["Extract", "Temporal", "Safety"]
        elif component == 'violation':
            traditional_headers = ["Precision", "Recall", "F1", "Accuracy", "Safety Avg"]
            llm_headers = ["Detection", "Explain", "Legal"]
        elif component == 'accident':
            traditional_headers = ["Precision", "Recall", "F1", "Accuracy", "Causality", "Safety"]
            llm_headers = ["RiskAssess", "Consequence", "Context"]
        elif component == 'assessment':
            traditional_headers = ["ScoreCorr", "RiskAcc", "Coverage Avg"]
            llm_headers = ["Assessment", "Advice", "Justify"]
        else:
            traditional_headers = ["Traditional"]
            llm_headers = ["LLM Judge"]
        
        return traditional_headers, llm_headers
    
    def _extract_component_metrics(self, component: str, traditional_scores: List[Dict[str, float]], 
                                   llm_judge_scores: List[Dict[str, Any]]) -> tuple:
        """Extract component-specific metric averages.
        
        Args:
            component: Component name
            traditional_scores: List of traditional metric dictionaries
            llm_judge_scores: List of LLM judge metric dictionaries
            
        Returns:
            Tuple of (traditional_metrics, llm_judge_metrics)
        """
        # Calculate traditional metric averages
        traditional_metrics = {}
        if traditional_scores:
            metric_sums = {}
            metric_counts = {}
            
            for score_dict in traditional_scores:
                for metric, value in score_dict.items():
                    if isinstance(value, (int, float)) and not isnan(value):
                        metric_sums[metric] = metric_sums.get(metric, 0) + value
                        metric_counts[metric] = metric_counts.get(metric, 0) + 1
            
            traditional_metrics = {k: metric_sums[k] / metric_counts[k] 
                                 for k in metric_sums.keys() if metric_counts[k] > 0}
            
            # Calculate enhanced metric averages for scene component
            if component == 'scene':
                # Safety metrics
                safety_metrics = ['critical_scene_f1', 'critical_scene_precision', 'critical_scene_recall', 'safety_weighted_detection']
                safety_scores = [traditional_metrics.get(m, 0) for m in safety_metrics if m in traditional_metrics]
                traditional_metrics['safety_avg'] = sum(safety_scores) / len(safety_scores) if safety_scores else 0
                
                # Temporal metrics  
                temporal_metrics = ['temporal_order_accuracy', 'order_preservation', 'sequence_similarity']
                temporal_scores = [traditional_metrics.get(m, 0) for m in temporal_metrics if m in traditional_metrics]
                traditional_metrics['temporal_avg'] = sum(temporal_scores) / len(temporal_scores) if temporal_scores else 0
                
                # Coherence metrics
                coherence_metrics = ['scene_coherence', 'semantic_transitions', 'narrative_flow']
                coherence_scores = [traditional_metrics.get(m, 0) for m in coherence_metrics if m in traditional_metrics]
                traditional_metrics['coherence_avg'] = sum(coherence_scores) / len(coherence_scores) if coherence_scores else 0
            
            # Calculate enhanced metrics for violation component
            elif component == 'violation':
                # Safety metrics - match actual metric names from the data
                safety_metrics = ['safety_safety_criticality', 'safety_critical_event_ratio']
                safety_scores = [traditional_metrics.get(m.lower().replace(' ', '_'), 0) for m in safety_metrics]
                traditional_metrics['safety_avg'] = sum(safety_scores) / len(safety_scores) if safety_scores else 0
            
            # Calculate enhanced metrics for accident component
            elif component == 'accident':
                # Causality score - match actual metric names from the data
                traditional_metrics['causality'] = traditional_metrics.get('temporal_causality_score', 0) or traditional_metrics.get('safety_temporal_causality', 0)
                # Safety score
                traditional_metrics['safety'] = traditional_metrics.get('safety_criticality_score', 0) or traditional_metrics.get('safety_safety_criticality', 0)
            
            # Calculate enhanced metrics for assessment component  
            elif component == 'assessment':
                # Coverage metrics - use actual metric names from the data
                coverage_metrics = ['advice_coverage', 'strengths_coverage', 'weaknesses_coverage']
                coverage_scores = [traditional_metrics.get(m.lower().replace(' ', '_'), 0) for m in coverage_metrics]
                traditional_metrics['coverage_avg'] = sum(coverage_scores) / len(coverage_scores) if coverage_scores else 0
        
        # Calculate LLM judge metric averages
        llm_judge_metrics = {}
        if llm_judge_scores:
            metric_sums = {}
            metric_counts = {}
            
            for score_dict in llm_judge_scores:
                if isinstance(score_dict, dict):
                    for metric, value in score_dict.items():
                        if isinstance(value, (int, float)):
                            metric_sums[metric] = metric_sums.get(metric, 0) + value
                            metric_counts[metric] = metric_counts.get(metric, 0) + 1
            
            llm_judge_metrics = {k: metric_sums[k] / metric_counts[k] 
                               for k in metric_sums.keys() if metric_counts[k] > 0}
        
        return traditional_metrics, llm_judge_metrics

    def generate_component_report(self, component: str, results: Dict[str, Any]) -> str:
        """Generate a markdown report for a single component.
        
        Args:
            component: Component name
            results: Evaluation results for the component
            
        Returns:
            Report content as markdown string
        """
        report_lines = [
            f"# {component.title()} Component Evaluation Report",
            "",
            f"**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
            f"**Component:** {component}",
            "",
            "## Summary",
            ""
        ]
        
        if not results:
            report_lines.extend([
                "❌ No evaluation results available for this component.",
                "",
                "Please run the evaluation first:",
                f"```bash",
                f"uv run python -m evaluation.component_eval --component {component}",
                f"```"
            ])
            return '\n'.join(report_lines)
        
        # Model performance summary
        successful_models = len([r for r in results.values() if r is not None])
        total_models = len(results)
        
        # Get component-specific headers
        trad_headers, llm_headers = self._get_component_table_headers(component)
        
        # Build table header
        header_parts = ["Model", "Videos"] + trad_headers + llm_headers + ["Avg Time", "Speed", "Status"]
        header_line = "| " + " | ".join(header_parts) + " |"
        separator_line = "|" + "|".join(["-------" for _ in header_parts]) + "|"
        
        report_lines.extend([
            f"- **Models Evaluated:** {successful_models}/{total_models}",
            f"- **Success Rate:** {successful_models/total_models*100:.1f}%",
            "",
            "## Model Performance",
            "",
            header_line,
            separator_line
        ])
        
        for model, result in results.items():
            # Handle failed models
            if result is None:
                fail_row = [model, "0"] + ["-"] * (len(trad_headers) + len(llm_headers)) + ["-", "-", "❌ Failed"]
                report_lines.append("| " + " | ".join(fail_row) + " |")
                continue
            
            video_count = len(result.get('video_results', {}))
            traditional_scores = result.get('traditional_metrics', [])
            llm_judge_scores = result.get('llm_judge_scores', [])
            timing_summary = result.get('timing_summary', {})
            
            # Extract component-specific metrics
            trad_metrics, llm_metrics = self._extract_component_metrics(component, traditional_scores, llm_judge_scores)
            
            # Build row data
            row_data = [model, str(video_count)]
            
            # Add traditional metric values
            if component == 'annotation':
                row_data.extend([
                    f"{trad_metrics.get('bleu', 0):.2f}",
                    f"{trad_metrics.get('rouge_l', 0):.2f}",
                    f"{trad_metrics.get('semantic_similarity', 0):.2f}"
                ])
            elif component == 'scene':
                row_data.extend([
                    f"{trad_metrics.get('bleu', 0):.2f}",
                    f"{trad_metrics.get('rouge_l', 0):.2f}",
                    f"{trad_metrics.get('semantic_similarity', 0):.2f}",
                    f"{trad_metrics.get('scene_coverage', 0):.2f}",
                    f"{trad_metrics.get('safety_avg', 0):.2f}",
                    f"{trad_metrics.get('temporal_avg', 0):.2f}",
                    f"{trad_metrics.get('coherence_avg', 0):.2f}"
                ])
            elif component == 'violation':
                row_data.extend([
                    f"{trad_metrics.get('precision', 0):.2f}",
                    f"{trad_metrics.get('recall', 0):.2f}",
                    f"{trad_metrics.get('f1', 0):.2f}",
                    f"{trad_metrics.get('accuracy', 0):.2f}",
                    f"{trad_metrics.get('safety_avg', 0):.2f}"
                ])
            elif component == 'accident':
                row_data.extend([
                    f"{trad_metrics.get('precision', 0):.2f}",
                    f"{trad_metrics.get('recall', 0):.2f}",
                    f"{trad_metrics.get('f1', 0):.2f}",
                    f"{trad_metrics.get('accuracy', 0):.2f}",
                    f"{trad_metrics.get('causality', 0):.2f}",
                    f"{trad_metrics.get('safety', 0):.2f}"
                ])
            elif component == 'assessment':
                row_data.extend([
                    f"{trad_metrics.get('safety_score_mae', 0):.2f}",
                    f"{trad_metrics.get('risk_level_accuracy', 0):.2f}",
                    f"{trad_metrics.get('coverage_avg', 0):.2f}"
                ])
            else:
                trad_avg = self._calculate_traditional_average(traditional_scores)
                row_data.append(f"{trad_avg:.2f}")
            
            # Add LLM judge metric values
            if component == 'annotation':
                row_data.extend([
                    f"{llm_metrics.get('accuracy_score', 0):.0f}",
                    f"{llm_metrics.get('completeness_score', 0):.0f}",
                    f"{llm_metrics.get('clarity_score', 0):.0f}"
                ])
            elif component == 'scene':
                row_data.extend([
                    f"{llm_metrics.get('extraction_quality', 0):.0f}",
                    f"{llm_metrics.get('temporal_coherence', 0):.0f}",
                    f"{llm_metrics.get('safety_relevance', 0):.0f}"
                ])
            elif component == 'violation':
                row_data.extend([
                    f"{llm_metrics.get('detection_accuracy', 0):.0f}",
                    f"{llm_metrics.get('explanation_quality', 0):.0f}",
                    f"{llm_metrics.get('legal_consistency', 0):.0f}"
                ])
            elif component == 'accident':
                row_data.extend([
                    f"{llm_metrics.get('risk_assessment_accuracy', 0):.0f}",
                    f"{llm_metrics.get('consequence_prediction', 0):.0f}",
                    f"{llm_metrics.get('context_understanding', 0):.0f}"
                ])
            elif component == 'assessment':
                row_data.extend([
                    f"{llm_metrics.get('assessment_accuracy', 0):.1f}",
                    f"{llm_metrics.get('advice_actionability', 0):.1f}",
                    f"{llm_metrics.get('score_justification', 0):.1f}"
                ])
            else:
                judge_avg = llm_metrics.get('overall_quality', 0)
                row_data.append(f"{judge_avg:.1f}")
            
            # Add timing information
            if timing_summary:
                avg_time_str = self._format_time(timing_summary.get('mean_time', 0))
                # Calculate processing speed (videos per minute)
                total_time = timing_summary.get('total_time', 0)
                if total_time > 0:
                    speed_str = f"{video_count / (total_time / 60):.1f}/min"
                else:
                    speed_str = "-"
            else:
                avg_time_str = "-"
                speed_str = "-"
            row_data.append(avg_time_str)
            row_data.append(speed_str)
            
            # Add status
            status = "✅ Success" if video_count > 0 else "⚠️ Partial"
            row_data.append(status)
            
            # Build the table row
            report_lines.append("| " + " | ".join(row_data) + " |")
        
        report_lines.extend([
            "",
            "## Detailed Results",
            "",
            "### Traditional Metrics",
            ""
        ])
        
        # Add traditional metrics details
        for model, result in results.items():
            if result and result.get('traditional_metrics'):
                report_lines.append(f"#### {model}")
                report_lines.append("")
                traditional_scores = result['traditional_metrics']
                self._add_traditional_metrics_details(report_lines, traditional_scores, component)
        
        report_lines.extend([
            "",
            "### LLM Judge Evaluation",
            ""
        ])
        
        # Add LLM judge details
        for model, result in results.items():
            if result and result.get('llm_judge_scores'):
                report_lines.append(f"#### {model}")
                report_lines.append("")
                judge_scores = result['llm_judge_scores']
                self._add_llm_judge_details(report_lines, judge_scores, component)
        
        # Add timing analysis
        model_timings = {}
        for model, result in results.items():
            if result and result.get('timing_summary'):
                model_timings[model] = result['timing_summary']
        
        if model_timings:
            report_lines.extend([
                "",
                "### Timing Analysis",
                "",
                "| Model | Mean Time | Median Time | Min Time | Max Time | Total Time | Videos |",
                "|-------|-----------|-------------|----------|----------|------------|--------|"
            ])
            
            # Sort by mean time (fastest first)
            sorted_models = sorted(model_timings.items(), key=lambda x: x[1].get('mean_time', float('inf')))
            
            for model, timing_stats in sorted_models:
                mean_time = self._format_time(timing_stats.get('mean_time', 0))
                median_time = self._format_time(timing_stats.get('median_time', 0))
                min_time = self._format_time(timing_stats.get('min_time', 0))
                max_time = self._format_time(timing_stats.get('max_time', 0))
                total_time = self._format_time(timing_stats.get('total_time', 0))
                count = timing_stats.get('count', 0)
                
                report_lines.append(f"| {model} | {mean_time} | {median_time} | {min_time} | {max_time} | {total_time} | {count} |")
            
            # Add comprehensive timing analysis
            if len(sorted_models) >= 2:
                fastest = sorted_models[0]
                slowest = sorted_models[-1]
                speed_ratio = slowest[1].get('mean_time', 1) / fastest[1].get('mean_time', 1)
                
                # Calculate performance statistics
                all_mean_times = [timing['mean_time'] for _, timing in sorted_models if timing.get('mean_time', 0) > 0]
                if all_mean_times:
                    import statistics
                    median_performance = statistics.median(all_mean_times)
                    std_dev = statistics.stdev(all_mean_times) if len(all_mean_times) > 1 else 0
                    total_videos = sum(timing.get('count', 0) for _, timing in sorted_models)
                    total_time = sum(timing.get('total_time', 0) for _, timing in sorted_models)
                    
                    # Categorize models by performance
                    fast_models = [model for model, timing in sorted_models if timing.get('mean_time', 0) < median_performance * 0.8]
                    medium_models = [model for model, timing in sorted_models if median_performance * 0.8 <= timing.get('mean_time', 0) <= median_performance * 1.2]
                    slow_models = [model for model, timing in sorted_models if timing.get('mean_time', 0) > median_performance * 1.2]
                
                report_lines.extend([
                    "",
                    "#### Performance Insights",
                    "",
                    f"- **Fastest Model**: {fastest[0]} ({self._format_time(fastest[1].get('mean_time', 0))} average)",
                    f"- **Slowest Model**: {slowest[0]} ({self._format_time(slowest[1].get('mean_time', 0))} average)",
                    f"- **Speed Difference**: {speed_ratio:.1f}x faster (fastest vs slowest)",
                    f"- **Performance Range**: {self._format_time(fastest[1].get('mean_time', 0))} - {self._format_time(slowest[1].get('mean_time', 0))}",
                    ""
                ])
                
                # Add performance distribution
                if all_mean_times:
                    report_lines.extend([
                        "#### Performance Distribution",
                        "",
                        f"- **Total Processing Time**: {self._format_time(total_time)} across {total_videos:,} video evaluations",
                        f"- **Average Processing Time**: {self._format_time(median_performance)} (median)",
                        f"- **Performance Std Dev**: ±{self._format_time(std_dev)}",
                        ""
                    ])
                    
                    # Performance categories
                    if fast_models or medium_models or slow_models:
                        report_lines.extend([
                            "#### Performance Categories",
                            ""
                        ])
                        
                        if fast_models:
                            report_lines.append(f"🚀 **Fast Models** (<{self._format_time(median_performance * 0.8)}): {len(fast_models)} models")
                            for model in fast_models[:3]:  # Show top 3
                                timing = next(t for m, t in sorted_models if m == model)
                                report_lines.append(f"   • {model}: {self._format_time(timing.get('mean_time', 0))}")
                            if len(fast_models) > 3:
                                report_lines.append(f"   • ... and {len(fast_models) - 3} more")
                        
                        if medium_models:
                            report_lines.append(f"⚡ **Medium Speed** ({self._format_time(median_performance * 0.8)}-{self._format_time(median_performance * 1.2)}): {len(medium_models)} models")
                        
                        if slow_models:
                            report_lines.append(f"🐌 **Slower Models** (>{self._format_time(median_performance * 1.2)}): {len(slow_models)} models")
                            for model in slow_models[-2:]:  # Show bottom 2
                                timing = next(t for m, t in sorted_models if m == model)
                                report_lines.append(f"   • {model}: {self._format_time(timing.get('mean_time', 0))}")
                        
                        report_lines.append("")
                
                # Add efficiency metrics
                report_lines.extend([
                    "#### Efficiency Analysis",
                    "",
                    f"- **Most Consistent**: Model with lowest timing variance",
                    f"- **Best Throughput**: {fastest[0]} at {fastest[1].get('count', 0) / (fastest[1].get('total_time', 1) / 60):.1f} videos/minute",
                    f"- **Resource Intensive**: {slowest[0]} at {slowest[1].get('count', 0) / (slowest[1].get('total_time', 1) / 60):.1f} videos/minute",
                    ""
                ])
        
        return '\n'.join(report_lines)
    
    def _format_time(self, seconds: float) -> str:
        """Format time duration in a human-readable way."""
        if seconds < 1:
            return f"{seconds*1000:.0f}ms"
        elif seconds < 60:
            return f"{seconds:.1f}s"
        elif seconds < 3600:
            minutes = int(seconds // 60)
            remaining_seconds = seconds % 60
            return f"{minutes}m {remaining_seconds:.0f}s"
        else:
            hours = int(seconds // 3600)
            remaining_minutes = int((seconds % 3600) // 60)
            return f"{hours}h {remaining_minutes}m"
    
    def _calculate_traditional_average(self, traditional_scores: List[Dict[str, float]]) -> float:
        """Calculate average of traditional metric scores."""
        if not traditional_scores:
            return 0.0
        
        all_values = []
        for score_dict in traditional_scores:
            for value in score_dict.values():
                if isinstance(value, (int, float)) and not isnan(value):
                    all_values.append(value)
        
        return sum(all_values) / len(all_values) if all_values else 0.0
    
    def _add_traditional_metrics_details(self, report_lines: List[str], 
                                       traditional_scores: List[Dict[str, float]], 
                                       component: str):
        """Add traditional metrics details to report."""
        if not traditional_scores:
            report_lines.append("No traditional metrics available.")
            return
        
        # Calculate averages for each metric
        metric_sums = {}
        metric_counts = {}
        
        for score_dict in traditional_scores:
            for metric, value in score_dict.items():
                if isinstance(value, (int, float)) and not isnan(value):
                    metric_sums[metric] = metric_sums.get(metric, 0) + value
                    metric_counts[metric] = metric_counts.get(metric, 0) + 1
        
        if metric_sums:
            report_lines.extend([
                "| Metric | Average Score |",
                "|--------|---------------|"
            ])
            
            for metric in sorted(metric_sums.keys()):
                avg = metric_sums[metric] / metric_counts[metric]
                report_lines.append(f"| {metric.replace('_', ' ').title()} | {avg:.3f} |")
        
        report_lines.append("")
    
    def _add_llm_judge_details(self, report_lines: List[str], 
                             judge_scores: List[Dict[str, Any]], 
                             component: str):
        """Add LLM judge details to report."""
        if not judge_scores:
            report_lines.append("No LLM judge scores available.")
            return
        
        # Calculate averages for each score type
        score_sums = {}
        score_counts = {}
        
        for score_dict in judge_scores:
            for key, value in score_dict.items():
                if key.endswith('_score') or key in ['accuracy_score', 'completeness_score', 'clarity_score',
                                                   'extraction_quality', 'temporal_coherence', 'safety_relevance',
                                                   'detection_accuracy', 'explanation_quality', 'legal_consistency',
                                                   'risk_assessment_accuracy', 'consequence_prediction', 'context_understanding',
                                                   'assessment_accuracy', 'advice_actionability', 'score_justification',
                                                   'overall_quality']:
                    if isinstance(value, (int, float)):
                        score_sums[key] = score_sums.get(key, 0) + value
                        score_counts[key] = score_counts.get(key, 0) + 1
        
        if score_sums:
            report_lines.extend([
                "| Dimension | Average Score (1-10) |",
                "|-----------|---------------------|"
            ])
            
            for score_key in sorted(score_sums.keys()):
                avg = score_sums[score_key] / score_counts[score_key]
                display_name = score_key.replace('_', ' ').title()
                report_lines.append(f"| {display_name} | {avg:.1f} |")
        
        report_lines.append("")
    
    def save_component_report(self, component: str, results: Dict[str, Any]) -> Path:
        """Save component report to file.
        
        Args:
            component: Component name
            results: Evaluation results
            
        Returns:
            Path to saved report file
        """
        report_content = self.generate_component_report(component, results)
        report_file = self.reports_dir / f"{component}_evaluation_report.md"
        
        with open(report_file, 'w', encoding='utf-8') as f:
            f.write(report_content)
        
        return report_file


def isnan(value):
    """Check if value is NaN (handles various types)."""
    try:
        import math
        return math.isnan(float(value))
    except (TypeError, ValueError):
        return False