#!/usr/bin/env python3
"""
Enhanced visualization utilities for PINN experiments with statistical information.
"""

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from typing import Dict, List, Any, Optional, Tuple
from pathlib import Path
import warnings

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


class EnhancedVisualizer:
    """Enhanced visualization with statistical information."""
    
    def __init__(self, figsize: Tuple[int, int] = (12, 8), dpi: int = 300):
        """
        Initialize enhanced visualizer.
        
        Args:
            figsize: Default figure size
            dpi: Figure DPI
        """
        self.figsize = figsize
        self.dpi = dpi
        self.colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    
    def plot_performance_comparison_with_stats(
        self, 
        results: Dict[str, Any], 
        statistical_analysis: Dict[str, Any],
        save_path: Optional[Path] = None
    ) -> None:
        """
        Plot performance comparison with statistical information.
        
        Args:
            results: Experimental results
            statistical_analysis: Statistical analysis results
            save_path: Path to save the plot
        """
        problems = list(results.keys())
        n_problems = len(problems)
        
        fig, axes = plt.subplots(2, n_problems, figsize=(self.figsize[0], self.figsize[1]))
        if n_problems == 1:
            axes = axes.reshape(2, 1)
        
        for i, problem in enumerate(problems):
            # Performance comparison (top row)
            ax_perf = axes[0, i]
            self._plot_problem_performance(ax_perf, problem, results[problem], statistical_analysis.get(problem, {}))
            
            # Training time comparison (bottom row)
            ax_time = axes[1, i]
            self._plot_problem_training_time(ax_time, problem, results[problem], statistical_analysis.get(problem, {}))
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        plt.close()
    
    def _plot_problem_performance(
        self, 
        ax: plt.Axes, 
        problem: str, 
        problem_results: Dict[str, List], 
        analysis: Dict[str, Any]
    ) -> None:
        """Plot performance comparison for a specific problem."""
        # Prepare data
        data_for_plot = []
        labels = []
        colors = []
        
        methods = ['standard', 'rpit', 'bayesian']
        method_colors = {'standard': self.colors[0], 'rpit': self.colors[1], 'bayesian': self.colors[2]}
        
        for method in methods:
            if method in problem_results and problem_results[method]:
                losses = [r['final_train_loss'] for r in problem_results[method]]
                data_for_plot.append(losses)
                labels.append(method.upper())
                colors.append(method_colors[method])
        
        if not data_for_plot:
            ax.text(0.5, 0.5, 'No data available', ha='center', va='center', transform=ax.transAxes)
            ax.set_title(f'{problem.replace("_", " ").title()} Problem')
            return
        
        # Create box plot with enhanced styling
        bp = ax.boxplot(data_for_plot, tick_labels=labels, patch_artist=True, 
                       showmeans=True, meanline=True, showfliers=True)
        
        # Color the boxes
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
        
        # Add statistical annotations
        if 'confidence_intervals' in analysis:
            self._add_confidence_intervals(ax, data_for_plot, labels, analysis['confidence_intervals'])
        
        # Add significance indicators
        if 'pairwise_tests' in analysis:
            self._add_significance_indicators(ax, data_for_plot, labels, analysis['pairwise_tests'])
        
        ax.set_title(f'{problem.replace("_", " ").title()} Problem', fontsize=14, fontweight='bold')
        ax.set_ylabel('Final Training Loss', fontsize=12)
        ax.set_yscale('log')
        ax.grid(True, alpha=0.3)
        
        # Rotate x-axis labels if needed
        if len(max(labels, key=len)) > 6:
            ax.tick_params(axis='x', rotation=45)
    
    def _plot_problem_training_time(
        self, 
        ax: plt.Axes, 
        problem: str, 
        problem_results: Dict[str, List], 
        analysis: Dict[str, Any]
    ) -> None:
        """Plot training time comparison for a specific problem."""
        # Prepare data
        data_for_plot = []
        labels = []
        colors = []
        
        methods = ['standard', 'rpit', 'bayesian']
        method_colors = {'standard': self.colors[0], 'rpit': self.colors[1], 'bayesian': self.colors[2]}
        
        for method in methods:
            if method in problem_results and problem_results[method]:
                times = [r['training_time'] for r in problem_results[method]]
                data_for_plot.append(times)
                labels.append(method.upper())
                colors.append(method_colors[method])
        
        if not data_for_plot:
            ax.text(0.5, 0.5, 'No data available', ha='center', va='center', transform=ax.transAxes)
            ax.set_title(f'{problem.replace("_", " ").title()} Training Time')
            return
        
        # Create box plot
        bp = ax.boxplot(data_for_plot, tick_labels=labels, patch_artist=True, 
                       showmeans=True, meanline=True, showfliers=True)
        
        # Color the boxes
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
        
        # Add confidence intervals
        if 'confidence_intervals' in analysis:
            self._add_confidence_intervals(ax, data_for_plot, labels, analysis['confidence_intervals'], metric='times')
        
        ax.set_title(f'{problem.replace("_", " ").title()} Training Time', fontsize=14, fontweight='bold')
        ax.set_ylabel('Training Time (seconds)', fontsize=12)
        ax.grid(True, alpha=0.3)
        
        # Rotate x-axis labels if needed
        if len(max(labels, key=len)) > 6:
            ax.tick_params(axis='x', rotation=45)
    
    def _add_confidence_intervals(
        self, 
        ax: plt.Axes, 
        data: List[List[float]], 
        labels: List[str], 
        ci_data: Dict[str, Any],
        metric: str = 'losses'
    ) -> None:
        """Add confidence intervals to the plot."""
        for i, (label, values) in enumerate(zip(labels, data)):
            if label.lower() in ci_data and metric in ci_data[label.lower()]:
                ci_info = ci_data[label.lower()][metric]
                x_pos = i + 1
                
                # Add confidence interval as error bars
                mean_val = ci_info['mean']
                lower = ci_info['ci_lower']
                upper = ci_info['ci_upper']
                
                ax.errorbar(x_pos, mean_val, yerr=[[mean_val - lower], [upper - mean_val]], 
                           fmt='none', color='black', capsize=5, capthick=2, alpha=0.7)
    
    def _add_significance_indicators(
        self, 
        ax: plt.Axes, 
        data: List[List[float]], 
        labels: List[str], 
        pairwise_tests: Dict[str, Any]
    ) -> None:
        """Add significance indicators between groups."""
        # Find significant pairs
        significant_pairs = []
        for pair_name, test_result in pairwise_tests.items():
            if test_result['losses']['significant']:
                method1, method2 = pair_name.split('_vs_')
                if method1.upper() in labels and method2.upper() in labels:
                    idx1 = labels.index(method1.upper())
                    idx2 = labels.index(method2.upper())
                    significant_pairs.append((idx1, idx2, test_result['losses']['p_value']))
        
        # Add significance lines
        if significant_pairs:
            y_max = max(max(values) for values in data)
            y_range = y_max - min(min(values) for values in data)
            
            for i, (idx1, idx2, p_val) in enumerate(significant_pairs):
                y_pos = y_max + (i + 1) * 0.05 * y_range
                
                # Draw line
                ax.plot([idx1 + 1, idx2 + 1], [y_pos, y_pos], 'k-', linewidth=1)
                
                # Add significance level
                if p_val < 0.001:
                    sig_text = '***'
                elif p_val < 0.01:
                    sig_text = '**'
                elif p_val < 0.05:
                    sig_text = '*'
                else:
                    sig_text = 'ns'
                
                ax.text((idx1 + idx2 + 2) / 2, y_pos, sig_text, ha='center', va='bottom', fontsize=10)
    
    def plot_statistical_summary_table(
        self, 
        statistical_analysis: Dict[str, Any], 
        save_path: Optional[Path] = None
    ) -> None:
        """
        Create a comprehensive statistical summary table.
        
        Args:
            statistical_analysis: Statistical analysis results
            save_path: Path to save the plot
        """
        fig, ax = plt.subplots(figsize=(16, 10))
        
        # Prepare table data
        table_data = []
        headers = ['Problem', 'Method', 'Mean Loss', 'CI (95%)', 'Std Dev', 'n', 'Rank']
        
        for problem, analysis in statistical_analysis.items():
            if problem == 'cross_problem':
                continue
                
            if 'descriptive_stats' in analysis and 'confidence_intervals' in analysis:
                # Sort methods by performance
                method_means = {}
                for method, stats in analysis['descriptive_stats'].items():
                    method_means[method] = stats['losses']['mean']
                
                sorted_methods = sorted(method_means.items(), key=lambda x: x[1])
                
                for rank, (method, mean_loss) in enumerate(sorted_methods, 1):
                    desc_stats = analysis['descriptive_stats'][method]['losses']
                    ci_stats = analysis['confidence_intervals'][method]['losses']
                    
                    ci_str = f"[{ci_stats['ci_lower']:.3f}, {ci_stats['ci_upper']:.3f}]"
                    
                    table_data.append([
                        problem.replace('_', ' ').title(),
                        method.upper(),
                        f"{desc_stats['mean']:.6f}",
                        ci_str,
                        f"{desc_stats['std']:.6f}",
                        str(desc_stats['n']),
                        str(rank)
                    ])
        
        # Create table
        if table_data:
            table = ax.table(cellText=table_data, colLabels=headers, 
                           cellLoc='center', loc='center')
            table.auto_set_font_size(False)
            table.set_fontsize(9)
            table.scale(1.2, 1.5)
            
            # Style the table
            for i in range(len(headers)):
                table[(0, i)].set_facecolor('#40466e')
                table[(0, i)].set_text_props(weight='bold', color='white')
            
            # Color code by rank
            for i, row in enumerate(table_data):
                rank = int(row[-1])
                if rank == 1:
                    color = '#d4edda'  # Light green for best
                elif rank == 2:
                    color = '#fff3cd'  # Light yellow for second
                else:
                    color = '#f8d7da'  # Light red for others
                
                for j in range(len(headers)):
                    table[(i+1, j)].set_facecolor(color)
        else:
            ax.text(0.5, 0.5, 'No statistical data available', 
                   ha='center', va='center', fontsize=14, 
                   bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))
        
        ax.axis('off')
        ax.set_title('Statistical Summary of All Experiments', fontsize=16, fontweight='bold', pad=20)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        plt.close()
    
    def plot_effect_sizes(
        self, 
        statistical_analysis: Dict[str, Any], 
        save_path: Optional[Path] = None
    ) -> None:
        """
        Plot effect sizes between methods.
        
        Args:
            statistical_analysis: Statistical analysis results
            save_path: Path to save the plot
        """
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        problems = [p for p in statistical_analysis.keys() if p != 'cross_problem']
        
        for i, problem in enumerate(problems):
            ax = axes[i]
            
            if 'effect_sizes' in statistical_analysis[problem]:
                effect_sizes = statistical_analysis[problem]['effect_sizes']
                
                # Prepare data for heatmap
                methods = ['standard', 'rpit', 'bayesian']
                effect_matrix = np.zeros((len(methods), len(methods)))
                
                for pair_name, effect_data in effect_sizes.items():
                    method1, method2 = pair_name.split('_vs_')
                    if method1 in methods and method2 in methods:
                        idx1 = methods.index(method1)
                        idx2 = methods.index(method2)
                        effect_matrix[idx1, idx2] = effect_data['losses']
                        effect_matrix[idx2, idx1] = -effect_data['losses']  # Symmetric
                
                # Create heatmap
                im = ax.imshow(effect_matrix, cmap='RdBu_r', vmin=-2, vmax=2)
                
                # Add text annotations
                for j in range(len(methods)):
                    for k in range(len(methods)):
                        if j != k:
                            text = ax.text(k, j, f'{effect_matrix[j, k]:.2f}',
                                         ha="center", va="center", color="black", fontweight='bold')
                
                ax.set_xticks(range(len(methods)))
                ax.set_yticks(range(len(methods)))
                ax.set_xticklabels([m.upper() for m in methods])
                ax.set_yticklabels([m.upper() for m in methods])
                ax.set_title(f'{problem.replace("_", " ").title()} Effect Sizes', fontweight='bold')
                
                # Add colorbar
                cbar = plt.colorbar(im, ax=ax)
                cbar.set_label("Cohen's d", rotation=270, labelpad=20)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        plt.close()
    
    def plot_overall_performance_radar(
        self, 
        statistical_analysis: Dict[str, Any], 
        save_path: Optional[Path] = None
    ) -> None:
        """
        Create a radar chart showing overall performance across problems.
        
        Args:
            statistical_analysis: Statistical analysis results
            save_path: Path to save the plot
        """
        if 'cross_problem' not in statistical_analysis:
            return
        
        # Prepare data
        problems = [p for p in statistical_analysis.keys() if p != 'cross_problem']
        methods = ['standard', 'rpit', 'bayesian']
        
        # Calculate normalized performance (lower is better, so invert)
        performance_data = {}
        for method in methods:
            performance_data[method] = []
            for problem in problems:
                if problem in statistical_analysis and 'descriptive_stats' in statistical_analysis[problem]:
                    if method in statistical_analysis[problem]['descriptive_stats']:
                        mean_loss = statistical_analysis[problem]['descriptive_stats'][method]['losses']['mean']
                        performance_data[method].append(1.0 / (1.0 + mean_loss))  # Normalize and invert
                    else:
                        performance_data[method].append(0.0)
                else:
                    performance_data[method].append(0.0)
        
        # Create radar chart
        fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))
        
        angles = np.linspace(0, 2 * np.pi, len(problems), endpoint=False).tolist()
        angles += angles[:1]  # Complete the circle
        
        for i, method in enumerate(methods):
            values = performance_data[method] + performance_data[method][:1]  # Complete the circle
            ax.plot(angles, values, 'o-', linewidth=2, label=method.upper(), color=self.colors[i])
            ax.fill(angles, values, alpha=0.25, color=self.colors[i])
        
        ax.set_xticks(angles[:-1])
        ax.set_xticklabels([p.replace('_', ' ').title() for p in problems])
        ax.set_ylim(0, 1)
        ax.set_title('Overall Performance Across Problems\n(Higher is Better)', 
                    size=16, fontweight='bold', pad=20)
        ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
        ax.grid(True)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        plt.close()
