import random
import os
import numpy as np
from typing import Tuple, List
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict
from utils.confidence_plots import plot_accuracy_with_confidence, compute_confidence_intervals

from modules.SpikeAgent import SpikeAgent

from utils.helpers import save_metrics, save_models
from utils.visualisers import generate_emergent_communication_visualizations, compute_qlearning_metrics, plot_qlearning_metrics


def set_seed(seed_value=42):
    """Sets the seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    print(f"Random seed set as {seed_value}")

def improved_curriculum_candidates(dataset, epoch, K=3, max_epochs=50):
    """
    Enhanced curriculum learning with smoother difficulty progression.
    """
    # Fashion-MNIST class groupings by visual similarity
    class_groups = {
        'very_similar': {
            'upper_body': [0, 2, 4, 6],  # T-shirt, Pullover, Coat, Shirt
            'footwear': [5, 7, 9],       # Sandal, Sneaker, Ankle boot
        },
        'moderately_similar': {
            'clothing': [0, 1, 2, 3, 4, 6],  # All clothing items
            'accessories': [3, 8],            # Dress, Bag
        },
        'distinct': {
            'group1': [0, 1, 2],  # T-shirt, Trouser, Pullover
            'group2': [3, 4, 5],  # Dress, Coat, Sandal
            'group3': [6, 7, 8, 9]  # Shirt, Sneaker, Bag, Ankle boot
        }
    }
    
    # Calculate curriculum progress
    progress = min(1.0, epoch / (max_epochs * 0.6))  # Full difficulty by 60% of training
    
    candidates = []
    labels = []
    
    if progress < 0.3:
        # Stage 1: Very different classes (easy)
        difficulty = 'easy'
        groups = list(class_groups['distinct'].values())
        selected_groups = random.sample(groups, min(K, len(groups)))
        
        for group in selected_groups:
            class_idx = random.choice(group)
            # Find an image of this class
            for i in range(len(dataset)):
                if dataset[i][1] == class_idx:
                    candidates.append(dataset[i][0])
                    labels.append(dataset[i][1])
                    break
    
    elif progress < 0.7:
        # Stage 2: Mix of similar and different (medium)
        difficulty = 'medium'
        # 50% chance to pick from similar groups
        for _ in range(K):
            if random.random() < 0.5:
                # Pick from moderately similar groups
                group_type = random.choice(['upper_body', 'footwear'])
                class_idx = random.choice(class_groups['very_similar'][group_type])
            else:
                # Pick randomly
                class_idx = random.randint(0, 9)
            
            # Find an image of this class
            for i in range(len(dataset)):
                if dataset[i][1] == class_idx:
                    candidates.append(dataset[i][0])
                    labels.append(dataset[i][1])
                    break
    
    else:
        # Stage 3: Any combination including very similar (hard)
        difficulty = 'hard'
        # Bias towards similar classes
        if random.random() < 0.7:
            # Pick from similar groups
            group_type = random.choice(['upper_body', 'footwear'])
            selected_classes = random.sample(class_groups['very_similar'][group_type], 
                                             min(K, len(class_groups['very_similar'][group_type])))
            for class_idx in selected_classes:
                for i in range(len(dataset)):
                    if dataset[i][1] == class_idx:
                        candidates.append(dataset[i][0])
                        labels.append(dataset[i][1])
                        break
        else:
            # Random selection
            for _ in range(K):
                idx = random.randint(0, len(dataset)-1)
                candidates.append(dataset[idx][0])
                labels.append(dataset[idx][1])
    
    # Ensure we have exactly K candidates
    while len(candidates) < K:
        idx = random.randint(0, len(dataset)-1)
        candidates.append(dataset[idx][0])
        labels.append(dataset[idx][1])
    
    return candidates[:K], labels[:K], difficulty

def adaptive_temperature_annealing(agent, epoch, max_epochs, accuracy_history, confidence_history):
    """
    Enhanced adaptive temperature annealing based on confidence-accuracy alignment.
    """
    if not hasattr(agent.decision, 'temperature'):
        return 1.0
    
    device = next(agent.decision.parameters()).device
    
    # Calculate current calibration gap
    if len(accuracy_history) > 5 and len(confidence_history) > 5:
        recent_acc = np.mean(accuracy_history[-10:]) / 100.0
        recent_conf = np.mean(confidence_history[-10:])
        calibration_gap = recent_acc - recent_conf
        
        # Get current temperature
        current_temp = agent.decision.temperature.item()
        
        # Adaptive adjustment with smoother transitions
        if calibration_gap > 0.10:  # Model is underconfident
            # Decrease temperature more gradually
            adjustment_factor = 0.97
            target_temp = max(0.5, current_temp * adjustment_factor)
        elif calibration_gap < -0.10:  # Model is overconfident
            # Increase temperature
            adjustment_factor = 1.03
            target_temp = min(3.0, current_temp * adjustment_factor)
        else:
            # Well calibrated - gentle annealing
            progress = epoch / max_epochs
            base_temp = 1.5
            final_temp = 0.8
            target_temp = base_temp - (base_temp - final_temp) * progress ** 2
        
        # Smooth transition with larger momentum
        new_temp = 0.9 * current_temp + 0.1 * target_temp
        
    else:
        # Early epochs - start with higher temperature
        progress = epoch / max_epochs
        new_temp = 1.5 * (0.8 / 1.5) ** (progress * 0.7)
    
    # Update temperature
    agent.decision.temperature.data = torch.tensor(new_temp, device=device)
    return new_temp

def run_training(
    pretrained_commsmod_path: str,
    epochs: int = 50,
    trials_per_epoch: int = 100,
    K: int = 3,
    γ: float = 0.99,
    ε_start: float = 0.3,         # Lower starting exploration
    ε_final: float = 0.02,         # Much lower final exploration
    lr: float = 1e-4,
    role_swap_freq: int = 25,
    target_update_freq: int = 10,
    soft_update_freq: int = 1,
    tau: float = 0.005,            # Larger soft update
    beta: float = 0.9,
    freeze_commsmod: bool = True,
    use_shaped_rewards: bool = True,
    use_auxiliary_loss: bool = True,
    use_curriculum: bool = True,
    device: torch.device = None,
    log_interval: int = 10,
    plot_interval: int = 5,
) -> Tuple[dict, 'SpikeAgent', 'SpikeAgent']:
    """
    Fixed training loop with improved confidence calibration.
    """
    device = device or torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"\nUsing device: {device}\nLoading dataset...")

    # Reproducibility
    set_seed(42)

    # Data preparation
    full = FashionMNIST("./data", train=True, download=True, transform=ToTensor())
    dataset = Subset(full, list(range(1000)))
    

    
    # Initialize agents
    agent_a = SpikeAgent(
        pretrained_commsmod_path, n_msg=128, n_actions=K,
        ε=ε_start, γ=γ, lr=lr, beta=beta,
        freeze_commsmod=freeze_commsmod,
        use_shaped_rewards=use_shaped_rewards,
        use_auxiliary_loss=use_auxiliary_loss
    )
    
    agent_b = SpikeAgent(
        pretrained_commsmod_path, n_msg=128, n_actions=K,
        ε=ε_start, γ=γ, lr=lr, beta=beta,
        freeze_commsmod=freeze_commsmod,
        use_shaped_rewards=use_shaped_rewards,
        use_auxiliary_loss=use_auxiliary_loss
    )
    
    # Move to device
    agent_a.to(device)
    agent_b.to(device)
    
    # Smoother exploration decay
    ε_decay = (ε_start - ε_final) / (epochs * trials_per_epoch * 0.8)
    ε = ε_start
    
    # Metrics storage
    metrics = defaultdict(list)
    step = 0
    
    for epoch in range(1, epochs + 1):
        epoch_metrics = defaultdict(list)
        
        # Clear histories periodically
        if epoch % 10 == 0:
            agent_a.clear_message_history()
            agent_b.clear_message_history()
        
        pbar = tqdm(range(trials_per_epoch), desc=f"Epoch {epoch}/{epochs}", ncols=60)
        
        for trial in pbar:
            # Role assignment
            sender, receiver = (agent_a, agent_b) if (trial // role_swap_freq) % 2 == 0 else (agent_b, agent_a)
            
            # Sample candidates with curriculum
            if use_curriculum:
                imgs, labels, difficulty = improved_curriculum_candidates(dataset, epoch, K, epochs)
                imgs = [img.to(device) for img in imgs]
            else:
                idxs = random.sample(range(len(dataset)), K)
                imgs = [dataset[i][0].to(device) for i in idxs]
                labels = [dataset[i][1] for i in idxs]
                difficulty = 'standard'
            
            # Select target
            target_idx = random.randrange(K)
            target_img = imgs[target_idx].to(device)
            
            # Communication
            sender_msg = sender.send_message(target_img, track_protocol=True, label=labels[target_idx])
            sender_msg = sender_msg.to(device)
            
            # Decision
            action, Qs, info = receiver.make_decision(imgs, sender_msg, candidate_labels=labels)
            
            # Compute rewards
            base_r = float(action == target_idx)
            if use_shaped_rewards:
                shaped_r = receiver.compute_protocol_aware_reward(
                    action, target_idx, info, sender_msg
                )
            else:
                shaped_r = base_r
            
            # Track metrics
            action_confidence = info.get('action_confidence', 0.5)
            
            epoch_metrics["reward"].append(base_r)
            epoch_metrics["shaped_reward"].append(shaped_r)
            epoch_metrics["accuracy"].append(base_r)
            epoch_metrics["q_values"].extend(Qs.detach().cpu().tolist())
            epoch_metrics["confidence"].append(info.get('max_confidence', 0.5))
            epoch_metrics["action_confidence"].append(action_confidence)
            epoch_metrics["q_gap"].append(info.get('q_gap', 0.0))
                # Initialize tracking if not exists
            if 'q_value_variance' not in metrics:
                metrics['q_value_variance'] = []
            if 'action_entropy' not in metrics:
                metrics['action_entropy'] = []
            if 'td_error' not in metrics:
                metrics['td_error'] = []
            
            
            # Q-value variance (from epoch_metrics['q_values'])
            if 'q_values' in epoch_metrics:
                q_vals = np.array(epoch_metrics['q_values'])
                metrics['q_value_variance'].append(np.var(q_vals))
            
            # Action entropy (from confidence scores)
            if 'confidence_entropy' in epoch_metrics:
                avg_entropy = np.mean(epoch_metrics['confidence_entropy'])
                metrics['action_entropy'].append(avg_entropy)
            
            # TD errors (if available)
            if 'td_errors' in epoch_metrics:
                avg_td_error = np.mean(np.abs(epoch_metrics['td_errors']))
                metrics['td_error'].append(avg_td_error)
                
            # Build transition
            transition = (
                target_img.to(device),
                [img.to(device) for img in imgs],
                action, shaped_r,
                None, None, True, info, labels
            )
            
            # Update agents
            loss_a = sender.update(transition)
            loss_b = receiver.update(transition)
            epoch_metrics["loss"].append((loss_a + loss_b) / 2)
            
            # Update exploration rate
            ε = max(ε_final, ε - ε_decay)
            sender.update_exploration_rate(ε)
            receiver.update_exploration_rate(ε)
            
            # Soft updates
            if step % soft_update_freq == 0:
                sender.soft_update_target(tau)
                receiver.soft_update_target(tau)
            
            # Target sync
            if step % target_update_freq == 0:
                sender.sync_target()
                receiver.sync_target()
            
            step += 1
            
            # Update progress bar
            pbar.set_postfix({
                "acc": f"{np.mean(epoch_metrics['accuracy']):.1%}",
                "conf": f"{np.mean(epoch_metrics['confidence']):.1%}",
                "ε": f"{ε:.3f}"
            }, refresh=False)
            
            # Log periodically
            if (trial + 1) % log_interval == 0:
                    avg_r  = np.mean(epoch_metrics["reward"])
                    avg_sh = np.mean(epoch_metrics["shaped_reward"])
                    avg_l  = np.mean(epoch_metrics["loss"])
                    q_mean = np.mean(epoch_metrics["q_values"])
                    q_max  = np.max(epoch_metrics["q_values"])
                    acc    = np.mean(epoch_metrics["accuracy"])*100
                    conf   = np.mean(epoch_metrics["confidence"])
                    print(f"[Epoch: {epoch:02d} | Trial: {trial+1:03d}] | Ave Reward: {avg_r:.3f} | Shaped Reward: {avg_sh:.3f}  Loss: {avg_l:.4f}  "
                          f"ε:{ε:.3f} | Qmean: {q_mean:.3f} | Qmax: {q_max:.3f} | Accuracy: {acc:.1f}% | Confidence:{conf*100:.1f}%")

        
        # End of epoch processing
        # Calculate epoch statistics
        epoch_accuracy = np.mean(epoch_metrics["accuracy"]) * 100
        epoch_confidence = np.mean(epoch_metrics["action_confidence"]) * 100
        
        # Update history
        agent_a.accuracy_history.append(epoch_accuracy)
        agent_b.accuracy_history.append(epoch_accuracy)
        
        # Store metrics
        metrics["accuracy"].append(epoch_accuracy)
        metrics["avg_reward"].append(np.mean(epoch_metrics["reward"]))
        metrics["shaped_reward"].append(np.mean(epoch_metrics["shaped_reward"]))
        metrics["avg_loss"].append(np.mean(epoch_metrics["loss"]))
        metrics["mean_q"].append(np.mean(epoch_metrics["q_values"]))
        metrics["max_q"].append(np.max(epoch_metrics["q_values"]))
        metrics["action_confidence"].append(epoch_confidence)
        metrics["q_gap"].append(np.mean(epoch_metrics["q_gap"]))
        metrics["epsilon"].append(ε)
        metrics["learning_rate"].append(agent_a.optimizer.param_groups[0]['lr'])
        
        if epoch % 5 == 0:  # Compute every 5 epochs
            qmetrics = compute_qlearning_metrics(metrics)
            print(f"\n📊 Q-Learning Metrics (Epoch {epoch}):")
            print(f"  Convergence: {qmetrics.get('convergence_epoch', 'Not yet')}")
            print(f"  Sample Efficiency (AUC): {qmetrics.get('learning_auc', 0):.3f}")
            print(f"  Q-value Stability: {qmetrics.get('q_value_stability', 0):.4f}")
            
            # Generate Q-learning plots
            plot_qlearning_metrics(metrics, f'plots/epoch_{epoch}/qlearning_metrics_epoch_{epoch}.png')
        # Protocol analysis
        protocol_metrics = agent_a.analyze_protocol_development()
        for key, value in protocol_metrics.items():
            if key not in metrics:
                metrics[key] = []
            metrics[key].append(value)
        
        # Adaptive temperature annealing
        if epoch >= 3:  # Start after initial learning
            confidence_history = metrics["action_confidence"]
            accuracy_history = metrics["accuracy"]
            
            new_temp_a = adaptive_temperature_annealing(
                agent_a, epoch, epochs, accuracy_history, confidence_history
            )
            new_temp_b = adaptive_temperature_annealing(
                agent_b, epoch, epochs, accuracy_history, confidence_history
            )
            
            avg_temp = (new_temp_a + new_temp_b) / 2
            
            if "temperature" not in metrics:
                metrics["temperature"] = []
            metrics["temperature"].append(avg_temp)
            
            # Learning rate scheduling based on accuracy
            agent_a.update_learning_rate(epoch_accuracy)
            agent_b.update_learning_rate(epoch_accuracy)
    
                
        # Epoch summary
        calibration_gap = epoch_accuracy - epoch_confidence
        print(f"\n{'='*60}")
        print(f"EPOCH {epoch} SUMMARY:")
        print(f"  Accuracy: {epoch_accuracy:.1f}% | Confidence: {epoch_confidence:.1f}%")
        print(f"  Calibration Gap: {calibration_gap:+.1f}%")
        print(f"  Protocol Discriminability: {protocol_metrics.get('protocol_discriminability', 0):.3f}")
        print(f"  Temperature: {agent_a.decision.temperature.item():.3f}")
        print(f"  Exploration Rate: {ε:.3f}")
        
        if epoch_accuracy > 80 and abs(calibration_gap) < 5:
            print(f"  🎉 EXCELLENT PERFORMANCE AND CALIBRATION!")
        elif epoch_accuracy > 75:
            print(f"  ✅ Good performance, working on calibration...")
        
        # Save periodically
        if epoch % plot_interval == 0:            
            save_metrics(metrics, epoch)
            generate_emergent_communication_visualizations(
                agent_a, agent_b, dataset, device, epoch, metrics
            )
            plot_emergent_communication_analysis(metrics, epoch, save_path="analysis")  
            plot_accuracy_with_confidence(metrics, f"analysis/accuracy_with_ci_epoch_{epoch}.png", method="bootstrap")
            print(f"\nModels and visualizations saved for epoch {epoch}")
            save_models(agent_a, agent_b, epoch)
            
        if epoch <= 3:  # Only debug first few epochs
            q_values = epoch_metrics["q_values"]
            q_std = np.std(q_values)
            q_range = np.max(q_values) - np.min(q_values)
            print(f"  Q-value stats: std={q_std:.4f}, range={q_range:.4f}")
            
            # Check if Q-values are too similar
            if q_std < 0.01:
                print("  ⚠️ WARNING: Q-values too similar - may cause confidence issues")
    
    final_qmetrics = compute_qlearning_metrics(metrics)
    print("\n🏁 FINAL Q-LEARNING METRICS:")
    for key, value in final_qmetrics.items():
        if value is not None:
            print(f"  {key}: {value:.3f}" if isinstance(value, float) else f"  {key}: {value}")   
                      
    return metrics, agent_a, agent_b


def plot_emergent_communication_analysis(metrics: dict, epoch: int, save_path: str = "analysis"):
    """Generate analysis plots specific to emergent communication."""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()

    # Get base epochs from most reliable metric
    base_epochs = None
    for key in ['accuracy', 'avg_reward', 'avg_loss']:
        if key in metrics and len(metrics[key]) > 0:
            base_epochs = range(1, len(metrics[key]) + 1)
            break
    
    if base_epochs is None:
        print("  ⚠️  No reliable metrics found for epoch reference")
        base_epochs = range(1, epoch + 1)

    # 1. Protocol Development
    ax = axes[0]
    has_disc = 'protocol_discriminability' in metrics and len(metrics['protocol_discriminability']) > 0
    has_within = 'within_class_similarity' in metrics and len(metrics['within_class_similarity']) > 0
    has_between = 'between_class_similarity' in metrics and len(metrics['between_class_similarity']) > 0
    
    if has_disc or has_within or has_between:
        if has_disc:
            disc_epochs = range(1, len(metrics['protocol_discriminability']) + 1)
            ax.plot(disc_epochs, metrics['protocol_discriminability'], 'b-', 
                   linewidth=2, label='Discriminability')
        
        if has_within:
            within_epochs = range(1, len(metrics['within_class_similarity']) + 1)
            ax.plot(within_epochs, metrics['within_class_similarity'], 'g--', 
                   linewidth=2, label='Within-class')
        
        if has_between:
            between_epochs = range(1, len(metrics['between_class_similarity']) + 1)
            ax.plot(between_epochs, metrics['between_class_similarity'], 'r:', 
                   linewidth=2, label='Between-class')
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Similarity')
        ax.set_title('Communication Protocol Development')
        ax.legend()
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No Protocol Development\nData Available', ha='center', va='center',
                fontsize=14, transform=ax.transAxes)
        ax.set_title('Communication Protocol Development')
        ax.axis('off')

    # 2. Attention Consistency
    ax = axes[1]
    if 'attention_consistency' in metrics and len(metrics['attention_consistency']) > 0:
        att_epochs = range(1, len(metrics['attention_consistency']) + 1)
        ax.plot(att_epochs, metrics['attention_consistency'], linewidth=2, color='purple')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Attention Consistency')
        ax.set_title('Temporal Attention Pattern Stability')
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No Attention Data\nAvailable', ha='center', va='center',
                fontsize=14, transform=ax.transAxes)
        ax.set_title('Temporal Attention Pattern Stability')
        ax.axis('off')

    # 3. Confidence vs Accuracy
    ax = axes[2]
    has_conf = 'action_confidence' in metrics and len(metrics['action_confidence']) > 0
    has_acc = 'accuracy' in metrics and len(metrics['accuracy']) > 0
    
    if has_conf and has_acc:
        # Use the shorter length
        min_len = min(len(metrics['action_confidence']), len(metrics['accuracy']))
        conf_epochs = range(1, min_len + 1)
        
        confidence_scaled = [c for c in metrics['action_confidence'][:min_len]]
        accuracy_data = metrics['accuracy'][:min_len]
        
        ax.plot(conf_epochs, accuracy_data, 'b-', linewidth=2, label='Accuracy')
        ax.plot(conf_epochs, confidence_scaled, 'r--', linewidth=2, label='Confidence')
        
        # Color coding for calibration
        ax.fill_between(conf_epochs, accuracy_data, confidence_scaled,
                       where=[a > c for a, c in zip(accuracy_data, confidence_scaled)],
                       color='green', alpha=0.3, label='Well-calibrated')
        ax.fill_between(conf_epochs, accuracy_data, confidence_scaled,
                       where=[a <= c for a, c in zip(accuracy_data, confidence_scaled)],
                       color='red', alpha=0.3, label='Overconfident')
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Percentage')
        ax.set_title('Confidence Calibration')
        ax.legend()
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No Confidence/Accuracy\nData Available', ha='center', va='center',
                fontsize=14, transform=ax.transAxes)
        ax.set_title('Confidence Calibration')
        ax.axis('off')

    # 4. Curriculum Progress placeholder
    ax = axes[3]
    ax.text(0.5, 0.5, 'Standard Training\nMode Active', ha='center', va='center',
            fontsize=14, transform=ax.transAxes,
            bbox=dict(boxstyle="round,pad=0.5", facecolor='lightblue', alpha=0.8))
    ax.set_title('Training Mode')
    ax.axis('off')

    # 5. Q-value Distribution Evolution
    ax = axes[4]
    has_q_mean = 'mean_q' in metrics and len(metrics['mean_q']) > 0
    has_q_max = 'max_q' in metrics and len(metrics['max_q']) > 0
    
    if has_q_mean or has_q_max:
        if has_q_mean:
            q_mean_epochs = range(1, len(metrics['mean_q']) + 1)
            ax.plot(q_mean_epochs, metrics['mean_q'], 'b-', linewidth=2, label='Mean Q')
        
        if has_q_max:
            q_max_epochs = range(1, len(metrics['max_q']) + 1)
            ax.plot(q_max_epochs, metrics['max_q'], 'r--', linewidth=2, label='Max Q')
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Q-value')
        ax.set_title('Q-value Evolution')
        ax.legend()
        ax.grid(True, alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'No Q-value Data\nAvailable', ha='center', va='center',
                fontsize=14, transform=ax.transAxes)
        ax.set_title('Q-value Evolution')
        ax.axis('off')

    # 6. Summary Statistics
    ax = axes[5]
    ax.axis('off')
    
    # Safely get final values with fallbacks
    final_accuracy = metrics.get('accuracy', [0])[-1] if metrics.get('accuracy') else 0
    final_discriminability = (metrics.get('protocol_discriminability', [0])[-1] 
                             if metrics.get('protocol_discriminability') else 0)
    final_confidence = (metrics.get('action_confidence', [0])[-1] 
                       if metrics.get('action_confidence') else 0)
    final_within_sim = (metrics.get('within_class_similarity', [0])[-1] 
                       if metrics.get('within_class_similarity') else 0)
    final_between_sim = (metrics.get('between_class_similarity', [0])[-1] 
                        if metrics.get('between_class_similarity') else 0)
    final_epsilon = metrics.get('epsilon', [0])[-1] if metrics.get('epsilon') else 0
    final_learning_rate = (metrics.get('learning_rate', [0])[-1] 
                          if metrics.get('learning_rate') else "N/A")
    
    # Format learning rate
    if isinstance(final_learning_rate, (int, float)):
        lr_str = f"{final_learning_rate:.6f}"
    else:
        lr_str = str(final_learning_rate)
    
    summary_text = f"""
    Emergent Communication Summary (Epoch {epoch})
    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    Final Accuracy: {final_accuracy:.1f}%
    Protocol Discriminability: {final_discriminability:.3f}
    Confidence Calibration: {abs(final_accuracy - final_confidence*100):.1f}% gap
    Protocol Quality:
    - Within-class similarity: {final_within_sim:.3f}
    - Between-class similarity: {final_between_sim:.3f}
    - Separation: {final_within_sim - final_between_sim:.3f}
    Learning Dynamics:
    - Current ε: {final_epsilon:.3f}
    - Learning rate: {lr_str}
    """
    
    ax.text(0.1, 0.9, summary_text, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgray', alpha=0.8))

    plt.suptitle(f'Emergent Communication Analysis - Epoch {epoch}', fontsize=16)
    plt.tight_layout()
    
    os.makedirs(save_path, exist_ok=True)
    plt.savefig(f'{save_path}/emergent_communication_analysis_epoch_{epoch}.png',
                dpi=150, bbox_inches='tight')
    plt.close()
