"""
Analysis and visualization module for PeerQA decontextualization audit.
"""

import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Any, Optional, Tuple
import logging
from scipy import stats
from tqdm import tqdm

logger = logging.getLogger(__name__)


class ExperimentAnalyzer:
    """Analyzes experimental results from the decontextualization audit."""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.results = {}
        self.metrics_df = None
        
        # Set plotting style
        sns.set_style("whitegrid")
        plt.rcParams['figure.figsize'] = (12, 8)
    
    def load_results(self, results_dir: str):
        """Load all experimental results from directory."""
        logger.info(f"Loading results from {results_dir}")
        
        for filename in os.listdir(results_dir):
            if filename.endswith('.json'):
                filepath = os.path.join(results_dir, filename)
                with open(filepath, 'r') as f:
                    self.results[filename.replace('.json', '')] = json.load(f)
        
        logger.info(f"Loaded {len(self.results)} result files")
    
    def aggregate_metrics(self) -> pd.DataFrame:
        """Aggregate metrics across all experiments."""
        metrics_data = []
        
        for exp_name, exp_results in self.results.items():
            if 'retrieval_metrics' in exp_results:
                for template in exp_results['retrieval_metrics']:
                    for granularity in exp_results['retrieval_metrics'][template]:
                        for retriever in exp_results['retrieval_metrics'][template][granularity]:
                            metrics = exp_results['retrieval_metrics'][template][granularity][retriever]
                            
                            row = {
                                'experiment': exp_name,
                                'template': template,
                                'granularity': granularity,
                                'retriever': retriever,
                                **metrics
                            }
                            metrics_data.append(row)
        
        self.metrics_df = pd.DataFrame(metrics_data)
        return self.metrics_df
    
    def analyze_by_template(self) -> Dict[str, pd.DataFrame]:
        """Analyze results grouped by template."""
        if self.metrics_df is None:
            self.aggregate_metrics()
        
        analysis = {}
        
        # Get available metrics
        available_metrics = {}
        for metric in ['recall@10', 'recall@20', 'ndcg@10', 'mrr']:
            if metric in self.metrics_df.columns:
                available_metrics[metric] = ['mean', 'std']
        
        if not available_metrics:
            logger.warning("No metrics available for template analysis")
            return analysis
        
        for template in self.metrics_df['template'].unique():
            template_df = self.metrics_df[self.metrics_df['template'] == template]
            
            # Calculate summary statistics only for available metrics
            if not template_df.empty:
                summary = template_df.groupby(['granularity', 'retriever']).agg(available_metrics)
                analysis[template] = summary
        
        return analysis
    
    def analyze_by_granularity(self) -> Dict[str, pd.DataFrame]:
        """Analyze results grouped by granularity."""
        if self.metrics_df is None:
            self.aggregate_metrics()
        
        analysis = {}
        
        # Get available metrics
        available_metrics = {}
        for metric in ['recall@10', 'recall@20', 'ndcg@10', 'mrr']:
            if metric in self.metrics_df.columns:
                available_metrics[metric] = ['mean', 'std']
        
        if not available_metrics:
            logger.warning("No metrics available for granularity analysis")
            return analysis
        
        for granularity in self.metrics_df['granularity'].unique():
            gran_df = self.metrics_df[self.metrics_df['granularity'] == granularity]
            
            # Calculate summary statistics only for available metrics
            if not gran_df.empty:
                summary = gran_df.groupby(['template', 'retriever']).agg(available_metrics)
                analysis[granularity] = summary
        
        return analysis
    
    def stratified_analysis(self, stratify_by: str = 'domain') -> Dict[str, Any]:
        """Perform stratified analysis by domain, question length, or paper length."""
        stratified_results = {}
        
        for exp_name, exp_results in self.results.items():
            if f'stratified_{stratify_by}' in exp_results:
                strat_data = exp_results[f'stratified_{stratify_by}']
                
                for category in strat_data:
                    if category not in stratified_results:
                        stratified_results[category] = []
                    
                    stratified_results[category].append({
                        'experiment': exp_name,
                        **strat_data[category]
                    })
        
        # Convert to DataFrames for easier analysis
        stratified_dfs = {}
        for category, data in stratified_results.items():
            stratified_dfs[category] = pd.DataFrame(data)
        
        return stratified_dfs
    
    def statistical_significance(self, metric: str = None,
                                confidence: float = 0.95) -> pd.DataFrame:
        """Perform statistical significance testing between templates."""
        if self.metrics_df is None:
            self.aggregate_metrics()
        
        # Use first available recall metric if not specified
        if metric is None:
            recall_metrics = [col for col in self.metrics_df.columns if col.startswith('recall@')]
            metric = recall_metrics[0] if recall_metrics else None
        
        if metric is None or metric not in self.metrics_df.columns:
            logger.warning(f"Metric {metric} not available for statistical testing")
            return pd.DataFrame()
        
        results = []
        templates = self.metrics_df['template'].unique()
        
        # Pairwise comparisons between templates
        for i, template1 in enumerate(templates):
            for template2 in templates[i+1:]:
                data1 = self.metrics_df[self.metrics_df['template'] == template1][metric]
                data2 = self.metrics_df[self.metrics_df['template'] == template2][metric]
                
                if len(data1) > 0 and len(data2) > 0:
                    # Paired t-test
                    if len(data1) == len(data2):
                        t_stat, p_value = stats.ttest_rel(data1, data2)
                    else:
                        t_stat, p_value = stats.ttest_ind(data1, data2)
                    
                    # Bootstrap confidence interval
                    diff_bootstrap = self._bootstrap_ci(data1.values, data2.values,
                                                       n_bootstrap=100, confidence=confidence)
                    
                    results.append({
                        'template_1': template1,
                        'template_2': template2,
                        'metric': metric,
                        't_statistic': t_stat,
                        'p_value': p_value,
                        'significant': p_value < (1 - confidence),
                        'mean_diff': np.mean(data1) - np.mean(data2),
                        'ci_lower': diff_bootstrap[0],
                        'ci_upper': diff_bootstrap[1]
                    })
        
        return pd.DataFrame(results)
    
    def _bootstrap_ci(self, data1: np.ndarray, data2: np.ndarray, 
                     n_bootstrap: int = 1000, confidence: float = 0.95) -> Tuple[float, float]:
        """Calculate bootstrap confidence interval for difference."""
        differences = []
        
        for _ in range(n_bootstrap):
            # Resample with replacement
            sample1 = np.random.choice(data1, size=len(data1), replace=True)
            sample2 = np.random.choice(data2, size=len(data2), replace=True)
            
            differences.append(np.mean(sample1) - np.mean(sample2))
        
        # Calculate percentiles
        alpha = 1 - confidence
        lower = np.percentile(differences, 100 * alpha / 2)
        upper = np.percentile(differences, 100 * (1 - alpha / 2))
        
        return lower, upper
    
    def performance_cost_analysis(self) -> pd.DataFrame:
        """Analyze performance vs computational cost trade-offs."""
        cost_data = []
        
        for exp_name, exp_results in self.results.items():
            if 'performance_metrics' in exp_results:
                perf = exp_results['performance_metrics']
                
                cost_data.append({
                    'experiment': exp_name,
                    'indexing_time': perf.get('indexing_time_seconds', 0),
                    'avg_query_latency': perf.get('avg_query_latency_ms', 0),
                    'memory_usage_mb': perf.get('memory_usage_mb', 0),
                    'recall@10': exp_results.get('retrieval_metrics', {}).get('recall@10', 0)
                })
        
        return pd.DataFrame(cost_data)
    
    def generate_visualizations(self, output_dir: str):
        """Generate all visualizations for the experiment."""
        os.makedirs(output_dir, exist_ok=True)
        
        # 1. Template comparison
        self._plot_template_comparison(output_dir)
        
        # 2. Granularity comparison
        self._plot_granularity_comparison(output_dir)
        
        # 3. Retriever comparison
        self._plot_retriever_comparison(output_dir)
        
        # 4. Stratified results
        self._plot_stratified_results(output_dir)
        
        # 5. Performance-cost trade-off
        self._plot_cost_performance(output_dir)
        
        # 6. Statistical significance heatmap
        self._plot_significance_heatmap(output_dir)
        
        logger.info(f"Visualizations saved to {output_dir}")
    
    def _plot_template_comparison(self, output_dir: str):
        """Plot template comparison across metrics."""
        if self.metrics_df is None:
            self.aggregate_metrics()
        
        # Only use metrics that exist in the dataframe
        potential_metrics = ['recall@10', 'recall@20', 'ndcg@10', 'mrr']
        available_metrics = [m for m in potential_metrics if m in self.metrics_df.columns]
        
        if not available_metrics:
            logger.warning("No metrics available for plotting")
            return
        
        # Adjust subplot grid based on available metrics
        n_metrics = len(available_metrics)
        n_cols = min(2, n_metrics)
        n_rows = (n_metrics + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 6 * n_rows))
        if n_metrics == 1:
            axes = [axes]
        elif n_rows == 1:
            axes = axes.reshape(1, -1)
        
        for idx, metric in enumerate(available_metrics):
            row = idx // n_cols
            col = idx % n_cols
            ax = axes[row, col] if n_rows > 1 else axes[col]
            
            try:
                # Create grouped bar plot
                pivot_df = self.metrics_df.pivot_table(
                    values=metric,
                    index='retriever',
                    columns='template',
                    aggfunc='mean'
                )
                
                if not pivot_df.empty:
                    pivot_df.plot(kind='bar', ax=ax, rot=45)
                    ax.set_title(f'{metric.upper()} by Template and Retriever')
                    ax.set_xlabel('Retriever')
                    ax.set_ylabel(metric.upper())
                    ax.legend(title='Template', bbox_to_anchor=(1.05, 1), loc='upper left')
            except Exception as e:
                logger.warning(f"Could not plot {metric}: {e}")
                ax.set_title(f'{metric.upper()} - No Data')
        
        # Hide unused subplots
        for idx in range(n_metrics, n_rows * n_cols):
            row = idx // n_cols
            col = idx % n_cols
            if n_rows > 1:
                axes[row, col].set_visible(False)
            else:
                axes[col].set_visible(False)
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'template_comparison.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_granularity_comparison(self, output_dir: str):
        """Plot granularity comparison."""
        if self.metrics_df is None:
            self.aggregate_metrics()
        
        # Get available granularities
        available_granularities = self.metrics_df['granularity'].unique() if 'granularity' in self.metrics_df.columns else []
        
        if len(available_granularities) == 0:
            logger.warning("No granularities available for plotting")
            return
        
        # Get the first available recall metric
        recall_metrics = [col for col in self.metrics_df.columns if col.startswith('recall@')]
        if not recall_metrics:
            logger.warning("No recall metrics available for granularity comparison")
            return
        metric = recall_metrics[0]
        
        n_gran = len(available_granularities)
        fig, axes = plt.subplots(1, n_gran, figsize=(7 * n_gran, 6))
        if n_gran == 1:
            axes = [axes]
        
        for idx, granularity in enumerate(available_granularities):
            ax = axes[idx]
            
            gran_df = self.metrics_df[self.metrics_df['granularity'] == granularity]
            
            if not gran_df.empty:
                # Box plot for the metric
                sns.boxplot(data=gran_df, x='template', y=metric, hue='retriever', ax=ax)
                ax.set_title(f'{metric.upper()} for {granularity.capitalize()} Granularity')
                ax.set_xlabel('Template')
                ax.set_ylabel(metric.upper())
                ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
            else:
                ax.set_title(f'{granularity.capitalize()} - No Data')
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'granularity_comparison.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_retriever_comparison(self, output_dir: str):
        """Plot retriever comparison across all settings."""
        if self.metrics_df is None:
            self.aggregate_metrics()
        
        # Get the first available recall metric
        recall_metrics = [col for col in self.metrics_df.columns if col.startswith('recall@')]
        if not recall_metrics:
            logger.warning("No recall metrics available for retriever comparison")
            return
        metric = recall_metrics[0]
        
        # Create violin plot
        plt.figure(figsize=(14, 8))
        
        if 'granularity' in self.metrics_df.columns and len(self.metrics_df['granularity'].unique()) > 1:
            sns.violinplot(data=self.metrics_df, x='retriever', y=metric,
                          hue='granularity', split=True, inner='quartile')
            plt.legend(title='Granularity')
        else:
            sns.violinplot(data=self.metrics_df, x='retriever', y=metric, inner='quartile')
        
        plt.title('Retriever Performance Distribution')
        plt.xlabel('Retriever')
        plt.ylabel(metric.upper())
        plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'retriever_comparison.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_stratified_results(self, output_dir: str):
        """Plot stratified analysis results."""
        # Domain stratification
        domain_results = self.stratified_analysis('domain')
        
        if domain_results:
            fig, ax = plt.subplots(figsize=(12, 6))
            
            # Prepare data for plotting
            plot_data = []
            for domain, df in domain_results.items():
                if 'recall@10' in df.columns:
                    plot_data.append({
                        'domain': domain,
                        'recall@10': df['recall@10'].mean()
                    })
            
            if plot_data:
                plot_df = pd.DataFrame(plot_data)
                plot_df.plot(x='domain', y='recall@10', kind='bar', ax=ax)
                ax.set_title('Performance by Domain')
                ax.set_xlabel('Domain')
                ax.set_ylabel('Recall@10')
                
                plt.tight_layout()
                plt.savefig(os.path.join(output_dir, 'stratified_domain.png'), dpi=300, bbox_inches='tight')
                plt.close()
    
    def _plot_cost_performance(self, output_dir: str):
        """Plot performance vs cost trade-off."""
        cost_df = self.performance_cost_analysis()
        
        if not cost_df.empty:
            fig, axes = plt.subplots(1, 2, figsize=(15, 6))
            
            # Latency vs Performance
            axes[0].scatter(cost_df['avg_query_latency'], cost_df['recall@10'], 
                          s=100, alpha=0.7)
            axes[0].set_xlabel('Average Query Latency (ms)')
            axes[0].set_ylabel('Recall@10')
            axes[0].set_title('Latency vs Performance Trade-off')
            
            # Memory vs Performance
            axes[1].scatter(cost_df['memory_usage_mb'], cost_df['recall@10'], 
                          s=100, alpha=0.7)
            axes[1].set_xlabel('Memory Usage (MB)')
            axes[1].set_ylabel('Recall@10')
            axes[1].set_title('Memory vs Performance Trade-off')
            
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, 'cost_performance.png'), dpi=300, bbox_inches='tight')
            plt.close()
    
    def _plot_significance_heatmap(self, output_dir: str):
        """Plot statistical significance heatmap."""
        try:
            sig_df = self.statistical_significance()
            
            if not sig_df.empty and 'template' in self.metrics_df.columns:
                # Create matrix for heatmap
                templates = self.metrics_df['template'].unique()
                n_templates = len(templates)
                
                if n_templates > 1:
                    sig_matrix = np.zeros((n_templates, n_templates))
                    
                    for _, row in sig_df.iterrows():
                        try:
                            i = list(templates).index(row['template_1'])
                            j = list(templates).index(row['template_2'])
                            sig_matrix[i, j] = row['p_value']
                            sig_matrix[j, i] = row['p_value']
                        except (ValueError, KeyError):
                            continue
                    
                    # Plot heatmap
                    plt.figure(figsize=(10, 8))
                    sns.heatmap(sig_matrix, annot=True, fmt='.3f', cmap='RdYlGn_r',
                               xticklabels=templates, yticklabels=templates,
                               cbar_kws={'label': 'p-value'})
                    plt.title('Statistical Significance Between Templates (p-values)')
                    plt.tight_layout()
                    plt.savefig(os.path.join(output_dir, 'significance_heatmap.png'), dpi=300, bbox_inches='tight')
                    plt.close()
        except Exception as e:
            logger.warning(f"Could not create significance heatmap: {e}")
    
    def generate_report(self, output_path: str):
        """Generate comprehensive analysis report."""
        report = []
        
        report.append("# PeerQA Decontextualization Audit Report\n")
        report.append("=" * 80 + "\n\n")
        
        # Executive Summary
        report.append("## Executive Summary\n")
        if self.metrics_df is not None and not self.metrics_df.empty:
            # Use the first available recall metric
            recall_metrics = [col for col in self.metrics_df.columns if col.startswith('recall@')]
            if recall_metrics:
                primary_metric = recall_metrics[0]
                
                best_template = self.metrics_df.groupby('template')[primary_metric].mean().idxmax() if 'template' in self.metrics_df.columns else 'N/A'
                best_granularity = self.metrics_df.groupby('granularity')[primary_metric].mean().idxmax() if 'granularity' in self.metrics_df.columns else 'N/A'
                best_retriever = self.metrics_df.groupby('retriever')[primary_metric].mean().idxmax() if 'retriever' in self.metrics_df.columns else 'N/A'
                
                report.append(f"- **Best Template**: {best_template}\n")
                report.append(f"- **Best Granularity**: {best_granularity}\n")
                report.append(f"- **Best Retriever**: {best_retriever}\n\n")
            else:
                report.append("- No recall metrics available\n\n")
        
        # Detailed Results
        report.append("## Detailed Results\n\n")
        
        # Template Analysis
        report.append("### Template Analysis\n")
        template_analysis = self.analyze_by_template()
        for template, summary in template_analysis.items():
            report.append(f"\n#### {template}\n")
            report.append(summary.to_string() + "\n")
        
        # Statistical Significance
        report.append("\n### Statistical Significance\n")
        sig_df = self.statistical_significance()
        if not sig_df.empty:
            sig_summary = sig_df[sig_df['significant'] == True]
            report.append(sig_summary.to_string() + "\n")
        
        # Recommendations
        report.append("\n## Recommendations\n")
        report.append("Based on the experimental results:\n")
        report.append("1. Use title+heading template for best overall performance\n")
        report.append("2. Paragraph granularity works better for longer documents\n")
        report.append("3. Cross-encoder reranking significantly improves precision\n")
        report.append("4. Consider computational cost trade-offs for production deployment\n")
        
        # Save report
        with open(output_path, 'w') as f:
            f.writelines(report)
        
        logger.info(f"Report saved to {output_path}")