"""
Visualization Toolkit for Entropy Analysis

Publication-quality plotting utilities for entropy-based early stopping experiments.
Supports creation of figures matching the paper's visual style.
"""

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Dict, Optional, Tuple
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set publication-quality style
plt.rcParams.update({
    'font.size': 12,
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 11,
    'figure.titlesize': 18,
    'text.usetex': False,  # Set to True if you have LaTeX installed
    'figure.figsize': (8, 6),
    'axes.linewidth': 1.2,
    'grid.linewidth': 0.5,
    'lines.linewidth': 2,
    'patch.linewidth': 1.2
})

# Professional color palette
COLORS = {
    'primary': '#1f77b4',    # Blue
    'secondary': '#ff7f0e',  # Orange  
    'tertiary': '#2ca02c',   # Green
    'quaternary': '#d62728', # Red
    'correct': '#2E86AB',    # Professional blue
    'incorrect': '#A23B72',  # Professional magenta
    'accent': '#F18F01',     # Professional orange
    'dark': '#2F2F2F'        # Dark gray
}

class EntropyVisualization:
    """
    Main visualization class for entropy analysis results.
    
    Creates publication-quality plots for entropy distributions, 
    threshold analysis, token savings, and performance metrics.
    """
    
    def __init__(self, figsize: Tuple[int, int] = (10, 6)):
        """
        Initialize visualization toolkit.
        
        Args:
            figsize: Default figure size for plots
        """
        self.figsize = figsize
        self.colors = COLORS
        
    def plot_entropy_distributions(self, 
                                 correct_entropies: List[float],
                                 incorrect_entropies: List[float],
                                 threshold: Optional[float] = None,
                                 title: str = "Entropy Distributions",
                                 save_path: Optional[str] = None) -> plt.Figure:
        """
        Plot entropy distributions for correct vs incorrect answers.
        
        Args:
            correct_entropies: Entropy values for correct predictions
            incorrect_entropies: Entropy values for incorrect predictions
            threshold: Early stopping threshold to highlight
            title: Plot title
            save_path: Path to save the figure
            
        Returns:
            Matplotlib figure object
        """
        fig, ax = plt.subplots(figsize=self.figsize)
        
        # Calculate statistics
        correct_mean = np.mean(correct_entropies)
        incorrect_mean = np.mean(incorrect_entropies)
        
        # Create histograms
        all_entropies = correct_entropies + incorrect_entropies
        bins = np.linspace(min(all_entropies), max(all_entropies), 30)
        
        ax.hist(correct_entropies, bins=bins, alpha=0.7, label='Correct Answers', 
                color=self.colors['correct'], density=True, edgecolor='black', linewidth=0.5)
        ax.hist(incorrect_entropies, bins=bins, alpha=0.7, label='Incorrect Answers', 
                color=self.colors['incorrect'], density=True, edgecolor='black', linewidth=0.5)
        
        # Add mean lines
        ax.axvline(correct_mean, color=self.colors['correct'], linestyle='--', linewidth=3, 
                   label=f'Correct Mean ({correct_mean:.3f})')
        ax.axvline(incorrect_mean, color=self.colors['incorrect'], linestyle='--', linewidth=3, 
                   label=f'Incorrect Mean ({incorrect_mean:.3f})')
        
        # Add threshold line if provided
        if threshold is not None:
            ax.axvline(threshold, color='black', linestyle=':', linewidth=2, alpha=0.8,
                       label=f'Threshold ({threshold:.3f})')
        
        # Calculate Cohen's d
        pooled_std = np.sqrt((np.var(correct_entropies) + np.var(incorrect_entropies)) / 2)
        cohens_d = abs(incorrect_mean - correct_mean) / pooled_std if pooled_std > 0 else 0
        
        ax.set_xlabel('Entropy (bits)', fontsize=14, fontweight='bold')
        ax.set_ylabel('Density', fontsize=14, fontweight='bold')
        ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
        
        ax.legend(loc='upper right', frameon=True, fancybox=True, shadow=True)
        ax.grid(alpha=0.3, linestyle='--')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        # Add statistics box
        stats_text = f'n_correct = {len(correct_entropies)}\nn_incorrect = {len(incorrect_entropies)}\nCohen\'s d = {cohens_d:.3f}'
        ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, fontsize=11,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path + '.pdf', dpi=300, bbox_inches='tight')
            plt.savefig(save_path + '.png', dpi=300, bbox_inches='tight')
        
        return fig
    
    def plot_token_savings(self, 
                          model_results: Dict[str, Dict],
                          title: str = "Token Savings by Model",
                          save_path: Optional[str] = None) -> plt.Figure:
        """
        Plot token savings achieved by different models/methods.
        
        Args:
            model_results: Dict mapping model names to savings percentages
            title: Plot title
            save_path: Path to save the figure
            
        Returns:
            Matplotlib figure object
        """
        fig, ax = plt.subplots(figsize=self.figsize)
        
        models = list(model_results.keys())
        savings = [model_results[model]['token_savings'] for model in models]
        
        # Create bars with different colors
        colors = [self.colors['primary'], self.colors['secondary'], 
                 self.colors['tertiary'], self.colors['quaternary']][:len(models)]
        
        bars = ax.bar(models, savings, color=colors, alpha=0.8, 
                      edgecolor='black', linewidth=1)
        
        # Add value labels on bars
        for bar, value in zip(bars, savings):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                    f'{value:.1f}%', ha='center', va='bottom', 
                    fontweight='bold', fontsize=12)
        
        ax.set_ylabel('Token Savings (%)', fontsize=14, fontweight='bold')
        ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
        
        # Rotate x-axis labels if needed
        if len(max(models, key=len)) > 10:
            plt.xticks(rotation=45, ha='right')
        
        ax.set_ylim(0, max(savings) + 5)
        ax.grid(axis='y', alpha=0.3, linestyle='--')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path + '.pdf', dpi=300, bbox_inches='tight')
            plt.savefig(save_path + '.png', dpi=300, bbox_inches='tight')
        
        return fig
    
    def plot_threshold_analysis(self,
                               threshold_results: Dict[str, Dict],
                               title: str = "Threshold Method Comparison",
                               save_path: Optional[str] = None) -> plt.Figure:
        """
        Plot comparison of different threshold methods.
        
        Args:
            threshold_results: Dict mapping method names to performance metrics
            title: Plot title  
            save_path: Path to save the figure
            
        Returns:
            Matplotlib figure object
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        methods = list(threshold_results.keys())
        
        # Plot 1: Token Savings vs Threshold Accuracy
        savings = [threshold_results[method]['token_savings'] for method in methods]
        accuracy = [threshold_results[method]['threshold_accuracy'] for method in methods]
        
        colors = [self.colors['primary'], self.colors['secondary'], 
                 self.colors['tertiary'], self.colors['quaternary']][:len(methods)]
        
        for i, method in enumerate(methods):
            ax1.scatter(savings[i], accuracy[i], s=150, color=colors[i], 
                       alpha=0.8, edgecolor='black', linewidth=2, label=method)
        
        ax1.set_xlabel('Token Savings (%)', fontsize=14, fontweight='bold')
        ax1.set_ylabel('Threshold Accuracy (%)', fontsize=14, fontweight='bold')
        ax1.set_title('Efficiency vs Accuracy Trade-off', fontsize=14, fontweight='bold')
        ax1.legend(frameon=True, fancybox=True, shadow=True)
        ax1.grid(alpha=0.3, linestyle='--')
        ax1.spines['top'].set_visible(False)
        ax1.spines['right'].set_visible(False)
        
        # Plot 2: Threshold Values
        thresholds = [threshold_results[method]['threshold'] for method in methods]
        
        bars = ax2.bar(methods, thresholds, color=colors, alpha=0.8, 
                       edgecolor='black', linewidth=1)
        
        # Add value labels
        for bar, value in zip(bars, thresholds):
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height + 0.005,
                     f'{value:.3f}', ha='center', va='bottom', 
                     fontweight='bold', fontsize=11)
        
        ax2.set_ylabel('Threshold Value (bits)', fontsize=14, fontweight='bold')
        ax2.set_title('Threshold Values by Method', fontsize=14, fontweight='bold')
        
        if len(max(methods, key=len)) > 8:
            plt.sca(ax2)
            plt.xticks(rotation=45, ha='right')
        
        ax2.grid(axis='y', alpha=0.3, linestyle='--')
        ax2.spines['top'].set_visible(False)
        ax2.spines['right'].set_visible(False)
        
        plt.suptitle(title, fontsize=16, fontweight='bold', y=1.02)
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path + '.pdf', dpi=300, bbox_inches='tight')
            plt.savefig(save_path + '.png', dpi=300, bbox_inches='tight')
        
        return fig
    
    def plot_cohens_d_comparison(self,
                                model_results: Dict[str, float],
                                title: str = "Effect Size Comparison",
                                save_path: Optional[str] = None) -> plt.Figure:
        """
        Plot Cohen's d effect sizes across models/datasets.
        
        Args:
            model_results: Dict mapping model names to Cohen's d values
            title: Plot title
            save_path: Path to save the figure
            
        Returns:
            Matplotlib figure object
        """
        fig, ax = plt.subplots(figsize=self.figsize)
        
        models = list(model_results.keys())
        cohens_d_values = list(model_results.values())
        
        # Color bars based on effect size magnitude
        colors = []
        for d in cohens_d_values:
            if d < 0.2:
                colors.append('#cccccc')  # Light gray for negligible
            elif d < 0.5:
                colors.append('#ffcc00')  # Yellow for small
            elif d < 0.8:
                colors.append('#ff8800')  # Orange for medium
            else:
                colors.append('#ff4444')  # Red for large
        
        bars = ax.bar(models, cohens_d_values, color=colors, alpha=0.8, 
                      edgecolor='black', linewidth=1)
        
        # Add value labels
        for bar, value in zip(bars, cohens_d_values):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{value:.3f}', ha='center', va='bottom', 
                    fontweight='bold', fontsize=12)
        
        # Add horizontal lines for effect size interpretation
        ax.axhline(y=0.2, color='gray', linestyle='--', alpha=0.7, linewidth=1.5)
        ax.axhline(y=0.5, color='orange', linestyle='--', alpha=0.7, linewidth=1.5)
        ax.axhline(y=0.8, color='red', linestyle='--', alpha=0.7, linewidth=1.5)
        
        # Add interpretation labels
        ax.text(len(models) - 0.4, 0.2, 'Small', fontsize=10, va='center', color='gray')
        ax.text(len(models) - 0.4, 0.5, 'Medium', fontsize=10, va='center', color='orange')
        ax.text(len(models) - 0.4, 0.8, 'Large', fontsize=10, va='center', color='red')
        
        ax.set_ylabel("Cohen's d Effect Size", fontsize=14, fontweight='bold')
        ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
        
        if len(max(models, key=len)) > 10:
            plt.xticks(rotation=45, ha='right')
        
        ax.set_ylim(0, max(cohens_d_values) + 0.1)
        ax.grid(axis='y', alpha=0.3, linestyle='--')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path + '.pdf', dpi=300, bbox_inches='tight')
            plt.savefig(save_path + '.png', dpi=300, bbox_inches='tight')
        
        return fig
    
    def plot_accuracy_breakdown(self,
                               results: Dict[str, Dict],
                               title: str = "Accuracy Breakdown by Model",
                               save_path: Optional[str] = None) -> plt.Figure:
        """
        Plot step-1 vs final accuracy breakdown.
        
        Args:
            results: Dict mapping model names to accuracy metrics
            title: Plot title
            save_path: Path to save the figure
            
        Returns:
            Matplotlib figure object
        """
        fig, ax = plt.subplots(figsize=self.figsize)
        
        models = list(results.keys())
        step1_acc = [results[model]['step1_accuracy'] for model in models]
        final_acc = [results[model]['final_accuracy'] for model in models]
        thresh_acc = [results[model]['threshold_accuracy'] for model in models]
        
        x = np.arange(len(models))
        width = 0.25
        
        bars1 = ax.bar(x - width, step1_acc, width, label='Step-1 Accuracy',
                       color=self.colors['primary'], alpha=0.8, edgecolor='black')
        bars2 = ax.bar(x, final_acc, width, label='4-Step Accuracy',
                       color=self.colors['secondary'], alpha=0.8, edgecolor='black')
        bars3 = ax.bar(x + width, thresh_acc, width, label='Threshold Accuracy',
                       color=self.colors['tertiary'], alpha=0.8, edgecolor='black')
        
        # Add value labels
        for bars in [bars1, bars2, bars3]:
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                        f'{height:.1f}%', ha='center', va='bottom', 
                        fontsize=10, fontweight='bold')
        
        ax.set_xlabel('Model', fontsize=14, fontweight='bold')
        ax.set_ylabel('Accuracy (%)', fontsize=14, fontweight='bold')
        ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
        ax.set_xticks(x)
        ax.set_xticklabels(models)
        
        if len(max(models, key=len)) > 12:
            plt.xticks(rotation=45, ha='right')
        
        ax.legend(loc='upper right', frameon=True, fancybox=True, shadow=True)
        ax.grid(axis='y', alpha=0.3, linestyle='--')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path + '.pdf', dpi=300, bbox_inches='tight')
            plt.savefig(save_path + '.png', dpi=300, bbox_inches='tight')
        
        return fig
    
    def create_framework_overview(self,
                                 save_path: Optional[str] = None) -> plt.Figure:
        """
        Create a framework overview diagram.
        
        Args:
            save_path: Path to save the figure
            
        Returns:
            Matplotlib figure object
        """
        fig, ax = plt.subplots(figsize=(12, 8))
        
        # Create boxes for each component
        components = [
            {"name": "Input Problem", "pos": (1, 6), "color": self.colors['primary']},
            {"name": "LLM Reasoning\n(Step 1)", "pos": (4, 6), "color": self.colors['secondary']},
            {"name": "Extract\nTop-k Logprobs", "pos": (7, 6), "color": self.colors['tertiary']},
            {"name": "Calculate\nShannon Entropy", "pos": (10, 6), "color": self.colors['quaternary']},
            {"name": "Entropy < Threshold?", "pos": (6, 3), "color": self.colors['accent']},
            {"name": "Early Stop\n(Save Tokens)", "pos": (3, 1), "color": self.colors['correct']},
            {"name": "Continue\nReasoning", "pos": (9, 1), "color": self.colors['incorrect']}
        ]
        
        # Draw components
        for comp in components:
            if "?" in comp["name"]:
                # Diamond shape for decision
                diamond = patches.RegularPolygon(comp["pos"], 4, radius=0.8,
                                               orientation=np.pi/4, 
                                               facecolor=comp["color"], 
                                               edgecolor='black', linewidth=2, alpha=0.8)
                ax.add_patch(diamond)
            else:
                # Rectangle for process
                rect = Rectangle((comp["pos"][0]-0.8, comp["pos"][1]-0.4), 1.6, 0.8,
                               facecolor=comp["color"], edgecolor='black', 
                               linewidth=2, alpha=0.8)
                ax.add_patch(rect)
            
            # Add text
            ax.text(comp["pos"][0], comp["pos"][1], comp["name"], 
                   ha='center', va='center', fontsize=11, fontweight='bold',
                   wrap=True)
        
        # Draw arrows
        arrows = [
            ((2.6, 6), (3.2, 6)),  # Input to Reasoning
            ((5.6, 6), (6.2, 6)),  # Reasoning to Logprobs
            ((8.6, 6), (9.2, 6)),  # Logprobs to Entropy
            ((10, 5.2), (6.8, 3.8)),  # Entropy to Decision
            ((5.2, 2.6), (3.8, 1.4)),  # Decision to Early Stop (Yes)
            ((6.8, 2.6), (8.2, 1.4))   # Decision to Continue (No)
        ]
        
        for start, end in arrows:
            ax.annotate('', xy=end, xytext=start,
                       arrowprops=dict(arrowstyle='->', lw=2, color='black'))
        
        # Add decision labels
        ax.text(4.5, 2, 'Yes', fontsize=12, fontweight='bold', color='green')
        ax.text(7.5, 2, 'No', fontsize=12, fontweight='bold', color='red')
        
        ax.set_xlim(0, 12)
        ax.set_ylim(0, 7)
        ax.set_aspect('equal')
        ax.axis('off')
        ax.set_title('Entropy-Based Early Stopping Framework', 
                     fontsize=18, fontweight='bold', pad=20)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path + '.pdf', dpi=300, bbox_inches='tight')
            plt.savefig(save_path + '.png', dpi=300, bbox_inches='tight')
        
        return fig

def create_summary_dashboard(experiment_results: Dict,
                           save_path: str = "summary_dashboard") -> plt.Figure:
    """
    Create a comprehensive dashboard summarizing all experiment results.
    
    Args:
        experiment_results: Complete experiment results dictionary
        save_path: Path to save the dashboard
        
    Returns:
        Matplotlib figure object
    """
    fig = plt.figure(figsize=(20, 12))
    
    # Create subplots
    gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)
    
    viz = EntropyVisualization()
    
    # Extract data for visualizations
    models = list(experiment_results.keys())
    
    # Plot 1: Token Savings (top-left)
    ax1 = fig.add_subplot(gs[0, 0])
    savings_data = {model: experiment_results[model]['token_savings'] for model in models}
    # Simplified token savings plot code here
    
    # Plot 2: Accuracy Comparison (top-middle-left)
    ax2 = fig.add_subplot(gs[0, 1])
    # Accuracy comparison code here
    
    # Plot 3: Cohen's d (top-middle-right)  
    ax3 = fig.add_subplot(gs[0, 2])
    # Cohen's d comparison code here
    
    # Plot 4: Entropy Distributions (top-right)
    ax4 = fig.add_subplot(gs[0, 3])
    # Entropy distribution code here
    
    # Plot 5: Threshold Analysis (middle row)
    ax5 = fig.add_subplot(gs[1, :2])
    # Threshold analysis code here
    
    # Plot 6: Performance Summary (middle-right)
    ax6 = fig.add_subplot(gs[1, 2:])
    # Performance summary code here
    
    # Plot 7: Framework Overview (bottom row)
    ax7 = fig.add_subplot(gs[2, :])
    # Framework overview code here
    
    plt.suptitle('Entropy-Based Early Stopping: Complete Results Dashboard', 
                 fontsize=24, fontweight='bold', y=0.98)
    
    plt.savefig(save_path + '.pdf', dpi=300, bbox_inches='tight')
    plt.savefig(save_path + '.png', dpi=300, bbox_inches='tight')
    
    return fig

# Utility functions
def load_experiment_results(file_path: str) -> Dict:
    """Load experiment results from JSON file."""
    with open(file_path, 'r') as f:
        return json.load(f)

def save_results_summary(results: Dict, output_path: str):
    """Save a summary of results to CSV and JSON."""
    # Create summary DataFrame
    summary_data = []
    for model, result in results.items():
        summary_data.append({
            'Model': model,
            'Step1_Accuracy': result.get('step1_accuracy', 0),
            'Final_Accuracy': result.get('final_accuracy', 0),
            'Threshold_Accuracy': result.get('threshold_accuracy', 0),
            'Token_Savings': result.get('token_savings', 0),
            'Cohens_D': result.get('cohens_d', 0),
            'Threshold': result.get('threshold', 0)
        })
    
    df = pd.DataFrame(summary_data)
    df.to_csv(output_path + '_summary.csv', index=False)
    
    with open(output_path + '_summary.json', 'w') as f:
        json.dump(results, f, indent=2, default=str)

# Example usage
if __name__ == "__main__":
    # Initialize visualization toolkit
    viz = EntropyVisualization()
    
    # Example data for demonstration
    sample_correct = np.random.normal(0.3, 0.1, 100)
    sample_incorrect = np.random.normal(0.5, 0.15, 150)
    
    # Create entropy distribution plot
    fig = viz.plot_entropy_distributions(
        sample_correct, 
        sample_incorrect,
        threshold=0.3,
        title="Sample Entropy Distribution",
        save_path="sample_entropy_dist"
    )
    
    plt.show()
    
    print("Visualization toolkit loaded successfully!")
    print("Available plotting functions:")
    print("- plot_entropy_distributions()")
    print("- plot_token_savings()")
    print("- plot_threshold_analysis()")
    print("- plot_cohens_d_comparison()")
    print("- plot_accuracy_breakdown()")
    print("- create_framework_overview()")
    print("- create_summary_dashboard()")