import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Tuple, Optional, Any
import time
from collections import defaultdict
import math

class ClassificationMetrics:
    """
    Standard classification metrics for neural network evaluation.
    Computes accuracy, precision, recall, F1-score, and confusion matrix.
    """
    
    def __init__(self, num_classes: int, average: str = 'macro'):
        self.num_classes = num_classes
        self.average = average
        self.reset()
    
    def reset(self):
        """Reset all accumulated statistics."""
        self.correct = 0
        self.total = 0
        self.class_correct = torch.zeros(self.num_classes)
        self.class_total = torch.zeros(self.num_classes)
        self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes)
    
    def update(self, predictions: torch.Tensor, targets: torch.Tensor):
        """
        Update metrics with new batch of predictions and targets.
        
        :param predictions: Model predictions [B, num_classes] or [B]
        :param targets: Ground truth labels [B]
        """
        if predictions.dim() > 1:
            predicted_classes = predictions.argmax(dim=1)
        else:
            predicted_classes = predictions
        
        # Overall accuracy
        correct_mask = predicted_classes == targets
        self.correct += correct_mask.sum().item()
        self.total += targets.size(0)
        
        # Per-class statistics
        for i in range(self.num_classes):
            class_mask = targets == i
            self.class_total[i] += class_mask.sum().item()
            self.class_correct[i] += (correct_mask & class_mask).sum().item()
        
        # Confusion matrix
        for t, p in zip(targets.view(-1), predicted_classes.view(-1)):
            self.confusion_matrix[t.long(), p.long()] += 1
    
    def compute(self) -> Dict[str, float]:
        """
        Compute final metrics from accumulated statistics.
        
        :return: Dictionary containing all computed metrics
        """
        accuracy = self.correct / max(self.total, 1)
        
        # Per-class precision, recall, F1
        precision = torch.zeros(self.num_classes)
        recall = torch.zeros(self.num_classes)
        f1 = torch.zeros(self.num_classes)
        
        for i in range(self.num_classes):
            tp = self.confusion_matrix[i, i].item()
            fp = self.confusion_matrix[:, i].sum().item() - tp
            fn = self.confusion_matrix[i, :].sum().item() - tp
            
            precision[i] = tp / max(tp + fp, 1e-8)
            recall[i] = tp / max(tp + fn, 1e-8)
            f1[i] = 2 * precision[i] * recall[i] / max(precision[i] + recall[i], 1e-8)
        
        # Averaging
        if self.average == 'macro':
            avg_precision = precision.mean().item()
            avg_recall = recall.mean().item()
            avg_f1 = f1.mean().item()
        elif self.average == 'weighted':
            weights = self.class_total / self.class_total.sum()
            avg_precision = (precision * weights).sum().item()
            avg_recall = (recall * weights).sum().item()
            avg_f1 = (f1 * weights).sum().item()
        else:  # micro
            tp_total = torch.diag(self.confusion_matrix).sum().item()
            fp_total = self.confusion_matrix.sum().item() - tp_total
            fn_total = fp_total  # For micro-averaging
            
            avg_precision = tp_total / max(tp_total + fp_total, 1e-8)
            avg_recall = tp_total / max(tp_total + fn_total, 1e-8)
            avg_f1 = 2 * avg_precision * avg_recall / max(avg_precision + avg_recall, 1e-8)
        
        return {
            'accuracy': accuracy,
            'precision': avg_precision,
            'recall': avg_recall,
            'f1_score': avg_f1,
            'per_class_precision': precision.tolist(),
            'per_class_recall': recall.tolist(),
            'per_class_f1': f1.tolist()
        }

class SpikeOperationsMetrics:
    """
    Spike Operations (SOPs) metrics for evaluating SNN computational efficiency.
    Tracks synaptic operations based on spike activity and network connectivity.
    """
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        """Reset all spike operation counters."""
        self.total_spikes = 0
        self.total_synapses = 0
        self.layer_spikes = defaultdict(int)
        self.layer_synapses = defaultdict(int)
        self.timestep_spikes = []
    
    def update_layer(self, layer_name: str, spike_tensor: torch.Tensor, fanout: int):
        """
        Update spike operations for a specific layer.
        
        :param layer_name: Name of the layer
        :param spike_tensor: Spike tensor [T, B, ...] or [B, ...]
        :param fanout: Number of downstream connections per neuron
        """
        if spike_tensor.dim() >= 3:  # Multi-timestep
            spikes_per_timestep = spike_tensor.sum(dim=tuple(range(1, spike_tensor.dim())))
            self.timestep_spikes.extend(spikes_per_timestep.tolist())
            total_layer_spikes = spike_tensor.sum().item()
        else:  # Single timestep
            total_layer_spikes = spike_tensor.sum().item()
        
        layer_sops = total_layer_spikes * fanout
        
        self.layer_spikes[layer_name] += total_layer_spikes
        self.layer_synapses[layer_name] += layer_sops
        self.total_spikes += total_layer_spikes
        self.total_synapses += layer_sops
    
    def compute(self) -> Dict[str, Any]:
        """
        Compute spike operation metrics.
        
        :return: Dictionary containing SOP statistics
        """
        spike_efficiency = self.total_spikes / max(self.total_synapses, 1) if self.total_synapses > 0 else 0
        
        return {
            'total_spikes': self.total_spikes,
            'total_synaptic_operations': self.total_synapses,
            'spike_efficiency': spike_efficiency,
            'average_spikes_per_timestep': np.mean(self.timestep_spikes) if self.timestep_spikes else 0,
            'spike_variance_across_time': np.var(self.timestep_spikes) if len(self.timestep_spikes) > 1 else 0,
            'layer_statistics': {
                'spikes_per_layer': dict(self.layer_spikes),
                'sops_per_layer': dict(self.layer_synapses)
            }
        }

class TemporalVarianceMetrics:
    """
    Temporal variance metrics for analyzing spike timing dynamics.
    Evaluates timing preservation and temporal diversity in spiking networks.
    """
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        """Reset temporal statistics."""
        self.timing_variances = []
        self.interval_variances = []
        self.burst_patterns = []
        self.first_spike_times = []
    
    def update(self, spike_info, timesteps: int):
        """
        Update temporal metrics from spike information.
        
        :param spike_info: SpikeInfo object containing timing data
        :param timesteps: Number of timesteps in sequence
        """
        if hasattr(spike_info, 'timing_map') and spike_info.timing_map is not None:
            # First spike timing variance
            timing_var = spike_info.timing_map.var(dim=-1).mean().item()
            self.timing_variances.append(timing_var)
            
            # First spike time statistics
            mean_first_spike = spike_info.timing_map.mean().item()
            self.first_spike_times.append(mean_first_spike)
        
        if hasattr(spike_info, 'interval_map') and spike_info.interval_map is not None:
            # Inter-spike interval variance
            interval_var = spike_info.interval_map.var(dim=-1).mean().item()
            self.interval_variances.append(interval_var)
        
        if hasattr(spike_info, 'burst_map') and spike_info.burst_map is not None:
            # Burst pattern analysis
            burst_intensity = spike_info.burst_map.mean().item()
            self.burst_patterns.append(burst_intensity)
    
    def compute(self) -> Dict[str, float]:
        """
        Compute temporal variance statistics.
        
        :return: Dictionary containing temporal metrics
        """
        return {
            'mean_timing_variance': np.mean(self.timing_variances) if self.timing_variances else 0.0,
            'std_timing_variance': np.std(self.timing_variances) if len(self.timing_variances) > 1 else 0.0,
            'mean_interval_variance': np.mean(self.interval_variances) if self.interval_variances else 0.0,
            'std_interval_variance': np.std(self.interval_variances) if len(self.interval_variances) > 1 else 0.0,
            'mean_burst_intensity': np.mean(self.burst_patterns) if self.burst_patterns else 0.0,
            'temporal_diversity_score': np.mean(self.timing_variances) + np.mean(self.interval_variances) if self.timing_variances and self.interval_variances else 0.0,
            'mean_first_spike_time': np.mean(self.first_spike_times) if self.first_spike_times else 0.0
        }

class SparseAttentionMetrics:
    """
    Sparse attention efficiency metrics for evaluating attention mechanisms.
    Measures sparsity ratios, computational savings, and attention quality.
    """
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        """Reset attention metrics."""
        self.sparsity_ratios = []
        self.attention_entropies = []
        self.computational_savings = []
        self.token_selection_counts = defaultdict(int)
    
    def update(self, attention_map: torch.Tensor, selected_tokens: Optional[torch.Tensor] = None,
               total_tokens: Optional[int] = None):
        """
        Update sparse attention metrics.
        
        :param attention_map: Attention weights [B, N] or [B, H, N, N]
        :param selected_tokens: Binary mask of selected tokens [B, N]
        :param total_tokens: Total number of tokens before selection
        """
        batch_size = attention_map.size(0)
        
        if selected_tokens is not None:
            # Sparsity ratio calculation
            num_selected = selected_tokens.sum(dim=1).float()
            if total_tokens is not None:
                sparsity_ratio = 1.0 - (num_selected / total_tokens)
            else:
                total_per_batch = selected_tokens.size(1)
                sparsity_ratio = 1.0 - (num_selected / total_per_batch)
            
            self.sparsity_ratios.extend(sparsity_ratio.tolist())
            
            # Computational savings (quadratic reduction)
            if total_tokens is not None:
                theoretical_ops = total_tokens * total_tokens
                actual_ops = num_selected * num_selected
                savings = 1.0 - (actual_ops / theoretical_ops)
                self.computational_savings.extend(savings.tolist())
        
        # Attention entropy (measure of attention distribution)
        if attention_map.dim() == 2:  # [B, N]
            attention_probs = torch.softmax(attention_map, dim=-1)
            entropy = -(attention_probs * torch.log(attention_probs + 1e-8)).sum(dim=-1)
            self.attention_entropies.extend(entropy.tolist())
        elif attention_map.dim() == 4:  # [B, H, N, N] multi-head attention
            attention_probs = torch.softmax(attention_map, dim=-1)
            entropy = -(attention_probs * torch.log(attention_probs + 1e-8)).sum(dim=-1).mean(dim=(-2, -1))
            self.attention_entropies.extend(entropy.tolist())
    
    def compute(self) -> Dict[str, float]:
        """
        Compute sparse attention efficiency metrics.
        
        :return: Dictionary containing attention metrics
        """
        return {
            'mean_sparsity_ratio': np.mean(self.sparsity_ratios) if self.sparsity_ratios else 0.0,
            'std_sparsity_ratio': np.std(self.sparsity_ratios) if len(self.sparsity_ratios) > 1 else 0.0,
            'mean_computational_savings': np.mean(self.computational_savings) if self.computational_savings else 0.0,
            'mean_attention_entropy': np.mean(self.attention_entropies) if self.attention_entropies else 0.0,
            'attention_concentration': 1.0 / (np.mean(self.attention_entropies) + 1e-8) if self.attention_entropies else 0.0
        }

class ComplexityMetrics:
    """
    Model complexity metrics including parameters, FLOPs, and energy estimation.
    Provides comprehensive analysis of computational requirements.
    """
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        """Reset complexity statistics."""
        self.total_params = 0
        self.trainable_params = 0
        self.total_flops = 0
        self.layer_complexity = {}
    
    def analyze_model(self, model: nn.Module, input_shape: Tuple[int, ...]):
        """
        Analyze model complexity including parameters and operations.
        
        :param model: PyTorch model to analyze
        :param input_shape: Input tensor shape (without batch dimension)
        """
        self.reset()
        
        # Parameter counting
        for name, param in model.named_parameters():
            param_count = param.numel()
            self.total_params += param_count
            if param.requires_grad:
                self.trainable_params += param_count
            
            layer_name = name.split('.')[0]
            if layer_name not in self.layer_complexity:
                self.layer_complexity[layer_name] = {'params': 0, 'flops': 0}
            self.layer_complexity[layer_name]['params'] += param_count
        
        # FLOP estimation
        self._estimate_flops(model, input_shape)
    
    def _estimate_flops(self, model: nn.Module, input_shape: Tuple[int, ...]):
        """
        Estimate FLOPs for the model using simplified calculations.
        
        :param model: PyTorch model
        :param input_shape: Input tensor shape
        """
        # Simplified FLOP estimation for common layers
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                flops = 2 * module.weight.numel()  # Multiply-accumulate
                self.total_flops += flops
                layer_name = name.split('.')[0]
                if layer_name in self.layer_complexity:
                    self.layer_complexity[layer_name]['flops'] += flops
            
            elif isinstance(module, nn.Conv2d):
                # Simplified convolution FLOP calculation
                if hasattr(module, 'weight'):
                    kernel_ops = module.weight.size(2) * module.weight.size(3)
                    output_elements = module.weight.size(0)  # Approximation
                    flops = kernel_ops * module.weight.size(1) * output_elements * 2
                    self.total_flops += flops
                    layer_name = name.split('.')[0]
                    if layer_name in self.layer_complexity:
                        self.layer_complexity[layer_name]['flops'] += flops
    
    def estimate_energy(self, spike_count: int, timesteps: int = 1) -> float:
        """
        Estimate energy consumption based on spike activity.
        
        :param spike_count: Total number of spikes
        :param timesteps: Number of timesteps
        :return: Estimated energy in millijoules
        """
        # Simplified energy model: 0.9 pJ per spike (Loihi-based estimation)
        energy_per_spike = 0.9e-12  # Joules
        total_energy_j = spike_count * energy_per_spike
        return total_energy_j * 1000  # Convert to millijoules
    
    def compute(self, spike_count: Optional[int] = None, timesteps: int = 1) -> Dict[str, Any]:
        """
        Compute complexity metrics.
        
        :param spike_count: Total spike count for energy estimation
        :param timesteps: Number of timesteps
        :return: Dictionary containing complexity metrics
        """
        gflops = self.total_flops / 1e9
        params_m = self.total_params / 1e6
        
        result = {
            'total_parameters': self.total_params,
            'trainable_parameters': self.trainable_params,
            'parameters_m': params_m,
            'total_flops': self.total_flops,
            'gflops': gflops,
            'flops_per_param': self.total_flops / max(self.total_params, 1),
            'layer_breakdown': self.layer_complexity
        }
        
        if spike_count is not None:
            energy_mj = self.estimate_energy(spike_count, timesteps)
            result.update({
                'estimated_energy_mj': energy_mj,
                'energy_per_spike_pj': 0.9,
                'spikes_per_joule': 1 / (0.9e-12) if spike_count > 0 else 0
            })
        
        return result

class ComprehensiveMetrics:
    """
    Comprehensive metrics aggregator combining all evaluation aspects.
    Provides unified interface for complete model evaluation.
    """
    
    def __init__(self, num_classes: int, track_temporal: bool = True, 
                 track_sparsity: bool = True, track_complexity: bool = True):
        self.num_classes = num_classes
        self.track_temporal = track_temporal
        self.track_sparsity = track_sparsity
        self.track_complexity = track_complexity
        
        # Initialize component metrics
        self.classification = ClassificationMetrics(num_classes)
        self.sops = SpikeOperationsMetrics()
        
        if track_temporal:
            self.temporal = TemporalVarianceMetrics()
        if track_sparsity:
            self.sparse_attention = SparseAttentionMetrics()
        if track_complexity:
            self.complexity = ComplexityMetrics()
    
    def reset(self):
        """Reset all metrics."""
        self.classification.reset()
        self.sops.reset()
        if self.track_temporal:
            self.temporal.reset()
        if self.track_sparsity:
            self.sparse_attention.reset()
        if self.track_complexity:
            self.complexity.reset()
    
    def update(self, predictions: torch.Tensor, targets: torch.Tensor,
               spike_info: Optional[Any] = None, attention_map: Optional[torch.Tensor] = None,
               selected_tokens: Optional[torch.Tensor] = None, layer_spikes: Optional[Dict] = None,
               timesteps: int = 1):
        """
        Update all metrics with batch data.
        
        :param predictions: Model predictions
        :param targets: Ground truth labels
        :param spike_info: SpikeInfo object with temporal data
        :param attention_map: Attention weights
        :param selected_tokens: Selected token mask
        :param layer_spikes: Dictionary of layer spike data
        :param timesteps: Number of timesteps
        """
        # Update classification metrics
        self.classification.update(predictions, targets)
        
        # Update spike operations
        if layer_spikes is not None:
            for layer_name, (spikes, fanout) in layer_spikes.items():
                self.sops.update_layer(layer_name, spikes, fanout)
        
        # Update temporal metrics
        if self.track_temporal and spike_info is not None:
            self.temporal.update(spike_info, timesteps)
        
        # Update sparse attention metrics
        if self.track_sparsity and attention_map is not None:
            self.sparse_attention.update(attention_map, selected_tokens)
    
    def analyze_model_complexity(self, model: nn.Module, input_shape: Tuple[int, ...]):
        """
        Analyze model complexity.
        
        :param model: PyTorch model
        :param input_shape: Input tensor shape
        """
        if self.track_complexity:
            self.complexity.analyze_model(model, input_shape)
    
    def compute_all(self, spike_count: Optional[int] = None, timesteps: int = 1) -> Dict[str, Any]:
        """
        Compute all metrics and return comprehensive results.
        
        :param spike_count: Total spike count for energy estimation
        :param timesteps: Number of timesteps
        :return: Dictionary containing all computed metrics
        """
        results = {
            'classification': self.classification.compute(),
            'spike_operations': self.sops.compute()
        }
        
        if self.track_temporal:
            results['temporal_variance'] = self.temporal.compute()
        
        if self.track_sparsity:
            results['sparse_attention'] = self.sparse_attention.compute()
        
        if self.track_complexity:
            results['complexity'] = self.complexity.compute(spike_count, timesteps)
        
        # Compute derived metrics
        results['summary'] = self._compute_summary_metrics(results)
        
        return results
    
    def _compute_summary_metrics(self, results: Dict[str, Any]) -> Dict[str, float]:
        """
        Compute summary metrics combining different aspects.
        
        :param results: Dictionary of computed metrics
        :return: Dictionary of summary metrics
        """
        summary = {
            'accuracy': results['classification']['accuracy'],
            'f1_score': results['classification']['f1_score']
        }
        
        if 'spike_operations' in results:
            summary['spike_efficiency'] = results['spike_operations']['spike_efficiency']
            summary['total_sops'] = results['spike_operations']['total_synaptic_operations']
        
        if 'sparse_attention' in results:
            summary['sparsity_ratio'] = results['sparse_attention']['mean_sparsity_ratio']
            summary['computational_savings'] = results['sparse_attention']['mean_computational_savings']
        
        if 'temporal_variance' in results:
            summary['temporal_diversity'] = results['temporal_variance']['temporal_diversity_score']
        
        if 'complexity' in results:
            summary['parameters_m'] = results['complexity']['parameters_m']
            summary['gflops'] = results['complexity']['gflops']
            if 'estimated_energy_mj' in results['complexity']:
                summary['energy_mj'] = results['complexity']['estimated_energy_mj']
        
        # Composite efficiency score
        if 'sparsity_ratio' in summary and 'spike_efficiency' in summary:
            summary['efficiency_score'] = (summary['sparsity_ratio'] * summary['spike_efficiency'] * 
                                         summary['accuracy'])
        
        return summary

# Utility functions for metric computation

def top_k_accuracy(predictions: torch.Tensor, targets: torch.Tensor, k: int = 5) -> float:
    """
    Compute top-k accuracy.
    
    :param predictions: Model predictions [B, num_classes]
    :param targets: Ground truth labels [B]
    :param k: Top-k parameter
    :return: Top-k accuracy
    """
    with torch.no_grad():
        batch_size = targets.size(0)
        _, top_k_pred = predictions.topk(k, 1, True, True)
        top_k_pred = top_k_pred.t()
        correct = top_k_pred.eq(targets.view(1, -1).expand_as(top_k_pred))
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        return correct_k.div_(batch_size).item()

def compute_firing_rate_statistics(spike_tensor: torch.Tensor) -> Dict[str, float]:
    """
    Compute firing rate statistics from spike tensor.
    
    :param spike_tensor: Spike tensor [T, B, ...] or [B, ...]
    :return: Dictionary of firing rate statistics
    """
    if spike_tensor.dim() >= 3:
        firing_rates = spike_tensor.mean(dim=0).flatten()
    else:
        firing_rates = spike_tensor.flatten()
    
    return {
        'mean_firing_rate': firing_rates.mean().item(),
        'std_firing_rate': firing_rates.std().item(),
        'min_firing_rate': firing_rates.min().item(),
        'max_firing_rate': firing_rates.max().item(),
        'firing_rate_cv': (firing_rates.std() / (firing_rates.mean() + 1e-8)).item()
    }

def compute_sparsity_metrics(tensor: torch.Tensor, threshold: float = 1e-6) -> Dict[str, float]:
    """
    Compute sparsity metrics for any tensor.
    
    :param tensor: Input tensor
    :param threshold: Threshold for considering values as non-zero
    :return: Dictionary of sparsity metrics
    """
    total_elements = tensor.numel()
    non_zero_elements = (tensor.abs() > threshold).sum().item()
    sparsity = 1.0 - (non_zero_elements / total_elements)
    
    return {
        'sparsity_ratio': sparsity,
        'density_ratio': 1.0 - sparsity,
        'non_zero_elements': non_zero_elements,
        'total_elements': total_elements
    }

def analyze_attention_pattern(attention_weights: torch.Tensor) -> Dict[str, float]:
    """
    Analyze attention pattern characteristics.
    
    :param attention_weights: Attention weights [B, N] or [B, H, N, N]
    :return: Dictionary of attention pattern metrics
    """
    if attention_weights.dim() == 2:  # [B, N]
        attn = torch.softmax(attention_weights, dim=-1)
        entropy = -(attn * torch.log(attn + 1e-8)).sum(dim=-1)
        max_attention = attn.max(dim=-1)[0]
        
        return {
            'mean_entropy': entropy.mean().item(),
            'mean_max_attention': max_attention.mean().item(),
            'attention_concentration': (1.0 / entropy.mean()).item(),
            'uniformity_score': (entropy / math.log(attn.size(-1))).mean().item()
        }
    elif attention_weights.dim() == 4:  # [B, H, N, N]
        attn = torch.softmax(attention_weights, dim=-1)
        entropy = -(attn * torch.log(attn + 1e-8)).sum(dim=-1).mean(dim=(-2, -1))
        
        return {
            'mean_entropy': entropy.mean().item(),
            'attention_diversity': entropy.std().item(),
            'head_specialization': attention_weights.std(dim=1).mean().item()
        }
    else:
        return {}

# Export main classes and functions
__all__ = [
    'ClassificationMetrics',
    'SpikeOperationsMetrics', 
    'TemporalVarianceMetrics',
    'SparseAttentionMetrics',
    'ComplexityMetrics',
    'ComprehensiveMetrics',
    'top_k_accuracy',
    'compute_firing_rate_statistics',
    'compute_sparsity_metrics',
    'analyze_attention_pattern'
]
