import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from collections import defaultdict
import os

# Set style for publication-quality plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

def compute_confidence_intervals(data, confidence_level=0.95, method='bootstrap', n_bootstrap=1000):
    """
    Compute confidence intervals for accuracy data using different statistical methods.
    
    Args:
        data: List of accuracy values or list of lists for multiple runs
        confidence_level: Confidence level (e.g., 0.95 for 95% CI)
        method: 'bootstrap', 'normal', 'percentile', or 'sem'
        n_bootstrap: Number of bootstrap samples
    
    Returns:
        dict with 'mean', 'lower', 'upper' arrays
    """
    
    if not isinstance(data[0], (list, np.ndarray)):
        # Single run - use rolling statistics or theoretical CI
        data = np.array(data)
        
        if method == 'normal':
            # Assume normal distribution
            alpha = 1 - confidence_level
            mean = data
            # Use rolling standard error
            window_size = min(5, len(data))
            std_errors = []
            
            for i in range(len(data)):
                start_idx = max(0, i - window_size + 1)
                window_data = data[start_idx:i+1]
                if len(window_data) > 1:
                    se = stats.sem(window_data)
                else:
                    se = 0.1 * data[i]  # Fallback
                std_errors.append(se)
            
            std_errors = np.array(std_errors)
            margin = stats.t.ppf((1 + confidence_level) / 2, df=len(data)-1) * std_errors
            
            return {
                'mean': mean,
                'lower': mean - margin,
                'upper': mean + margin
            }
            
        elif method == 'sem':
            # Standard error of the mean approach
            mean = data
            # Compute rolling SEM
            window_size = 5
            sems = []
            
            for i in range(len(data)):
                start_idx = max(0, i - window_size + 1)
                window_data = data[start_idx:i+1]
                if len(window_data) > 1:
                    sem = stats.sem(window_data)
                else:
                    sem = 0.05 * data[i]  # 5% as default uncertainty
                sems.append(sem)
            
            sems = np.array(sems)
            # Use 1.96 for 95% CI
            margin = stats.norm.ppf((1 + confidence_level) / 2) * sems
            
            return {
                'mean': mean,
                'lower': mean - margin,
                'upper': mean + margin
            }
    
    else:
        # Multiple runs - compute empirical confidence intervals
        data = np.array(data)  # Shape: (n_runs, n_epochs)
        
        if method == 'bootstrap':
            # Bootstrap confidence intervals
            n_runs, n_epochs = data.shape
            means = []
            lowers = []
            uppers = []
            
            for epoch in range(n_epochs):
                epoch_data = data[:, epoch]
                # Bootstrap sampling
                bootstrap_means = []
                for _ in range(n_bootstrap):
                    boot_sample = np.random.choice(epoch_data, size=len(epoch_data), replace=True)
                    bootstrap_means.append(np.mean(boot_sample))
                
                bootstrap_means = np.array(bootstrap_means)
                alpha = 1 - confidence_level
                lower_percentile = (alpha / 2) * 100
                upper_percentile = (1 - alpha / 2) * 100
                
                means.append(np.mean(epoch_data))
                lowers.append(np.percentile(bootstrap_means, lower_percentile))
                uppers.append(np.percentile(bootstrap_means, upper_percentile))
                
        elif method == 'percentile':
            # Direct percentile method
            alpha = 1 - confidence_level
            lower_percentile = (alpha / 2) * 100
            upper_percentile = (1 - alpha / 2) * 100
            
            means = np.mean(data, axis=0)
            lowers = np.percentile(data, lower_percentile, axis=0)
            uppers = np.percentile(data, upper_percentile, axis=0)
            
        else:  # method == 'normal'
            # Assume normal distribution
            means = np.mean(data, axis=0)
            stds = np.std(data, axis=0, ddof=1)
            n_runs = data.shape[0]
            
            # Standard error of mean
            sems = stds / np.sqrt(n_runs)
            
            # t-distribution critical value
            alpha = 1 - confidence_level
            t_critical = stats.t.ppf((1 + confidence_level) / 2, df=n_runs-1)
            
            margin = t_critical * sems
            lowers = means - margin
            uppers = means + margin
        
        return {
            'mean': np.array(means),
            'lower': np.array(lowers), 
            'upper': np.array(uppers)
        }


def plot_accuracy_with_confidence(metrics, save_path, confidence_level=0.95, 
                                method='bootstrap', title_suffix=""):
    """
    Plot accuracy with confidence intervals in publication style.
    
    Args:
        metrics: Dict containing 'accuracy' key with data
        save_path: Path to save the plot
        confidence_level: Confidence level for intervals
        method: Method for computing CI ('bootstrap', 'normal', 'percentile', 'sem')
        title_suffix: Additional text for plot title
    """
    
    if 'accuracy' not in metrics:
        print("No accuracy data found in metrics")
        return
    
    # Prepare data
    accuracy_data = metrics['accuracy']
    epochs = range(1, len(accuracy_data) + 1)
    
    # Compute confidence intervals
    ci_data = compute_confidence_intervals(accuracy_data, confidence_level, method)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Plot main line
    ax.plot(epochs, ci_data['mean'], 'b-', linewidth=3, label='Accuracy', alpha=0.8)
    
    # Plot confidence interval
    ax.fill_between(epochs, ci_data['lower'], ci_data['upper'], 
                   alpha=0.3, color='skyblue', 
                   label=f'{int(confidence_level*100)}% Confidence Interval')
    
    # Mark convergence if available
    if 'convergence_epoch' in metrics and metrics['convergence_epoch']:
        conv_epoch = metrics['convergence_epoch']
        ax.axvline(conv_epoch, color='red', linestyle='--', linewidth=2,
                  label=f'Convergence (epoch {conv_epoch})')
        
        # Add annotation
        conv_accuracy = ci_data['mean'][conv_epoch-1] if conv_epoch <= len(ci_data['mean']) else ci_data['mean'][-1]
        ax.annotate(f'Convergence\n{conv_accuracy:.1f}%', 
                   xy=(conv_epoch, conv_accuracy),
                   xytext=(conv_epoch + 5, conv_accuracy + 5),
                   arrowprops=dict(arrowstyle='->', color='red', alpha=0.7),
                   fontsize=10, ha='left')
    
    # Styling
    ax.set_xlabel('Training Epoch', fontsize=14, fontweight='bold')
    ax.set_ylabel('Classification Accuracy (%)', fontsize=14, fontweight='bold')
    ax.set_title(f'Neuromorphic Agent Learning Curve{title_suffix}', 
                fontsize=16, fontweight='bold', pad=20)
    
    # Set y-axis limits with some padding
    y_min = max(0, min(ci_data['lower']) - 2)
    y_max = min(100, max(ci_data['upper']) + 2)
    ax.set_ylim(y_min, y_max)
    
    # Professional grid
    ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
    ax.set_axisbelow(True)
    
    # Legend
    ax.legend(loc='lower right', fontsize=12, framealpha=0.9)
    
    # Add statistics box
    final_accuracy = ci_data['mean'][-1]
    final_lower = ci_data['lower'][-1]
    final_upper = ci_data['upper'][-1]
    
    stats_text = f"""Final Performance:
{final_accuracy:.1f}% ± {(final_upper - final_lower)/2:.1f}%
[{final_lower:.1f}%, {final_upper:.1f}%]
Method: {method.capitalize()}
Confidence: {int(confidence_level*100)}%"""
    
    ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, 
           fontsize=10, verticalalignment='top',
           bbox=dict(boxstyle="round,pad=0.5", facecolor='lightyellow', alpha=0.8))
    
    # Tight layout and save
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    
    print(f"Accuracy plot with confidence intervals saved to {save_path}")


def plot_multiple_metrics_with_confidence(metrics, save_path, metrics_to_plot=None,
                                        confidence_level=0.95, method='bootstrap'):
    """
    Plot multiple metrics with confidence intervals in subplots.
    """
    
    if metrics_to_plot is None:
        metrics_to_plot = ['accuracy', 'avg_reward', 'protocol_discriminability', 'action_confidence']
    
    # Filter available metrics
    available_metrics = [m for m in metrics_to_plot if m in metrics and len(metrics[m]) > 0]
    
    if not available_metrics:
        print("No valid metrics found for plotting")
        return
    
    n_metrics = len(available_metrics)
    cols = 2
    rows = (n_metrics + 1) // 2
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 5*rows))
    if n_metrics == 1:
        axes = [axes]
    elif rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    
    for idx, metric in enumerate(available_metrics):
        ax = axes[idx]
        data = metrics[metric]
        epochs = range(1, len(data) + 1)
        
        # Compute confidence intervals
        ci_data = compute_confidence_intervals(data, confidence_level, method)
        
        # Plot
        color = colors[idx % len(colors)]
        ax.plot(epochs, ci_data['mean'], color=color, linewidth=2.5, alpha=0.9)
        ax.fill_between(epochs, ci_data['lower'], ci_data['upper'], 
                       alpha=0.25, color=color)
        
        # Styling
        ax.set_xlabel('Epoch', fontsize=12)
        ax.set_ylabel(metric.replace('_', ' ').title(), fontsize=12)
        ax.set_title(f'{metric.replace("_", " ").title()}', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # Add final value annotation
        final_val = ci_data['mean'][-1]
        final_lower = ci_data['lower'][-1]
        final_upper = ci_data['upper'][-1]
        
        ax.text(0.98, 0.02, f'{final_val:.2f}\n[{final_lower:.2f}, {final_upper:.2f}]',
               transform=ax.transAxes, fontsize=9, ha='right', va='bottom',
               bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))
    
    # Remove empty subplots
    for idx in range(n_metrics, len(axes)):
        fig.delaxes(axes[idx])
    
    plt.suptitle('Training Metrics with Confidence Intervals', 
                fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    
    print(f"Multiple metrics plot saved to {save_path}")


def create_convergence_analysis_plot(metrics, save_path, convergence_threshold=90.0):
    """
    Create a detailed convergence analysis plot with confidence intervals.
    """
    
    if 'accuracy' not in metrics:
        return
    
    accuracy_data = metrics['accuracy']
    epochs = np.array(range(1, len(accuracy_data) + 1))
    
    # Compute confidence intervals
    ci_data = compute_confidence_intervals(accuracy_data, 0.95, 'bootstrap')
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
    
    # Top plot: Full learning curve
    ax1.plot(epochs, ci_data['mean'], 'b-', linewidth=3, label='Mean Accuracy')
    ax1.fill_between(epochs, ci_data['lower'], ci_data['upper'], 
                    alpha=0.3, color='skyblue', label='95% CI')
    
    # Mark convergence
    converged = np.where(np.array(ci_data['mean']) >= convergence_threshold)[0]
    if len(converged) > 0:
        conv_epoch = converged[0] + 1
        ax1.axvline(conv_epoch, color='red', linestyle='--', linewidth=2,
                   label=f'Convergence (epoch {conv_epoch})')
        ax1.axhline(convergence_threshold, color='orange', linestyle=':', 
                   alpha=0.7, label=f'Threshold ({convergence_threshold}%)')
    
    ax1.set_xlabel('Training Epoch', fontsize=12)
    ax1.set_ylabel('Accuracy (%)', fontsize=12)
    ax1.set_title('Learning Curve with Convergence Analysis', fontsize=14, fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Bottom plot: Post-convergence stability
    if len(converged) > 0 and len(converged) < len(epochs) - 5:
        post_conv_epochs = epochs[converged[0]:]
        post_conv_mean = ci_data['mean'][converged[0]:]
        post_conv_lower = ci_data['lower'][converged[0]:]
        post_conv_upper = ci_data['upper'][converged[0]:]
        
        ax2.plot(post_conv_epochs, post_conv_mean, 'g-', linewidth=3, label='Post-convergence')
        ax2.fill_between(post_conv_epochs, post_conv_lower, post_conv_upper,
                        alpha=0.3, color='lightgreen')
        
        # Stability metrics
        stability_std = np.std(post_conv_mean)
        stability_mean = np.mean(post_conv_mean)
        
        ax2.axhline(stability_mean, color='darkgreen', linestyle='-', alpha=0.7, 
                   label=f'Mean: {stability_mean:.1f}%')
        ax2.axhline(stability_mean + stability_std, color='darkgreen', 
                   linestyle=':', alpha=0.5, label=f'±1σ: {stability_std:.1f}%')
        ax2.axhline(stability_mean - stability_std, color='darkgreen', 
                   linestyle=':', alpha=0.5)
        
        ax2.set_xlabel('Training Epoch', fontsize=12)
        ax2.set_ylabel('Accuracy (%)', fontsize=12)
        ax2.set_title('Post-Convergence Stability Analysis', fontsize=14, fontweight='bold')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
    else:
        ax2.text(0.5, 0.5, 'Convergence not yet achieved\nor insufficient post-convergence data',
                ha='center', va='center', transform=ax2.transAxes, fontsize=14,
                bbox=dict(boxstyle="round,pad=0.5", facecolor='lightyellow'))
        ax2.set_title('Post-Convergence Analysis', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    
    print(f"Convergence analysis plot saved to {save_path}")


def integrate_confidence_plots_into_training(metrics, epoch, base_path="plots"):
    """
    Integration function to add confidence interval plots to existing training loop.
    """
    
    epoch_dir = os.path.join(base_path, f"epoch_{epoch}")
    os.makedirs(epoch_dir, exist_ok=True)
    
    # Generate different types of confidence plots
    plot_functions = [
        (plot_accuracy_with_confidence, "accuracy_with_ci.png", 
         {"method": "bootstrap", "title_suffix": " (Bootstrap CI)"}),
        
        (plot_accuracy_with_confidence, "accuracy_with_ci_normal.png",
         {"method": "normal", "title_suffix": " (Normal CI)"}),
         
        (plot_multiple_metrics_with_confidence, "all_metrics_with_ci.png", {}),
        
        (create_convergence_analysis_plot, "convergence_analysis.png", {})
    ]
    
    for plot_func, filename, kwargs in plot_functions:
        try:
            full_path = os.path.join(epoch_dir, filename)
            plot_func(metrics, full_path, **kwargs)
        except Exception as e:
            print(f"Failed to generate {filename}: {e}")


def enhanced_plot_generation_example(metrics):
    """
    Example showing how to use the new confidence interval plotting functions.
    """
    
    # Simple accuracy plot with confidence intervals
    plot_accuracy_with_confidence(
        metrics, 
        "plots/accuracy_bootstrap_ci.png",
        method="bootstrap",
        confidence_level=0.95
    )
    
    # Multiple metrics with confidence intervals
    plot_multiple_metrics_with_confidence(
        metrics,
        "plots/all_metrics_ci.png",
        metrics_to_plot=['accuracy', 'avg_reward', 'protocol_discriminability'],
        method="bootstrap"
    )
    
    # Detailed convergence analysis
    create_convergence_analysis_plot(
        metrics,
        "plots/convergence_analysis.png"
    )