"""
Analysis and visualization tools for document generation benchmark results.

This module provides detailed analysis, metrics calculation, and visualization
for the document generation benchmark results.

Author: GitHub Copilot
Date: September 14, 2025
"""

import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from typing import Dict, List, Any
import os
from datetime import datetime
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class BenchmarkAnalyzer:
    """Analyzer for document generation benchmark results"""
    
    def __init__(self, results_dir: str):
        """Initialize analyzer with results directory"""
        self.results_dir = results_dir
        self.summary_data = None
        self.detailed_results = None
        
    def load_results(self):
        """Load benchmark results from files"""
        summary_file = os.path.join(self.results_dir, "benchmark_summary.json")
        
        if not os.path.exists(summary_file):
            raise FileNotFoundError(f"Summary file not found: {summary_file}")
        
        with open(summary_file, 'r', encoding='utf-8') as f:
            self.summary_data = json.load(f)
        
        self.detailed_results = self.summary_data.get("detailed_results", [])
        logger.info(f"Loaded {len(self.detailed_results)} benchmark results")
    
    def calculate_advanced_metrics(self) -> Dict[str, Any]:
        """Calculate advanced metrics beyond basic averages"""
        if not self.detailed_results:
            self.load_results()
        
        results_df = pd.DataFrame(self.detailed_results)
        
        metrics = {
            "score_statistics": {
                "user_profile_accuracy": {
                    "mean": results_df["user_profile_accuracy"].mean(),
                    "std": results_df["user_profile_accuracy"].std(),
                    "median": results_df["user_profile_accuracy"].median(),
                    "min": results_df["user_profile_accuracy"].min(),
                    "max": results_df["user_profile_accuracy"].max()
                },
                "intent_capture_accuracy": {
                    "mean": results_df["intent_capture_accuracy"].mean(),
                    "std": results_df["intent_capture_accuracy"].std(),
                    "median": results_df["intent_capture_accuracy"].median(),
                    "min": results_df["intent_capture_accuracy"].min(),
                    "max": results_df["intent_capture_accuracy"].max()
                },
                "citation_accuracy": {
                    "mean": results_df["citation_accuracy"].mean(),
                    "std": results_df["citation_accuracy"].std(),
                    "median": results_df["citation_accuracy"].median(),
                    "min": results_df["citation_accuracy"].min(),
                    "max": results_df["citation_accuracy"].max()
                },
                "document_quality_score": {
                    "mean": results_df["document_quality_score"].mean(),
                    "std": results_df["document_quality_score"].std(),
                    "median": results_df["document_quality_score"].median(),
                    "min": results_df["document_quality_score"].min(),
                    "max": results_df["document_quality_score"].max()
                }
            },
            "correlations": self._calculate_correlations(results_df),
            "performance_by_document_type": self._analyze_by_document_type(),
            "performance_by_user_role": self._analyze_by_user_role(),
            "citation_analysis": self._analyze_citations(),
            "quality_dimension_analysis": self._analyze_quality_dimensions()
        }
        
        return metrics
    
    def _calculate_correlations(self, df: pd.DataFrame) -> Dict[str, float]:
        """Calculate correlations between different metrics"""
        correlation_matrix = df[["user_profile_accuracy", "intent_capture_accuracy", 
                               "citation_accuracy", "document_quality_score"]].corr()
        
        return {
            "profile_intent_correlation": correlation_matrix.loc["user_profile_accuracy", "intent_capture_accuracy"],
            "intent_quality_correlation": correlation_matrix.loc["intent_capture_accuracy", "document_quality_score"],
            "citation_quality_correlation": correlation_matrix.loc["citation_accuracy", "document_quality_score"],
            "profile_quality_correlation": correlation_matrix.loc["user_profile_accuracy", "document_quality_score"]
        }
    
    def _analyze_by_document_type(self) -> Dict[str, Dict[str, float]]:
        """Analyze performance by document type"""
        doc_type_analysis = {}
        
        for result in self.detailed_results:
            doc_type = result["detailed_evaluation"]["intent"]["document_type"]
            
            if doc_type not in doc_type_analysis:
                doc_type_analysis[doc_type] = {
                    "scores": [],
                    "profile_accuracy": [],
                    "intent_accuracy": [],
                    "citation_accuracy": [],
                    "quality_scores": []
                }
            
            doc_type_analysis[doc_type]["scores"].append(result["overall_score"])
            doc_type_analysis[doc_type]["profile_accuracy"].append(result["user_profile_accuracy"])
            doc_type_analysis[doc_type]["intent_accuracy"].append(result["intent_capture_accuracy"])
            doc_type_analysis[doc_type]["citation_accuracy"].append(result["citation_accuracy"])
            doc_type_analysis[doc_type]["quality_scores"].append(result["document_quality_score"])
        
        # Calculate averages
        summary = {}
        for doc_type, data in doc_type_analysis.items():
            summary[doc_type] = {
                "average_score": np.mean(data["scores"]),
                "count": len(data["scores"]),
                "avg_profile_accuracy": np.mean(data["profile_accuracy"]),
                "avg_intent_accuracy": np.mean(data["intent_accuracy"]),
                "avg_citation_accuracy": np.mean(data["citation_accuracy"]),
                "avg_quality_score": np.mean(data["quality_scores"])
            }
        
        return summary
    
    def _analyze_by_user_role(self) -> Dict[str, Dict[str, float]]:
        """Analyze performance by user role"""
        role_analysis = {}
        
        for result in self.detailed_results:
            role = result["detailed_evaluation"]["user_profile"]["role"]
            
            if role not in role_analysis:
                role_analysis[role] = {
                    "scores": [],
                    "profile_accuracy": [],
                    "intent_accuracy": [],
                    "citation_accuracy": [],
                    "quality_scores": []
                }
            
            role_analysis[role]["scores"].append(result["overall_score"])
            role_analysis[role]["profile_accuracy"].append(result["user_profile_accuracy"])
            role_analysis[role]["intent_accuracy"].append(result["intent_capture_accuracy"])
            role_analysis[role]["citation_accuracy"].append(result["citation_accuracy"])
            role_analysis[role]["quality_scores"].append(result["document_quality_score"])
        
        # Calculate averages
        summary = {}
        for role, data in role_analysis.items():
            summary[role] = {
                "average_score": np.mean(data["scores"]),
                "count": len(data["scores"]),
                "avg_profile_accuracy": np.mean(data["profile_accuracy"]),
                "avg_intent_accuracy": np.mean(data["intent_accuracy"]),
                "avg_citation_accuracy": np.mean(data["citation_accuracy"]),
                "avg_quality_score": np.mean(data["quality_scores"])
            }
        
        return summary
    
    def _analyze_citations(self) -> Dict[str, Any]:
        """Analyze citation patterns and accuracy"""
        citation_stats = {
            "total_documents": len(self.detailed_results),
            "documents_with_citations": 0,
            "avg_citations_per_doc": 0,
            "citation_accuracy_distribution": [],
            "most_cited_messages": {},
            "citation_relevance_scores": []
        }
        
        total_citations = 0
        
        for result in self.detailed_results:
            citations = result["detailed_evaluation"]["document"]["citations"]
            
            if citations:
                citation_stats["documents_with_citations"] += 1
                total_citations += len(citations)
                
                for citation in citations:
                    msg_id = citation["message_id"]
                    citation_stats["most_cited_messages"][msg_id] = citation_stats["most_cited_messages"].get(msg_id, 0) + 1
                    citation_stats["citation_relevance_scores"].append(citation["context_relevance"])
            
            citation_stats["citation_accuracy_distribution"].append(result["citation_accuracy"])
        
        if citation_stats["documents_with_citations"] > 0:
            citation_stats["avg_citations_per_doc"] = total_citations / citation_stats["documents_with_citations"]
        
        return citation_stats
    
    def _analyze_quality_dimensions(self) -> Dict[str, Dict[str, float]]:
        """Analyze quality scores by dimension"""
        quality_dimensions = [
            "factual_accuracy", "citation_quality", "structure_organization",
            "completeness", "readability", "professional_quality", "consistency"
        ]
        
        dimension_analysis = {}
        
        for dimension in quality_dimensions:
            scores = []
            for result in self.detailed_results:
                quality_scores = result["detailed_evaluation"].get("quality_scores", {})
                if dimension in quality_scores:
                    scores.append(quality_scores[dimension])
            
            if scores:
                dimension_analysis[dimension] = {
                    "mean": np.mean(scores),
                    "std": np.std(scores),
                    "median": np.median(scores),
                    "min": np.min(scores),
                    "max": np.max(scores)
                }
        
        return dimension_analysis
    
    def generate_visualizations(self, output_dir: str = None):
        """Generate comprehensive visualizations"""
        if output_dir is None:
            output_dir = os.path.join(self.results_dir, "visualizations")
        
        os.makedirs(output_dir, exist_ok=True)
        
        # Set style
        plt.style.use('seaborn-v0_8')
        sns.set_palette("husl")
        
        # 1. Overall Score Distribution
        self._plot_score_distribution(output_dir)
        
        # 2. Performance by Benchmark Component
        self._plot_component_performance(output_dir)
        
        # 3. Document Type Performance
        self._plot_document_type_performance(output_dir)
        
        # 4. User Role Performance
        self._plot_user_role_performance(output_dir)
        
        # 5. Quality Dimensions Analysis
        self._plot_quality_dimensions(output_dir)
        
        # 6. Citation Analysis
        self._plot_citation_analysis(output_dir)
        
        # 7. Correlation Heatmap
        self._plot_correlation_heatmap(output_dir)
        
        logger.info(f"Visualizations saved to {output_dir}")
    
    def _plot_score_distribution(self, output_dir: str):
        """Plot overall score distribution"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
        
        scores_df = pd.DataFrame(self.detailed_results)
        
        # Overall score histogram
        ax1.hist(scores_df["overall_score"], bins=20, alpha=0.7, edgecolor='black')
        ax1.set_title("Overall Score Distribution")
        ax1.set_xlabel("Score")
        ax1.set_ylabel("Frequency")
        
        # Box plot of all metrics
        metrics = ["user_profile_accuracy", "intent_capture_accuracy", "citation_accuracy", "document_quality_score"]
        ax2.boxplot([scores_df[metric] for metric in metrics], labels=metrics)
        ax2.set_title("Score Distribution by Metric")
        ax2.set_ylabel("Score")
        ax2.tick_params(axis='x', rotation=45)
        
        # Score progression
        ax3.plot(range(len(scores_df)), scores_df["overall_score"], marker='o', alpha=0.7)
        ax3.set_title("Score Progression")
        ax3.set_xlabel("Query Index")
        ax3.set_ylabel("Overall Score")
        
        # Performance metrics comparison
        metric_means = [scores_df[metric].mean() for metric in metrics]
        ax4.bar(metrics, metric_means, alpha=0.7)
        ax4.set_title("Average Performance by Metric")
        ax4.set_ylabel("Average Score")
        ax4.tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "score_distribution.png"), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_component_performance(self, output_dir: str):
        """Plot performance by benchmark component"""
        scores_df = pd.DataFrame(self.detailed_results)
        
        fig, ax = plt.subplots(figsize=(12, 8))
        
        components = ["User Profile\nInference", "Intent\nCapture", "Citation\nAccuracy", "Document\nQuality"]
        metrics = ["user_profile_accuracy", "intent_capture_accuracy", "citation_accuracy", "document_quality_score"]
        
        x = np.arange(len(components))
        width = 0.35
        
        means = [scores_df[metric].mean() for metric in metrics]
        stds = [scores_df[metric].std() for metric in metrics]
        
        bars = ax.bar(x, means, width, yerr=stds, alpha=0.7, capsize=5)
        
        ax.set_xlabel("Benchmark Components")
        ax.set_ylabel("Score")
        ax.set_title("Performance by Benchmark Component")
        ax.set_xticks(x)
        ax.set_xticklabels(components)
        ax.set_ylim(0, 5)
        
        # Add value labels on bars
        for bar, mean in zip(bars, means):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.05,
                   f'{mean:.2f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "component_performance.png"), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_document_type_performance(self, output_dir: str):
        """Plot performance by document type"""
        doc_type_analysis = self._analyze_by_document_type()
        
        if not doc_type_analysis:
            return
        
        fig, ax = plt.subplots(figsize=(12, 8))
        
        doc_types = list(doc_type_analysis.keys())
        scores = [doc_type_analysis[dt]["average_score"] for dt in doc_types]
        counts = [doc_type_analysis[dt]["count"] for dt in doc_types]
        
        bars = ax.bar(doc_types, scores, alpha=0.7)
        
        # Add count labels
        for bar, count in zip(bars, counts):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.05,
                   f'n={count}', ha='center', va='bottom', fontsize=8)
        
        ax.set_xlabel("Document Type")
        ax.set_ylabel("Average Score")
        ax.set_title("Performance by Document Type")
        ax.tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "document_type_performance.png"), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_user_role_performance(self, output_dir: str):
        """Plot performance by user role"""
        role_analysis = self._analyze_by_user_role()
        
        if not role_analysis:
            return
        
        fig, ax = plt.subplots(figsize=(12, 8))
        
        roles = list(role_analysis.keys())
        scores = [role_analysis[role]["average_score"] for role in roles]
        counts = [role_analysis[role]["count"] for role in roles]
        
        bars = ax.bar(roles, scores, alpha=0.7)
        
        # Add count labels
        for bar, count in zip(bars, counts):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.05,
                   f'n={count}', ha='center', va='bottom', fontsize=8)
        
        ax.set_xlabel("User Role")
        ax.set_ylabel("Average Score")
        ax.set_title("Performance by User Role")
        ax.tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "user_role_performance.png"), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_quality_dimensions(self, output_dir: str):
        """Plot quality scores by dimension"""
        dimension_analysis = self._analyze_quality_dimensions()
        
        if not dimension_analysis:
            return
        
        fig, ax = plt.subplots(figsize=(14, 8))
        
        dimensions = list(dimension_analysis.keys())
        means = [dimension_analysis[dim]["mean"] for dim in dimensions]
        stds = [dimension_analysis[dim]["std"] for dim in dimensions]
        
        bars = ax.bar(dimensions, means, yerr=stds, alpha=0.7, capsize=5)
        
        ax.set_xlabel("Quality Dimension")
        ax.set_ylabel("Score")
        ax.set_title("Document Quality by Dimension")
        ax.tick_params(axis='x', rotation=45)
        ax.set_ylim(0, 5)
        
        # Add value labels
        for bar, mean in zip(bars, means):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.05,
                   f'{mean:.2f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "quality_dimensions.png"), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_citation_analysis(self, output_dir: str):
        """Plot citation analysis"""
        citation_stats = self._analyze_citations()
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
        
        # Citation accuracy distribution
        ax1.hist(citation_stats["citation_accuracy_distribution"], bins=20, alpha=0.7, edgecolor='black')
        ax1.set_title("Citation Accuracy Distribution")
        ax1.set_xlabel("Citation Accuracy")
        ax1.set_ylabel("Frequency")
        
        # Documents with vs without citations
        citation_data = [citation_stats["documents_with_citations"], 
                        citation_stats["total_documents"] - citation_stats["documents_with_citations"]]
        ax2.pie(citation_data, labels=["With Citations", "Without Citations"], autopct='%1.1f%%')
        ax2.set_title("Documents with Citations")
        
        # Most cited messages (top 10)
        if citation_stats["most_cited_messages"]:
            most_cited = sorted(citation_stats["most_cited_messages"].items(), 
                              key=lambda x: x[1], reverse=True)[:10]
            msgs, counts = zip(*most_cited)
            ax3.bar(range(len(msgs)), counts, alpha=0.7)
            ax3.set_title("Most Cited Messages (Top 10)")
            ax3.set_xlabel("Message ID")
            ax3.set_ylabel("Citation Count")
            ax3.set_xticks(range(len(msgs)))
            ax3.set_xticklabels([msg[:10] + "..." if len(msg) > 10 else msg for msg in msgs], rotation=45)
        
        # Citation relevance scores
        if citation_stats["citation_relevance_scores"]:
            ax4.hist(citation_stats["citation_relevance_scores"], bins=20, alpha=0.7, edgecolor='black')
            ax4.set_title("Citation Relevance Scores")
            ax4.set_xlabel("Relevance Score")
            ax4.set_ylabel("Frequency")
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "citation_analysis.png"), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_correlation_heatmap(self, output_dir: str):
        """Plot correlation heatmap"""
        scores_df = pd.DataFrame(self.detailed_results)
        
        correlation_data = scores_df[["user_profile_accuracy", "intent_capture_accuracy", 
                                    "citation_accuracy", "document_quality_score", "overall_score"]]
        
        fig, ax = plt.subplots(figsize=(10, 8))
        
        correlation_matrix = correlation_data.corr()
        sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0, 
                   square=True, ax=ax, cbar_kws={"shrink": .8})
        
        ax.set_title("Correlation Matrix of Benchmark Metrics")
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "correlation_heatmap.png"), dpi=300, bbox_inches='tight')
        plt.close()
    
    def generate_detailed_report(self, output_file: str = None):
        """Generate a detailed text report"""
        if output_file is None:
            output_file = os.path.join(self.results_dir, "detailed_report.txt")
        
        metrics = self.calculate_advanced_metrics()
        
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("DOCUMENT GENERATION BENCHMARK - DETAILED ANALYSIS REPORT\n")
            f.write("=" * 60 + "\n\n")
            f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            
            # Basic statistics
            f.write("BASIC STATISTICS\n")
            f.write("-" * 30 + "\n")
            f.write(f"Total queries processed: {len(self.detailed_results)}\n")
            if self.summary_data:
                avg_scores = self.summary_data["average_scores"]
                f.write(f"Average user profile accuracy: {avg_scores['user_profile_accuracy']:.3f}\n")
                f.write(f"Average intent capture accuracy: {avg_scores['intent_capture_accuracy']:.3f}\n")
                f.write(f"Average citation accuracy: {avg_scores['citation_accuracy']:.3f}\n")
                f.write(f"Average document quality score: {avg_scores['document_quality_score']:.3f}\n")
                f.write(f"Overall average score: {avg_scores['overall_score']:.3f}\n\n")
            
            # Advanced metrics
            f.write("ADVANCED METRICS\n")
            f.write("-" * 30 + "\n")
            
            # Score statistics
            f.write("Score Statistics:\n")
            for metric, stats in metrics["score_statistics"].items():
                f.write(f"  {metric}:\n")
                f.write(f"    Mean: {stats['mean']:.3f} ± {stats['std']:.3f}\n")
                f.write(f"    Median: {stats['median']:.3f}\n")
                f.write(f"    Range: {stats['min']:.3f} - {stats['max']:.3f}\n\n")
            
            # Correlations
            f.write("Correlations:\n")
            for corr_name, corr_value in metrics["correlations"].items():
                f.write(f"  {corr_name}: {corr_value:.3f}\n")
            f.write("\n")
            
            # Performance by document type
            f.write("PERFORMANCE BY DOCUMENT TYPE\n")
            f.write("-" * 30 + "\n")
            for doc_type, perf in metrics["performance_by_document_type"].items():
                f.write(f"{doc_type}: {perf['average_score']:.3f} (n={perf['count']})\n")
            f.write("\n")
            
            # Performance by user role
            f.write("PERFORMANCE BY USER ROLE\n")
            f.write("-" * 30 + "\n")
            for role, perf in metrics["performance_by_user_role"].items():
                f.write(f"{role}: {perf['average_score']:.3f} (n={perf['count']})\n")
            f.write("\n")
            
            # Citation analysis
            f.write("CITATION ANALYSIS\n")
            f.write("-" * 30 + "\n")
            citation_stats = metrics["citation_analysis"]
            f.write(f"Documents with citations: {citation_stats['documents_with_citations']}/{citation_stats['total_documents']}\n")
            f.write(f"Average citations per document: {citation_stats['avg_citations_per_doc']:.2f}\n")
            if citation_stats['most_cited_messages']:
                f.write("\nMost cited messages:\n")
                most_cited = sorted(citation_stats['most_cited_messages'].items(), 
                                  key=lambda x: x[1], reverse=True)[:5]
                for msg_id, count in most_cited:
                    f.write(f"  {msg_id}: {count} citations\n")
            f.write("\n")
            
            # Quality dimensions
            f.write("QUALITY DIMENSIONS ANALYSIS\n")
            f.write("-" * 30 + "\n")
            for dimension, stats in metrics["quality_dimension_analysis"].items():
                f.write(f"{dimension}: {stats['mean']:.3f} ± {stats['std']:.3f}\n")
        
        logger.info(f"Detailed report saved to {output_file}")


def main():
    """Main function to run analysis"""
    results_dir = "./benchmark_results"
    
    if not os.path.exists(results_dir):
        print(f"Results directory not found: {results_dir}")
        print("Please run the benchmark first using document_generation.py")
        return
    
    # Initialize analyzer
    analyzer = BenchmarkAnalyzer(results_dir)
    
    try:
        # Load results
        analyzer.load_results()
        
        # Generate advanced metrics
        metrics = analyzer.calculate_advanced_metrics()
        
        # Generate visualizations
        analyzer.generate_visualizations()
        
        # Generate detailed report
        analyzer.generate_detailed_report()
        
        print("Analysis completed successfully!")
        print(f"Results available in: {results_dir}")
        
    except Exception as e:
        print(f"Error during analysis: {e}")
        logger.error(f"Analysis error: {e}")


if __name__ == "__main__":
    main()
