import os
import random

import torch
import torch.nn.functional as F

import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.patches as mpatches

from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix

# Set style for beautiful plots
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")


def visualize_communication_confusion_matrix(agent_sender, agent_receiver, dataset, 
                                           device, save_path, n_trials=200):
    """Visualize confusion matrix of communication success across classes"""
    
    print(" Computing communication confusion matrix...")
    
    class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
                   'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    
    # Initialize tracking
    true_classes = []
    predicted_classes = []
    confidence_scores = []
    
    with torch.no_grad():
        for trial in range(n_trials):
            try:
                # Sample random trial
                target_img, candidates, target_idx, labels = sample_referential_trial(dataset, device=device)
                target_class = labels[target_idx]
                
                # Communication
                sender_msg = agent_sender.send_message(target_img)
                action, Qs, info_dict = agent_receiver.make_decision(candidates, sender_msg)
                
                selected_class = labels[action]
                
                # Safe confidence extraction
                confidence = 0.5  # Default fallback
                if 'confidence' in info_dict:
                    conf_val = info_dict['confidence']
                    if isinstance(conf_val, torch.Tensor):
                        confidence = conf_val.item()
                    elif isinstance(conf_val, (int, float)):
                        confidence = conf_val
                elif 'confidence' in info_dict:
                    conf_val = info_dict['confidence']
                    if hasattr(conf_val, 'max'):
                        confidence = conf_val.max().item()
                    elif isinstance(conf_val, (int, float)):
                        confidence = conf_val
                
                # Ensure confidence is valid
                if not isinstance(confidence, (int, float)) or np.isnan(confidence) or np.isinf(confidence):
                    confidence = 0.5
                
                true_classes.append(target_class)
                predicted_classes.append(selected_class)
                confidence_scores.append(confidence if action == target_idx else 1 - confidence)
                
            except Exception as e:
                continue
    
    # Check if we have valid data
    if len(true_classes) == 0 or len(predicted_classes) == 0:
        print("  ⚠️  No valid data collected for confusion matrix")
        return
    
    # Create confusion matrix
    cm = confusion_matrix(true_classes, predicted_classes, labels=range(10))
    
    # Safe normalization to handle zero divisions
    row_sums = cm.sum(axis=1)[:, np.newaxis]
    # Replace zero sums with 1 to avoid division by zero
    row_sums = np.where(row_sums == 0, 1, row_sums)
    cm_normalized = cm.astype('float') / row_sums
    
    # Replace NaN values with 0
    cm_normalized = np.nan_to_num(cm_normalized, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Calculate overall accuracy safely
    if len(true_classes) > 0:
        overall_accuracy = np.sum(np.array(true_classes) == np.array(predicted_classes)) / len(true_classes) * 100
    else:
        overall_accuracy = 0.0
    
    # Ensure accuracy is finite
    if not np.isfinite(overall_accuracy):
        overall_accuracy = 0.0
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
    
    # Plot 1: Confusion Matrix
    im = ax1.imshow(cm_normalized, interpolation='nearest', cmap='Blues', vmin=0, vmax=1)
    ax1.set_title('Communication Confusion Matrix', fontsize=16, fontweight='bold')
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax1)
    cbar.set_label('Success Rate', rotation=270, labelpad=20)
    
    # Add labels
    ax1.set_xticks(np.arange(10))
    ax1.set_yticks(np.arange(10))
    ax1.set_xticklabels(class_names, rotation=45, ha='right')
    ax1.set_yticklabels(class_names)
    ax1.set_xlabel('Communicated Class')
    ax1.set_ylabel('Target Class')
    
    # Add text annotations with safe values
    for i in range(10):
        for j in range(10):
            value = cm_normalized[i, j]
            # Ensure value is finite
            if np.isfinite(value):
                text = ax1.text(j, i, f'{value:.2f}',
                               ha='center', va='center',
                               color='white' if value > 0.5 else 'black')
            else:
                text = ax1.text(j, i, '0.00',
                               ha='center', va='center',
                               color='black')
    
    # Plot 2: Per-class accuracy with confidence
    accuracies = np.diag(cm_normalized)
    # Ensure accuracies are finite
    accuracies = np.nan_to_num(accuracies, nan=0.0, posinf=0.0, neginf=0.0)
    
    class_confidence = []
    for i in range(10):
        class_conf = [c for t, p, c in zip(true_classes, predicted_classes, confidence_scores) 
                      if t == i and np.isfinite(c)]
        if class_conf:
            avg_conf = np.mean(class_conf)
            # Ensure confidence is finite
            if np.isfinite(avg_conf):
                class_confidence.append(avg_conf)
            else:
                class_confidence.append(0.0)
        else:
            class_confidence.append(0.0)
    
    x = np.arange(10)
    width = 0.35
    
    bars1 = ax2.bar(x - width/2, accuracies, width, label='Accuracy', alpha=0.8)
    bars2 = ax2.bar(x + width/2, class_confidence, width, label='Avg Confidence', alpha=0.8)
    
    ax2.set_xlabel('Class')
    ax2.set_ylabel('Score')
    ax2.set_title('Per-Class Communication Performance', fontsize=16, fontweight='bold')
    ax2.set_xticks(x)
    ax2.set_xticklabels(class_names, rotation=45, ha='right')
    ax2.legend()
    ax2.set_ylim(0, 1.1)
    
    # Add value labels on bars with safe values
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            if np.isfinite(height):
                ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{height:.2f}', ha='center', va='bottom', fontsize=9)
    
    # Fixed title with safe accuracy value
    plt.suptitle(f'Communication Analysis (Overall Accuracy: {overall_accuracy:.1f}%)', 
                 fontsize=18, fontweight='bold')
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  💾 Communication confusion matrix saved to {save_path}")


def safe_confidence_extraction(info_dict, default=0.5):
    """Safely extract confidence from info_dict with fallbacks"""
    confidence = default
    
    # Try different possible keys
    for key in ['action_confidence', 'confidence', 'max_confidence']:
        if key in info_dict:
            conf_val = info_dict[key]
            try:
                if isinstance(conf_val, torch.Tensor):
                    if conf_val.numel() == 1:
                        confidence = conf_val.item()
                    else:
                        confidence = conf_val.max().item()
                elif isinstance(conf_val, (int, float)):
                    confidence = conf_val
                elif hasattr(conf_val, 'item'):
                    confidence = conf_val.item()
                
                # Validate the extracted confidence
                if isinstance(confidence, (int, float)) and np.isfinite(confidence):
                    return max(0.0, min(1.0, confidence))  # Clamp to [0,1]
                    
            except Exception:
                continue
    
    return default



def visualize_decision_boundaries_tsne(agent_sender, agent_receiver, dataset, 
                                      device, save_path, n_samples=500):
    """Visualize decision boundaries using t-SNE on spike encodings"""
    
    print("  Computing t-SNE decision boundaries...")
    
    # Collect spike encodings
    spike_encodings = []
    true_labels = []
    decisions = []
    success_flags = []
    
    with torch.no_grad():
        for _ in range(n_samples):
            try:
                target_img, candidates, target_idx, labels = sample_referential_trial(dataset, device=device)
                
                # Get sender encoding
                sender_msg = agent_sender.send_message(target_img)
                avg_encoding = sender_msg.mean(dim=0).cpu().numpy()  # Average over time
                
                # Get decision
                action, Qs, _ = agent_receiver.make_decision(candidates, sender_msg)
                
                spike_encodings.append(avg_encoding)
                true_labels.append(labels[target_idx])
                decisions.append(labels[action])
                success_flags.append(action == target_idx)
                
            except Exception as e:
                continue
    
    spike_encodings = np.array(spike_encodings)
    true_labels = np.array(true_labels)
    decisions = np.array(decisions)
    success_flags = np.array(success_flags)
    
    # Compute t-SNE
    print("    Computing t-SNE embedding...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    embeddings_2d = tsne.fit_transform(spike_encodings)
    
    # Create figure with subplots
    fig = plt.figure(figsize=(20, 10))
    gs = GridSpec(2, 3, figure=fig)
    
    # Plot 1: True classes
    ax1 = fig.add_subplot(gs[0, 0])
    scatter1 = ax1.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                          c=true_labels, cmap='tab10', alpha=0.6, s=50)
    ax1.set_title('True Classes', fontsize=14, fontweight='bold')
    ax1.set_xlabel('t-SNE 1')
    ax1.set_ylabel('t-SNE 2')
    plt.colorbar(scatter1, ax=ax1, label='Class')
    
    # Plot 2: Communicated classes
    ax2 = fig.add_subplot(gs[0, 1])
    scatter2 = ax2.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                          c=decisions, cmap='tab10', alpha=0.6, s=50)
    ax2.set_title('Communicated Classes', fontsize=14, fontweight='bold')
    ax2.set_xlabel('t-SNE 1')
    ax2.set_ylabel('t-SNE 2')
    plt.colorbar(scatter2, ax=ax2, label='Class')
    
    # Plot 3: Success/Failure
    ax3 = fig.add_subplot(gs[0, 2])
    colors = ['red' if not s else 'green' for s in success_flags]
    ax3.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
               c=colors, alpha=0.6, s=50)
    ax3.set_title('Communication Success', fontsize=14, fontweight='bold')
    ax3.set_xlabel('t-SNE 1')
    ax3.set_ylabel('t-SNE 2')
    
    # Add legend
    red_patch = mpatches.Patch(color='red', label='Failed')
    green_patch = mpatches.Patch(color='green', label='Success')
    ax3.legend(handles=[green_patch, red_patch])
    
    # Plot 4: Decision confidence regions
    ax4 = fig.add_subplot(gs[1, :])
    
    # Create a mesh grid
    x_min, x_max = embeddings_2d[:, 0].min() - 1, embeddings_2d[:, 0].max() + 1
    y_min, y_max = embeddings_2d[:, 1].min() - 1, embeddings_2d[:, 1].max() + 1
    
    # Plot successful communications with class colors
    for class_id in range(10):
        class_mask = (true_labels == class_id) & success_flags
        if np.any(class_mask):
            ax4.scatter(embeddings_2d[class_mask, 0], embeddings_2d[class_mask, 1],
                       label=f'Class {class_id}', alpha=0.7, s=100, edgecolors='black')
    
    # Highlight misclassifications
    fail_mask = ~success_flags
    if np.any(fail_mask):
        ax4.scatter(embeddings_2d[fail_mask, 0], embeddings_2d[fail_mask, 1],
                   c='red', marker='x', s=100, label='Misclassified', alpha=0.8)
    
    ax4.set_title('Decision Landscape', fontsize=16, fontweight='bold')
    ax4.set_xlabel('t-SNE 1')
    ax4.set_ylabel('t-SNE 2')
    ax4.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax4.grid(True, alpha=0.3)
    
    # Add statistics text
    accuracy = np.mean(success_flags)
    stats_text = f"""
        Communication Statistics:
        ━━━━━━━━━━━━━━━━━━━━━
        Accuracy: {accuracy:.1%}
        Samples: {len(success_flags)}
        Failed: {np.sum(~success_flags)}
    """
    fig.text(0.02, 0.02, stats_text, fontsize=10, fontfamily='monospace',
             bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
    
    plt.suptitle('Communication Protocol Decision Boundaries', fontsize=18, fontweight='bold')
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  💾 Decision boundaries visualization saved to {save_path}")

def safe_plot_metric(ax, epochs, metric_data, label, **kwargs):
    """Safely plot metric data that might have different lengths than epochs"""
    if not metric_data or len(metric_data) == 0:
        ax.text(0.5, 0.5, f'No {label} Data\nAvailable', ha='center', va='center',
                fontsize=12, transform=ax.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
        return False
    
    # Ensure metric_data is a list/array
    if not isinstance(metric_data, (list, np.ndarray)):
        metric_data = [metric_data]
    
    # Handle length mismatch
    if len(metric_data) != len(epochs):
        # Use only the available data
        available_epochs = range(1, len(metric_data) + 1)
        ax.plot(available_epochs, metric_data, label=label, **kwargs)
        print(f"  ⚠️  {label}: Using {len(metric_data)} data points instead of {len(epochs)} epochs")
    else:
        ax.plot(epochs, metric_data, label=label, **kwargs)
    
    return True


def visualize_protocol_evolution(agent_a, agent_b, metrics, save_path):
    """Fixed version of protocol evolution visualization with length checking"""
    
    print("  Analyzing protocol evolution...")
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    
    # Base epochs from accuracy (most reliable metric)
    if 'accuracy' in metrics and len(metrics['accuracy']) > 0:
        epochs = range(1, len(metrics['accuracy']) + 1)
    else:
        print("  ⚠️  No accuracy data available for epoch reference")
        epochs = range(1, 11)  # Default fallback
    
    # 1. Protocol discriminability evolution
    ax = axes[0]
    if 'protocol_discriminability' in metrics and len(metrics['protocol_discriminability']) > 0:
        if safe_plot_metric(ax, epochs, metrics['protocol_discriminability'], 
                           'Protocol Discriminability', linewidth=3, color='blue'):
            ax.fill_between(range(1, len(metrics['protocol_discriminability']) + 1), 
                           metrics['protocol_discriminability'], alpha=0.3)
            ax.set_title('Protocol Discriminability', fontweight='bold')
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Discriminability Score')
            ax.grid(True, alpha=0.3)
            
            # Add trend line if enough data
            if len(metrics['protocol_discriminability']) > 3:
                disc_epochs = range(1, len(metrics['protocol_discriminability']) + 1)
                z = np.polyfit(disc_epochs, metrics['protocol_discriminability'], 1)
                p = np.poly1d(z)
                ax.plot(disc_epochs, p(disc_epochs), 'r--', alpha=0.8, linewidth=1, label='Trend')
                ax.legend()
    else:
        ax.text(0.5, 0.5, 'No Protocol\nDiscriminability Data', ha='center', va='center',
                fontsize=14, transform=ax.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
        ax.set_title('Protocol Discriminability', fontweight='bold')
    
    # 2. Within vs Between class similarity
    ax = axes[1]
    has_within = 'within_class_similarity' in metrics and len(metrics['within_class_similarity']) > 0
    has_between = 'between_class_similarity' in metrics and len(metrics['between_class_similarity']) > 0
    
    if has_within and has_between:
        # Use the shorter of the two metrics
        min_len = min(len(metrics['within_class_similarity']), len(metrics['between_class_similarity']))
        sim_epochs = range(1, min_len + 1)
        
        within_data = metrics['within_class_similarity'][:min_len]
        between_data = metrics['between_class_similarity'][:min_len]
        
        ax.plot(sim_epochs, within_data, 'g-', linewidth=2, label='Within-class', marker='o')
        ax.plot(sim_epochs, between_data, 'r-', linewidth=2, label='Between-class', marker='s')
        ax.fill_between(sim_epochs, within_data, between_data, alpha=0.2, color='gray')
        ax.set_title('Class Separation in Protocol', fontweight='bold')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Similarity')
        ax.legend()
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No Class Similarity\nData Available', ha='center', va='center',
                fontsize=14, transform=ax.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
        ax.set_title('Class Separation in Protocol', fontweight='bold')
    
    # 3. Temporal attention consistency
    ax = axes[2]
    if 'attention_consistency' in metrics and len(metrics['attention_consistency']) > 0:
        safe_plot_metric(ax, epochs, metrics['attention_consistency'], 
                        'Attention Consistency', linewidth=2, color='purple')
        ax.set_title('Temporal Attention Consistency', fontweight='bold')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Consistency Score')
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No Attention\nConsistency Data', ha='center', va='center',
                fontsize=14, transform=ax.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
        ax.set_title('Temporal Attention Consistency', fontweight='bold')
    
    # 4. Accuracy vs Protocol Quality
    ax = axes[3]
    if ('protocol_discriminability' in metrics and len(metrics['protocol_discriminability']) > 0 and
        'accuracy' in metrics and len(metrics['accuracy']) > 0):
        
        # Use the shorter length
        min_len = min(len(metrics['protocol_discriminability']), len(metrics['accuracy']))
        disc_data = metrics['protocol_discriminability'][:min_len]
        acc_data = metrics['accuracy'][:min_len]
        plot_epochs = range(1, min_len + 1)
        
        scatter = ax.scatter(disc_data, acc_data, c=plot_epochs, cmap='viridis', s=100, alpha=0.7)
        plt.colorbar(scatter, ax=ax, label='Epoch')
        ax.set_xlabel('Protocol Discriminability')
        ax.set_ylabel('Accuracy (%)')
        ax.set_title('Protocol Quality vs Performance', fontweight='bold')
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No Protocol vs\nAccuracy Data', ha='center', va='center',
                fontsize=14, transform=ax.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
        ax.set_title('Protocol Quality vs Performance', fontweight='bold')
    
    # 5. Learning dynamics
    ax = axes[4]
    has_loss = 'avg_loss' in metrics and len(metrics['avg_loss']) > 0
    has_epsilon = 'epsilon' in metrics and len(metrics['epsilon']) > 0
    
    if has_loss or has_epsilon:
        ax2 = ax.twinx()
        
        if has_loss:
            loss_epochs = range(1, len(metrics['avg_loss']) + 1)
            l1 = ax.plot(loss_epochs, metrics['avg_loss'], 'b-', linewidth=2, label='Loss')
            ax.set_ylabel('Loss', color='b')
        
        if has_epsilon:
            eps_epochs = range(1, len(metrics['epsilon']) + 1)
            l2 = ax2.plot(eps_epochs, metrics['epsilon'], 'r--', linewidth=2, label='Exploration (ε)')
            ax2.set_ylabel('Epsilon', color='r')
        
        ax.set_xlabel('Epoch')
        ax.set_title('Learning Dynamics', fontweight='bold')
        
        # Combine legends if both exist
        if has_loss and has_epsilon:
            lns = l1 + l2
            labs = [l.get_label() for l in lns]
            ax.legend(lns, labs, loc='upper right')
        elif has_loss:
            ax.legend(loc='upper right')
        elif has_epsilon:
            ax2.legend(loc='upper right')
        
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No Learning\nDynamics Data', ha='center', va='center',
                fontsize=14, transform=ax.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
        ax.set_title('Learning Dynamics', fontweight='bold')
    
    # 6. Protocol emergence phases
    ax = axes[5]
    if ('protocol_discriminability' in metrics and len(metrics['protocol_discriminability']) > 0 and
        'accuracy' in metrics and len(metrics['accuracy']) > 0):
        
        # Use available data
        min_len = min(len(metrics['protocol_discriminability']), len(metrics['accuracy']))
        disc = np.array(metrics['protocol_discriminability'][:min_len])
        acc = np.array(metrics['accuracy'][:min_len])
        
        # Phase detection
        phase_colors = []
        phase_labels = []
        
        for i, (d, a) in enumerate(zip(disc, acc)):
            if d < 0.1 and a < 40:
                phase_colors.append('red')
                phase_labels.append('Random')
            elif d < 0.3 and a < 60:
                phase_colors.append('orange')
                phase_labels.append('Emerging')
            elif d < 0.5 and a < 80:
                phase_colors.append('yellow')
                phase_labels.append('Developing')
            else:
                phase_colors.append('green')
                phase_labels.append('Established')
        
        # Create phase visualization
        if phase_colors:
            for i, (color, label) in enumerate(zip(phase_colors, phase_labels)):
                ax.barh(0, 1, left=i, height=1, color=color, alpha=0.7)
            
            ax.set_xlim(0, len(phase_colors))
            ax.set_ylim(-0.5, 0.5)
            ax.set_xlabel('Epoch')
            ax.set_title('Protocol Development Phases', fontweight='bold')
            ax.set_yticks([])
            
            # Add legend
            handles = [mpatches.Patch(color='red', label='Random'),
                      mpatches.Patch(color='orange', label='Emerging'),
                      mpatches.Patch(color='yellow', label='Developing'),
                      mpatches.Patch(color='green', label='Established')]
            ax.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=4)
    else:
        ax.text(0.5, 0.5, 'No Phase\nAnalysis Data', ha='center', va='center',
                fontsize=14, transform=ax.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
        ax.set_title('Protocol Development Phases', fontweight='bold')
    
    plt.suptitle('Communication Protocol Evolution Analysis', fontsize=16, fontweight='bold')
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  💾 Protocol evolution analysis saved to {save_path}")


def visualize_spike_similarity_metrics(agent_sender, agent_receiver, dataset, device, save_path):
    """Visualize the different spike similarity metrics used by DecisionNet"""
    
    print("  Analyzing spike similarity metrics...")
    
    # Collect similarity data
    similarities_by_metric = {
        'rate': {'correct': [], 'incorrect': []},
        'correlation': {'correct': [], 'incorrect': []},
        'count': {'correct': [], 'incorrect': []},
        'timing': {'correct': [], 'incorrect': []}
    }
    
    n_collected = 0
    n_successful_trials = 0
    
    with torch.no_grad():
        for trial in range(500):  # Increase trials for more data
            try:
                target_img, candidates, target_idx, labels = sample_referential_trial(dataset, device=device)
                
                # Get sender message
                sender_msg = agent_sender.send_message(target_img)
                
                # Ensure correct tensor shapes and devices
                sender_spikes = sender_msg.unsqueeze(1).to(device)  # [T, 1, n_msg]
                candidate_spikes = agent_receiver.encode_candidates(candidates, use_target=False)  # [T, 1, K, n_msg]
                
                # Method 1: Try using the existing get_similarity_metrics method
                similarity_metrics = None
                if hasattr(agent_receiver.decision, 'get_similarity_metrics'):
                    try:
                        similarity_metrics = agent_receiver.decision.get_similarity_metrics(
                            sender_spikes.squeeze(1),
                            candidate_spikes.squeeze(1)
                        )
                    except Exception as e:
                        print(f"    get_similarity_metrics failed: {e}")
                
                # Method 2: Compute similarities manually if method 1 fails
                if similarity_metrics is None or not similarity_metrics:
                    similarity_metrics = {}
                    
                    T, n_msg = sender_spikes.squeeze(1).shape
                    T, K, n_msg = candidate_spikes.squeeze(1).shape
                    
                    sender_flat = sender_spikes.squeeze(1)  # [T, n_msg]
                    
                    for k in range(K):
                        candidate_k = candidate_spikes.squeeze(1)[:, k, :]  # [T, n_msg]
                        
                        # Rate similarity
                        sender_rate = sender_flat.mean(0)
                        cand_rate = candidate_k.mean(0)
                        rate_sim = F.cosine_similarity(sender_rate.unsqueeze(0), cand_rate.unsqueeze(0), dim=1).item()
                        
                        # Temporal correlation
                        correlation = F.cosine_similarity(sender_flat.flatten().unsqueeze(0), 
                                                        candidate_k.flatten().unsqueeze(0), dim=1).item()
                        
                        # Count similarity
                        sender_count = sender_flat.sum().item()
                        cand_count = candidate_k.sum().item()
                        count_sim = 1.0 - abs(sender_count - cand_count) / max(sender_count + 1e-8, cand_count + 1e-8)
                        
                        # Timing similarity (using temporal patterns)
                        sender_temporal = sender_flat.sum(1)  # [T]
                        cand_temporal = candidate_k.sum(1)    # [T]
                        timing_sim = F.cosine_similarity(sender_temporal.unsqueeze(0), 
                                                       cand_temporal.unsqueeze(0), dim=1).item()
                        
                        similarity_metrics[f'candidate_{k}'] = {
                            'rate': max(0, min(1, rate_sim)),  # Clamp to [0,1]
                            'correlation': max(0, min(1, (correlation + 1) / 2)),  # Normalize to [0,1]
                            'count': max(0, min(1, count_sim)),
                            'timing': max(0, min(1, (timing_sim + 1) / 2))
                        }
                
                # Process collected metrics
                for k in range(len(candidates)):
                    is_correct = (k == target_idx)
                    candidate_key = f'candidate_{k}'
                    
                    if candidate_key in similarity_metrics:
                        candidate_metrics = similarity_metrics[candidate_key]
                        
                        for metric_name in ['rate', 'correlation', 'count', 'timing']:
                            if metric_name in candidate_metrics:
                                sim_value = candidate_metrics[metric_name]
                                
                                # Ensure it's a scalar
                                if isinstance(sim_value, torch.Tensor):
                                    sim_value = sim_value.item()
                                
                                # Validate the value
                                if isinstance(sim_value, (int, float)) and not np.isnan(sim_value):
                                    if is_correct:
                                        similarities_by_metric[metric_name]['correct'].append(sim_value)
                                    else:
                                        similarities_by_metric[metric_name]['incorrect'].append(sim_value)
                                    n_collected += 1
                
                n_successful_trials += 1
                
                if trial % 100 == 0:
                    print(f"    Processed {trial} trials, successful: {n_successful_trials}, data points: {n_collected}")
                
            except Exception as e:
                if trial < 10:  # Only print errors for first few trials
                    print(f"    Trial {trial} error: {e}")
                continue
    
    print(f"    Completed: {n_successful_trials} successful trials, {n_collected} data points collected")
    
    # Check if we have enough data
    total_data_points = sum(len(similarities_by_metric[metric]['correct']) + 
                           len(similarities_by_metric[metric]['incorrect']) 
                           for metric in similarities_by_metric)
    
    if total_data_points == 0:
        print("  ⚠️  No similarity data collected - creating placeholder plot")
        # Create a placeholder plot
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        axes = axes.flatten()
        
        for idx, (metric, label) in enumerate(zip(['rate', 'correlation', 'count', 'timing'],
                                                 ['Rate Similarity', 'Temporal Correlation', 'Spike Count', 'Timing Similarity'])):
            ax = axes[idx]
            ax.text(0.5, 0.5, f'No {label} Data\nCollected', ha='center', va='center',
                    fontsize=16, transform=ax.transAxes,
                    bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
            ax.set_title(label, fontweight='bold')
            ax.set_xticks([])
            ax.set_yticks([])
        
        plt.suptitle('Spike Similarity Metrics: No Data Available', fontsize=16, fontweight='bold')
        plt.tight_layout()
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        return
    
    # Create visualization with collected data
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    metric_names = ['rate', 'correlation', 'count', 'timing']
    metric_labels = ['Rate Similarity', 'Temporal Correlation', 'Spike Count', 'Timing Similarity']
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
    
    for idx, (metric, label, color) in enumerate(zip(metric_names, metric_labels, colors)):
        ax = axes[idx]
        
        correct_sims = similarities_by_metric[metric]['correct']
        incorrect_sims = similarities_by_metric[metric]['incorrect']
        
        if correct_sims and incorrect_sims and len(correct_sims) > 2 and len(incorrect_sims) > 2:
            # Create violin plots
            try:
                parts = ax.violinplot([correct_sims, incorrect_sims], positions=[1, 2], 
                                     widths=0.7, showmeans=True, showmedians=True)
                
                # Color the violins
                for pc in parts['bodies']:
                    pc.set_facecolor(color)
                    pc.set_alpha(0.7)
                
                # Statistical test
                try:
                    from scipy import stats
                    statistic, p_value = stats.ttest_ind(correct_sims, incorrect_sims)
                except ImportError:
                    p_value = 0.999  # Fallback if scipy not available
                
                ax.set_xticks([1, 2])
                ax.set_xticklabels(['Target', 'Distractor'])
                ax.set_ylabel('Similarity Score')
                ax.set_title(f'{label}\n(p={p_value:.3f}, n_correct={len(correct_sims)}, n_incorrect={len(incorrect_sims)})', 
                            fontweight='bold', fontsize=10)
                ax.grid(True, alpha=0.3, axis='y')
                
                # Add mean lines
                ax.hlines(np.mean(correct_sims), 0.5, 1.5, colors='green', 
                         linestyles='dashed', alpha=0.8, linewidth=2)
                ax.hlines(np.mean(incorrect_sims), 1.5, 2.5, colors='red', 
                         linestyles='dashed', alpha=0.8, linewidth=2)
                
            except Exception as e:
                ax.text(0.5, 0.5, f'Plotting Error:\n{str(e)[:50]}...', ha='center', va='center',
                        fontsize=12, transform=ax.transAxes)
                ax.set_title(label, fontweight='bold')
        else:
            # Insufficient data
            ax.text(0.5, 0.5, f'Insufficient Data\nCorrect: {len(correct_sims)}\nIncorrect: {len(incorrect_sims)}', 
                    ha='center', va='center', fontsize=12, transform=ax.transAxes,
                    bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))
            ax.set_title(label, fontweight='bold')
            ax.set_xticks([])
            ax.set_yticks([])
    
    plt.suptitle('Spike Similarity Metrics: Target vs Distractors', fontsize=16, fontweight='bold')
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  💾 Spike similarity metrics analysis saved to {save_path}")


def get_similarity_metrics(self, sender_spikes: torch.Tensor, candidate_spikes: torch.Tensor) -> dict:
    """
    Get similarity metrics for visualization.
    Fixed to actually compute and return metrics.
    """
    T, K, N = candidate_spikes.shape
    metrics = {}

    with torch.no_grad():
        # Extract features for sender
        sender_features, _ = self.extract_temporal_features(sender_spikes.unsqueeze(1))

        for k in range(K):
            candidate_k = candidate_spikes[:, k, :].unsqueeze(1)
            cand_features, _ = self.extract_temporal_features(candidate_k)

            # Compute various similarities
            cosine_sim = F.cosine_similarity(sender_features, cand_features, dim=1).mean().item()

            # Rate-based similarity
            sender_rate = sender_spikes.mean(0)
            cand_rate = candidate_spikes[:, k, :].mean(0)
            rate_sim = 1.0 - (sender_rate - cand_rate).abs().mean().item()

            # Temporal correlation
            sender_flat = sender_spikes.reshape(T, -1)
            cand_flat = candidate_spikes[:, k, :].reshape(T, -1)
            correlation = F.cosine_similarity(sender_flat.T, cand_flat.T, dim=1).mean().item()

            # Count similarity
            sender_count = sender_spikes.sum()
            cand_count = candidate_spikes[:, k, :].sum()
            count_sim = 1.0 - abs(sender_count - cand_count) / max(sender_count, cand_count)

            metrics[f'candidate_{k}'] = {
                'rate': max(0, min(1, rate_sim)),
                'correlation': max(0, min(1, (correlation + 1) / 2)),  # Normalize to [0, 1]
                'count': max(0, min(1, count_sim.item())),
                'timing': max(0, min(1, (cosine_sim + 1) / 2))  # Use cosine similarity as timing proxy
            }

    return metrics


def visualize_temporal_dynamics(agent_sender, agent_receiver, dataset, device, save_path):
    """Visualize temporal dynamics of spike communication"""
    
    print("   Analyzing temporal dynamics...")
    
    # Get samples from different classes
    class_spike_patterns = {}
    class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
                   'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    
    with torch.no_grad():
        for class_id in range(10):
            # Get samples of this class
            class_samples = get_samples_by_class(dataset, class_id, num_samples=5)
            
            if class_samples:
                spike_patterns = []
                for img in class_samples[:3]:  # Use 3 samples per class
                    img = img.to(device)
                    spikes = agent_sender.send_message(img)  # [T, n_msg]
                    spike_patterns.append(spikes.cpu().numpy())
                
                class_spike_patterns[class_id] = spike_patterns
    
    # Create visualization
    fig, axes = plt.subplots(2, 5, figsize=(20, 10))
    axes = axes.flatten()
    
    for class_id, ax in enumerate(axes):
        if class_id in class_spike_patterns:
            patterns = class_spike_patterns[class_id]
            
            # Average spike pattern
            avg_pattern = np.mean(patterns, axis=0)  # [T, n_msg]
            
            # Create heatmap
            im = ax.imshow(avg_pattern.T, aspect='auto', cmap='cividis', interpolation='nearest')
            
            # Add spike rate curve
            ax2 = ax.twinx()
            spike_rate = avg_pattern.mean(axis=1)  # Average across neurons
            ax2.plot(spike_rate, color='cyan', linewidth=2, alpha=0.8)
            ax2.set_ylabel('Avg Rate', color='cyan')
            ax2.tick_params(axis='y', labelcolor='cyan')
            
            ax.set_title(f'{class_names[class_id]}', fontweight='bold')
            ax.set_xlabel('Time Step')
            ax.set_ylabel('Neuron ID')
            
            # Add statistics
            total_spikes = np.sum(avg_pattern)
            ax.text(0.02, 0.98, f'Spikes: {total_spikes:.0f}', 
                   transform=ax.transAxes, fontsize=9,
                   bbox=dict(boxstyle="round,pad=0.3", facecolor='yellow', alpha=0.7),
                   verticalalignment='top')
    
    plt.suptitle('Temporal Spike Patterns by Class', fontsize=16, fontweight='bold')
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  💾 Temporal dynamics visualization saved to {save_path}")


def generate_emergent_communication_visualizations(agent_a, agent_b, dataset, device, epoch, 
                                                  metrics=None, base_path="plots"):
    """Generate comprehensive visualizations for emergent communication analysis"""
    
    print(f"\n Generating emergent communication visualizations for epoch {epoch}...")
    
    # Create epoch-specific directory
    epoch_dir = os.path.join(base_path, f"epoch_{epoch}")
    os.makedirs(epoch_dir, exist_ok=True)
    
    # Core visualizations
    visualizations = [
        ("Communication Confusion Matrix", visualize_communication_confusion_matrix, 
         "communication_confusion_matrix.png"),
        
        ("Decision Boundaries (t-SNE)", visualize_decision_boundaries_tsne,
         "decision_boundaries_tsne.png"),
        
        ("Spike Similarity Metrics", visualize_spike_similarity_metrics,
         "spike_similarity_metrics.png"),
        
        ("Temporal Dynamics", visualize_temporal_dynamics,
         "temporal_dynamics.png")
    ]
    
    # Run core visualizations
    for name, func, filename in visualizations:
        try:
            print(f"  🔄 Generating {name}...")
            func(agent_a, agent_b, dataset, device, os.path.join(epoch_dir, filename))
        except Exception as e:
            print(f"  ❌ Failed to generate {name}: {e}")
    
    # Metrics-based visualizations
    if metrics:
        try:
            print("  🔄 Generating Protocol Evolution...")
            visualize_protocol_evolution(agent_a, agent_b, metrics, 
                                       os.path.join(epoch_dir, "protocol_evolution.png"))
        except Exception as e:
            print(f"  ❌ Failed to generate protocol evolution: {e}")
        
        # Keep these from original if still useful
        try:
            visualize_shaped_rewards_impact(metrics, 
                                          os.path.join(epoch_dir, "shaped_rewards_impact.png"))
        except Exception as e:
            print(f"  ❌ Failed to generate shaped rewards analysis: {e}")
        
        try:
            visualize_action_confidence_evolution(metrics,
                                                  os.path.join(epoch_dir, "action_confidence.png"))
        except Exception as e:
            print(f"  ❌ Failed to generate confidence analysis: {e}")
    
    print(f"  ✅ Emergent communication visualizations complete for epoch {epoch}")


# Helper functions
def sample_referential_trial(dataset, K=3, device=None):
    """Sample a referential trial with target and candidates - improved error handling"""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    try:
        # Sample K random images
        indices = random.sample(range(len(dataset)), K)
        imgs = []
        labels = []
        
        for idx in indices:
            img, label = dataset[idx]
            # Ensure image is properly formatted
            if isinstance(img, torch.Tensor):
                imgs.append(img.to(device))
            else:
                imgs.append(torch.tensor(img, device=device))
            labels.append(label)
        
        # Choose target
        target_idx = random.randint(0, K-1)
        target_img = imgs[target_idx]
        candidates = imgs
        
        return target_img, candidates, target_idx, labels
        
    except Exception as e:
        print(f"Error in sample_referential_trial: {e}")
        # Return fallback data
        dummy_img = torch.zeros(1, 28, 28, device=device)
        return dummy_img, [dummy_img, dummy_img, dummy_img], 0, [0, 1, 2]


def get_samples_by_class(dataset, target_class, num_samples=50):
    """Get samples of a specific class"""
    class_samples = []
    for i in range(len(dataset)):
        img, label = dataset[i]
        if label == target_class and len(class_samples) < num_samples:
            class_samples.append(img)
    return class_samples

def visualize_shaped_rewards_impact(metrics, save_path):
    """Visualize the impact of shaped rewards on learning"""
    
    if 'shaped_reward' not in metrics:
        print("  ⚠️  No shaped reward data available")
        return
    
    epochs = range(1, len(metrics['avg_reward']) + 1)
    base_rewards = np.array(metrics['avg_reward'])
    shaped_rewards = np.array(metrics['shaped_reward'])
    accuracy = np.array(metrics['accuracy'])
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Shaped Rewards Analysis', fontsize=16, fontweight='bold')
    
    # 1. Reward comparison
    ax = axes[0, 0]
    ax.plot(epochs, base_rewards, label='Base Reward', linewidth=2, marker='o')
    ax.plot(epochs, shaped_rewards, label='Shaped Reward', linewidth=2, marker='s', linestyle='--')
    ax.fill_between(epochs, base_rewards, shaped_rewards, alpha=0.3, label='Shaping Bonus')
    ax.set_title('Base vs Shaped Rewards')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Reward')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 2. Shaping bonus evolution
    ax = axes[0, 1]
    shaping_bonus = shaped_rewards - base_rewards
    ax.plot(epochs, shaping_bonus, linewidth=2, color='green', marker='D')
    ax.fill_between(epochs, 0, shaping_bonus, alpha=0.3, color='green')
    ax.set_title('Reward Shaping Bonus Over Time')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Shaping Bonus')
    ax.grid(True, alpha=0.3)
    
    # 3. Reward vs Accuracy correlation
    ax = axes[1, 0]
    scatter1 = ax.scatter(base_rewards, accuracy, alpha=0.6, label='Base Reward', s=50)
    scatter2 = ax.scatter(shaped_rewards, accuracy, alpha=0.6, label='Shaped Reward', s=50, marker='s')
    
    # Add trend lines
    z1 = np.polyfit(base_rewards, accuracy, 1)
    p1 = np.poly1d(z1)
    z2 = np.polyfit(shaped_rewards, accuracy, 1)
    p2 = np.poly1d(z2)
    
    reward_range = np.linspace(min(base_rewards.min(), shaped_rewards.min()),
                              max(base_rewards.max(), shaped_rewards.max()), 100)
    ax.plot(reward_range, p1(reward_range), '--', alpha=0.8, color=scatter1.get_facecolors()[0])
    ax.plot(reward_range, p2(reward_range), '--', alpha=0.8, color=scatter2.get_facecolors()[0])
    
    ax.set_title('Reward-Accuracy Correlation')
    ax.set_xlabel('Reward')
    ax.set_ylabel('Accuracy (%)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 4. Learning efficiency
    ax = axes[1, 1]
    # Calculate cumulative rewards
    cum_base = np.cumsum(base_rewards)
    cum_shaped = np.cumsum(shaped_rewards)
    
    ax.plot(epochs, cum_base, label='Cumulative Base', linewidth=2)
    ax.plot(epochs, cum_shaped, label='Cumulative Shaped', linewidth=2, linestyle='--')
    ax.set_title('Cumulative Reward Comparison')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Cumulative Reward')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Add efficiency metric
    efficiency_gain = (cum_shaped[-1] - cum_base[-1]) / cum_base[-1] * 100
    ax.text(0.05, 0.95, f'Efficiency Gain: {efficiency_gain:.1f}%',
            transform=ax.transAxes, fontsize=12,
            bbox=dict(boxstyle="round,pad=0.3", facecolor='yellow', alpha=0.7))
    
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  💾 Shaped rewards analysis saved to {save_path}")


def visualize_action_confidence_evolution(metrics, save_path):
    """Fixed visualization for decision confidence evolution"""
    
    # Check for both possible keys
    confidence_key = 'action_confidence' if 'action_confidence' in metrics else 'confidence'
    
    if confidence_key not in metrics:
        print("  ⚠️  No confidence data available")
        return
    
    epochs = range(1, len(metrics[confidence_key]) + 1)
    confidence = np.array(metrics[confidence_key])
    
    # Check if confidence is in percentage or decimal
    if confidence.mean() > 1:  # Likely in percentage
        confidence = confidence / 100.0
    
    # Handle accuracy - ensure it's in [0,1] range
    accuracy = np.array(metrics['accuracy'])
    if accuracy.mean() > 1:  # If in percentage
        accuracy = accuracy / 100.0
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Decision Confidence Analysis', fontsize=16, fontweight='bold')
    
    # 1. Confidence evolution
    ax = axes[0, 0]
    ax.plot(epochs, confidence, linewidth=3, color='purple', marker='o', markersize=4)
    ax.fill_between(epochs, confidence, alpha=0.3, color='purple')
    
    # Add rolling average
    if len(confidence) > 5:
        rolling_avg = np.convolve(confidence, np.ones(5)/5, mode='valid')
        ax.plot(range(3, len(confidence)-1), rolling_avg, 'k--', linewidth=2, 
                label='5-epoch average', alpha=0.7)
    
    ax.set_title('Decision Confidence Over Time')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Average Confidence')
    ax.set_ylim(0, 1)
    ax.grid(True, alpha=0.3)
    
    # Add confidence bands
    for threshold, label, color in [(0.8, 'High', 'green'), 
                                   (0.6, 'Medium', 'orange'), 
                                   (0.4, 'Low', 'red')]:
        ax.axhline(threshold, linestyle='--', alpha=0.5, color=color, label=f'{label} ({threshold})')
    ax.legend(loc='best')
    
    # 2. Confidence vs Accuracy (FIXED)
    ax = axes[0, 1]
    
    # Create the scatter plot with larger markers
    scatter = ax.scatter(confidence, accuracy, c=epochs, cmap='viridis', 
                        s=100, alpha=0.8, edgecolors='black', linewidth=0.5)
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('Epoch')
    
    # Add ideal line
    ax.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect Calibration', linewidth=2)
    
    # Add regression line if we have enough points
    if len(confidence) > 3:
        # Ensure we have variation in the data
        if np.std(confidence) > 0 and np.std(accuracy) > 0:
            z = np.polyfit(confidence, accuracy, 1)
            p = np.poly1d(z)
            
            # Create a smooth line for the fit
            conf_range = np.linspace(confidence.min(), confidence.max(), 100)
            ax.plot(conf_range, p(conf_range), "r-", alpha=0.8, linewidth=2,
                    label=f'Trend: y={z[0]:.2f}x+{z[1]:.2f}')
    
    # Adjust axis limits to zoom into the data
    conf_margin = 0.1
    acc_margin = 0.1
    ax.set_xlim(max(0, confidence.min() - conf_margin), 
                min(1, confidence.max() + conf_margin))
    ax.set_ylim(max(0, accuracy.min() - acc_margin), 
                min(1, accuracy.max() + acc_margin))
    
    ax.set_title('Confidence vs Accuracy Correlation')
    ax.set_xlabel('Confidence')
    ax.set_ylabel('Accuracy')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 3. Confidence distribution with stats
    ax = axes[1, 0]
    # Create bins for different training phases
    phase_size = max(1, len(confidence) // 3)
    early = confidence[:phase_size]
    middle = confidence[phase_size:2*phase_size] if len(confidence) > phase_size else confidence
    late = confidence[2*phase_size:] if len(confidence) > 2*phase_size else confidence
    
    # Use more bins for better resolution
    n_bins = min(30, len(confidence) // 2)
    
    # Determine bin range based on actual data
    all_data = np.concatenate([early, middle, late])
    bin_range = (all_data.min() - 0.01, all_data.max() + 0.01)
    
    n, bins, patches = ax.hist([early, middle, late], bins=n_bins, range=bin_range,
                              label=[f'Early (μ={np.mean(early):.3f})', 
                                    f'Middle (μ={np.mean(middle):.3f})', 
                                    f'Late (μ={np.mean(late):.3f})'],
                              alpha=0.7, color=['red', 'orange', 'green'])
    
    ax.set_title('Confidence Distribution Across Training Phases')
    ax.set_xlabel('Confidence Level')
    ax.set_ylabel('Frequency')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 4. Calibration plot (FIXED)
    ax = axes[1, 1]
    
    # Use adaptive binning based on confidence range
    conf_range = confidence.max() - confidence.min()
    if conf_range < 0.2:  # If confidence is concentrated
        # Use fewer bins or custom bins around the data
        n_bins = 5
        bin_boundaries = np.linspace(confidence.min() - 0.05, 
                                   confidence.max() + 0.05, n_bins + 1)
    else:
        n_bins = 10
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
    
    bin_accuracy = []
    bin_confidence = []
    bin_counts = []
    
    for i in range(n_bins):
        in_bin = (confidence >= bin_boundaries[i]) & (confidence < bin_boundaries[i+1])
        if i == n_bins - 1:  # Include the last point in the last bin
            in_bin = (confidence >= bin_boundaries[i]) & (confidence <= bin_boundaries[i+1])
        
        if np.sum(in_bin) > 0:
            bin_accuracy.append(np.mean(accuracy[in_bin]))
            bin_confidence.append(np.mean(confidence[in_bin]))
            bin_counts.append(np.sum(in_bin))
    
    if bin_accuracy and len(bin_accuracy) > 0:
        # Calculate bar width based on actual data
        if len(bin_confidence) > 1:
            bar_width = 0.8 * min(np.diff(sorted(bin_confidence)))
        else:
            bar_width = 0.05
        
        # Plot calibration bars
        bars = ax.bar(bin_confidence, bin_accuracy, width=bar_width, alpha=0.6, 
                      edgecolor='black', linewidth=2, label='Observed')
        
        # Add value labels on bars
        for bar, conf, acc, count in zip(bars, bin_confidence, bin_accuracy, bin_counts):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{count}', ha='center', va='bottom', fontsize=8)
        
        # Plot perfect calibration line
        ax.plot([0, 1], [0, 1], 'k--', alpha=0.5, linewidth=2, label='Perfect Calibration')
        
        # Calculate ECE (Expected Calibration Error)
        ece = np.sum(np.abs(np.array(bin_accuracy) - np.array(bin_confidence)) * 
                    np.array(bin_counts)) / np.sum(bin_counts)
        
        # Calculate MCE (Maximum Calibration Error)
        mce = np.max(np.abs(np.array(bin_accuracy) - np.array(bin_confidence)))
        
        # Add text box with metrics
        textstr = f'ECE: {ece:.3f}\nMCE: {mce:.3f}\nBins: {len(bin_accuracy)}'
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
                verticalalignment='top', bbox=props)
        
        # Set axis limits to show the data better
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
    else:
        # If no bins have data, show a message
        ax.text(0.5, 0.5, 'Insufficient data for calibration plot', 
                ha='center', va='center', transform=ax.transAxes,
                fontsize=12, color='red')
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
    
    ax.set_title('Confidence Calibration')
    ax.set_xlabel('Mean Confidence')
    ax.set_ylabel('Actual Accuracy')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    # Print summary statistics
    print(f"  💾 Decision confidence analysis saved to {save_path}")
    print(f"  📊 Confidence range: [{confidence.min():.3f}, {confidence.max():.3f}]")
    print(f"  📊 Accuracy range: [{accuracy.min():.3f}, {accuracy.max():.3f}]")
    if bin_accuracy:
        print(f"  📊 ECE: {ece:.3f}, MCE: {mce:.3f}")


def create_zoomed_calibration_plot(metrics, save_path):
    """Create a zoomed-in calibration plot for narrow confidence ranges"""
    
    confidence_key = 'action_confidence' if 'action_confidence' in metrics else 'confidence'
    if confidence_key not in metrics:
        return
    
    confidence = np.array(metrics[confidence_key])
    accuracy = np.array(metrics['accuracy'])
    
    # Ensure proper scaling
    if confidence.mean() > 1:
        confidence = confidence / 100.0
    if accuracy.mean() > 1:
        accuracy = accuracy / 100.0
    
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # Determine the actual range of confidence values
    conf_min, conf_max = confidence.min(), confidence.max()
    conf_range = conf_max - conf_min
    
    # Create custom bins centered around the data
    n_bins = min(10, len(confidence) // 5)
    bin_boundaries = np.linspace(conf_min - conf_range * 0.1, 
                                conf_max + conf_range * 0.1, n_bins + 1)
    
    bin_accuracy = []
    bin_confidence = []
    bin_counts = []
    bin_std = []
    
    for i in range(n_bins):
        in_bin = (confidence >= bin_boundaries[i]) & (confidence < bin_boundaries[i+1])
        if i == n_bins - 1:
            in_bin = (confidence >= bin_boundaries[i]) & (confidence <= bin_boundaries[i+1])
        
        if np.sum(in_bin) > 0:
            bin_accuracy.append(np.mean(accuracy[in_bin]))
            bin_confidence.append(np.mean(confidence[in_bin]))
            bin_counts.append(np.sum(in_bin))
            bin_std.append(np.std(accuracy[in_bin]))
    
    if bin_accuracy:
        # Convert to arrays
        bin_accuracy = np.array(bin_accuracy)
        bin_confidence = np.array(bin_confidence)
        bin_counts = np.array(bin_counts)
        bin_std = np.array(bin_std)
        
        # Plot with error bars
        ax.errorbar(bin_confidence, bin_accuracy, yerr=bin_std, 
                   fmt='o', markersize=10, capsize=5, capthick=2,
                   label='Observed ± 1 STD', color='blue', ecolor='lightblue')
        
        # Add sample sizes
        for conf, acc, count in zip(bin_confidence, bin_accuracy, bin_counts):
            ax.annotate(f'n={count}', (conf, acc), 
                       xytext=(5, 5), textcoords='offset points',
                       fontsize=8, alpha=0.7)
        
        # Perfect calibration line
        ax.plot([0, 1], [0, 1], 'k--', alpha=0.5, linewidth=2, label='Perfect Calibration')
        
        # Zoom to data range
        margin = 0.05
        ax.set_xlim(conf_min - margin, conf_max + margin)
        ax.set_ylim(min(0, bin_accuracy.min() - margin), 
                   min(1, bin_accuracy.max() + margin))
        
        # Add grid
        ax.grid(True, alpha=0.3, which='both')
        ax.minorticks_on()
        
        # Calculate and display metrics
        ece = np.sum(np.abs(bin_accuracy - bin_confidence) * bin_counts) / np.sum(bin_counts)
        
        ax.set_title(f'Calibration Plot (Zoomed) - ECE: {ece:.3f}', fontsize=14)
    
    ax.set_xlabel('Confidence', fontsize=12)
    ax.set_ylabel('Accuracy', fontsize=12)
    ax.legend()
    
    plt.tight_layout()
    
    # Save with modified filename
    base, ext = os.path.splitext(save_path)
    zoomed_path = f"{base}_zoomed{ext}"
    plt.savefig(zoomed_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  💾 Zoomed calibration plot saved to {zoomed_path}")

def compute_qlearning_metrics(metrics):
    """Compute Q-learning metrics for publication"""
    
    qlearning_metrics = {}
    
    # 1. Convergence Rate
    if 'accuracy' in metrics:
        accuracy = np.array(metrics['accuracy'])
        
        # Find first epoch where accuracy > 90% 
        convergence_threshold = 90.0
        converged_epochs = np.where(accuracy >= convergence_threshold)[0]
        
        if len(converged_epochs) > 0:
            convergence_epoch = converged_epochs[0] + 1  # +1 for 1-indexing
            qlearning_metrics['convergence_epoch'] = convergence_epoch
            qlearning_metrics['convergence_rate'] = convergence_threshold / convergence_epoch
        else:
            qlearning_metrics['convergence_epoch'] = None
            qlearning_metrics['convergence_rate'] = None
        
        # Compute stability after convergence
        if len(converged_epochs) > 5:
            post_convergence_acc = accuracy[converged_epochs[0]:]
            qlearning_metrics['post_convergence_std'] = np.std(post_convergence_acc)
            qlearning_metrics['post_convergence_mean'] = np.mean(post_convergence_acc)
    
    # 2. Sample Efficiency
    if 'accuracy' in metrics:
        # Compute area under learning curve (normalized)
        epochs = np.arange(1, len(accuracy) + 1)
        auc = np.trapz(accuracy, epochs) / (len(accuracy) * 100)  # Normalize to [0,1]
        qlearning_metrics['learning_auc'] = auc
        
        # Samples to reach different performance levels
        for threshold in [50, 70, 80, 90, 95]:
            reached = np.where(accuracy >= threshold)[0]
            if len(reached) > 0:
                # Assuming 100 trials per epoch
                trials_per_epoch = 100
                samples_needed = (reached[0] + 1) * trials_per_epoch
                qlearning_metrics[f'samples_to_{threshold}pct'] = samples_needed
    
    # 3. Q-value Statistics
    if 'mean_q' in metrics and 'max_q' in metrics:
        mean_q = np.array(metrics['mean_q'])
        max_q = np.array(metrics['max_q'])
        
        # Q-value convergence
        if len(mean_q) > 10:
            # Check if Q-values have stabilized (low variance in last 10 epochs)
            recent_q_std = np.std(mean_q[-10:])
            qlearning_metrics['q_value_stability'] = recent_q_std
            
            # Q-value growth rate
            q_growth = (mean_q[-1] - mean_q[0]) / len(mean_q)
            qlearning_metrics['q_value_growth_rate'] = q_growth
    
    # 4. Exploration Efficiency
    if 'epsilon' in metrics and 'accuracy' in metrics:
        epsilon = np.array(metrics['epsilon'])
        
        # Compute exploration efficiency: accuracy gain per unit exploration
        if len(epsilon) > 1:
            exploration_used = epsilon[0] - epsilon[-1]
            accuracy_gained = accuracy[-1] - accuracy[0]
            qlearning_metrics['exploration_efficiency'] = accuracy_gained / exploration_used
    
    # 5. Learning Stability
    if 'accuracy' in metrics:
        # Compute smoothness of learning curve
        if len(accuracy) > 2:
            # First derivative (rate of change)
            acc_diff = np.diff(accuracy)
            qlearning_metrics['learning_smoothness'] = -np.mean(np.abs(acc_diff))
            
            # Count number of performance drops
            drops = np.sum(acc_diff < -5)  # Drops > 5%
            qlearning_metrics['significant_drops'] = drops
    
    # 6. Protocol Quality Metrics
    if 'protocol_discriminability' in metrics:
        protocol_disc = np.array(metrics['protocol_discriminability'])
        qlearning_metrics['final_protocol_quality'] = protocol_disc[-1]
        qlearning_metrics['avg_protocol_quality'] = np.mean(protocol_disc)
    
    # 7. Confidence Calibration
    if 'expected_calibration_error' in metrics:
        ece = np.array(metrics['expected_calibration_error'])
        qlearning_metrics['final_ece'] = ece[-1] if len(ece) > 0 else None
        qlearning_metrics['avg_ece'] = np.mean(ece) if len(ece) > 0 else None
    
    return qlearning_metrics


def plot_qlearning_metrics(metrics, save_path):
    """Create publication-ready Q-learning metrics visualization"""
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    # 1. Learning Curve with Convergence
    ax = axes[0]
    if 'accuracy' in metrics:
        epochs = range(1, len(metrics['accuracy']) + 1)
        ax.plot(epochs, metrics['accuracy'], 'b-', linewidth=2, label='Accuracy')
        
        # Mark convergence point
        qmetrics = compute_qlearning_metrics(metrics)
        if qmetrics.get('convergence_epoch'):
            ax.axvline(qmetrics['convergence_epoch'], color='r', linestyle='--', 
                      label=f'Convergence (epoch {qmetrics["convergence_epoch"]})')
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Accuracy (%)')
        ax.set_title('Learning Curve & Convergence')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    # 2. Q-value Evolution
    ax = axes[1]
    if 'mean_q' in metrics and 'max_q' in metrics:
        epochs = range(1, len(metrics['mean_q']) + 1)
        ax.plot(epochs, metrics['mean_q'], 'g-', linewidth=2, label='Mean Q')
        ax.plot(epochs, metrics['max_q'], 'r--', linewidth=2, label='Max Q')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Q-value')
        ax.set_title('Q-value Evolution')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    # 3. Sample Efficiency
    ax = axes[2]
    qmetrics = compute_qlearning_metrics(metrics)
    thresholds = [50, 70, 80, 90, 95]
    samples_needed = []
    achieved_thresholds = []
    
    for t in thresholds:
        key = f'samples_to_{t}pct'
        if key in qmetrics and qmetrics[key] is not None:
            samples_needed.append(qmetrics[key])
            achieved_thresholds.append(t)
    
    if samples_needed:
        ax.bar(range(len(achieved_thresholds)), samples_needed, 
               tick_label=achieved_thresholds)
        ax.set_xlabel('Accuracy Threshold (%)')
        ax.set_ylabel('Samples Needed')
        ax.set_title('Sample Efficiency')
        ax.grid(True, alpha=0.3, axis='y')
    
    # 4. Exploration vs Performance
    ax = axes[3]
    if 'epsilon' in metrics and 'accuracy' in metrics:
        # Normalize both to [0,1] for comparison
        epsilon_norm = np.array(metrics['epsilon']) / metrics['epsilon'][0]
        accuracy_norm = np.array(metrics['accuracy']) / 100.0
        
        epochs = range(1, len(epsilon_norm) + 1)
        ax.plot(epochs, epsilon_norm, 'r-', linewidth=2, label='Exploration (ε)')
        ax.plot(epochs, accuracy_norm, 'b-', linewidth=2, label='Performance')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Normalized Value')
        ax.set_title('Exploration-Exploitation Trade-off')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    # 5. Learning Stability
    ax = axes[4]
    if 'accuracy' in metrics:
        accuracy = np.array(metrics['accuracy'])
        if len(accuracy) > 1:
            # Compute rolling standard deviation
            window = 5
            rolling_std = []
            for i in range(window, len(accuracy)):
                rolling_std.append(np.std(accuracy[i-window:i]))
            
            ax.plot(range(window, len(accuracy)), rolling_std, 'purple', linewidth=2)
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Rolling Std Dev (5 epochs)')
            ax.set_title('Learning Stability')
            ax.grid(True, alpha=0.3)
    
    # 6. Summary Statistics
    ax = axes[5]
    ax.axis('off')
    
    summary_text = "Q-Learning Performance Metrics\n" + "="*30 + "\n"
    
    for key, value in qmetrics.items():
        if value is not None:
            if isinstance(value, float):
                summary_text += f"{key}: {value:.3f}\n"
            else:
                summary_text += f"{key}: {value}\n"
    
    ax.text(0.1, 0.9, summary_text, transform=ax.transAxes, 
            fontsize=10, verticalalignment='top', fontfamily='monospace')
    
    plt.suptitle('Q-Learning Performance Analysis', fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  💾 Q-learning metrics visualization saved to {save_path}")
    
    return qmetrics