
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict

class CommunicationProtocolAnalyzer(nn.Module):
    """
    Fixed module to analyze emerging communication protocols.
    """
    def __init__(self, n_msg: int = 128, num_steps: int = 25):
        super(CommunicationProtocolAnalyzer, self).__init__()
        self.n_msg = n_msg
        self.num_steps = num_steps
        
        # Learn prototypical spike patterns for each class
        self.class_prototypes = nn.Parameter(torch.randn(10, n_msg) * 0.1)
        
    def analyze_protocol(self, spike_patterns, labels):
        """
        Analyze how consistent the communication protocol is.
        Returns metrics about protocol stability and discriminability.
        """
        if len(spike_patterns) == 0:
            return {
                'within_class_similarity': 0.0,
                'between_class_similarity': 0.0,
                'protocol_discriminability': 0.0
            }
        
        # Ensure we have proper tensors
        if isinstance(spike_patterns, list):
            spike_patterns = torch.stack(spike_patterns)
        if isinstance(labels, list):
            labels = torch.tensor(labels)
        
        # Group spike patterns by class
        class_patterns = {}
        for pattern, label in zip(spike_patterns, labels):
            label_item = label.item() if isinstance(label, torch.Tensor) else label
            if label_item not in class_patterns:
                class_patterns[label_item] = []
            class_patterns[label_item].append(pattern)
        
        # Compute within-class and between-class similarities
        within_class_sims = []
        between_class_sims = []
        
        # Within-class similarities
        for class_id, patterns in class_patterns.items():
            if len(patterns) > 1:
                patterns_tensor = torch.stack(patterns)
                # Compute pairwise similarities within class
                for i in range(len(patterns)):
                    for j in range(i+1, len(patterns)):
                        sim = F.cosine_similarity(patterns[i], patterns[j], dim=0)
                        within_class_sims.append(sim.item())
        
        # Between-class similarities
        class_ids = list(class_patterns.keys())
        for i in range(len(class_ids)):
            for j in range(i+1, len(class_ids)):
                patterns_i = class_patterns[class_ids[i]][:5]  # Limit to 5 samples
                patterns_j = class_patterns[class_ids[j]][:5]  # Limit to 5 samples
                
                for p1 in patterns_i:
                    for p2 in patterns_j:
                        sim = F.cosine_similarity(p1, p2, dim=0)
                        between_class_sims.append(sim.item())
        
        # Calculate metrics
        within_sim = np.mean(within_class_sims) if within_class_sims else 0.5
        between_sim = np.mean(between_class_sims) if between_class_sims else 0.5
        
        protocol_metrics = {
            'within_class_similarity': within_sim,
            'between_class_similarity': between_sim,
            'protocol_discriminability': max(0, within_sim - between_sim)
        }
        
        return protocol_metrics


