"""
Evaluation metrics for medical image segmentation
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional
from scipy.spatial.distance import directed_hausdorff
import cv2

class SegmentationMetrics:
    """Comprehensive evaluation metrics for medical image segmentation"""
    
    def __init__(self, num_classes: int, threshold: float = 0.5):
        self.num_classes = num_classes
        self.threshold = threshold
        
    def dice_score(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute Dice score for each class
        
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            dice_scores: (C,) - Dice score for each class
        """
        # Apply threshold
        pred_binary = (predictions > self.threshold).float()
        target_binary = (targets > self.threshold).float()
        
        # Flatten spatial dimensions
        pred_flat = pred_binary.view(pred_binary.size(0), pred_binary.size(1), -1)
        target_flat = target_binary.view(target_binary.size(0), target_binary.size(1), -1)
        
        # Compute intersection and union
        intersection = (pred_flat * target_flat).sum(dim=2)
        union = pred_flat.sum(dim=2) + target_flat.sum(dim=2)
        
        # Compute Dice score
        dice = (2.0 * intersection) / (union + 1e-7)
        
        return dice.mean(dim=0)  # Average over batch
    
    def iou_score(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute IoU (Jaccard) score for each class
        
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            iou_scores: (C,) - IoU score for each class
        """
        # Apply threshold
        pred_binary = (predictions > self.threshold).float()
        target_binary = (targets > self.threshold).float()
        
        # Flatten spatial dimensions
        pred_flat = pred_binary.view(pred_binary.size(0), pred_binary.size(1), -1)
        target_flat = target_binary.view(target_binary.size(0), target_binary.size(1), -1)
        
        # Compute intersection and union
        intersection = (pred_flat * target_flat).sum(dim=2)
        union = pred_flat.sum(dim=2) + target_flat.sum(dim=2) - intersection
        
        # Compute IoU
        iou = intersection / (union + 1e-7)
        
        return iou.mean(dim=0)  # Average over batch
    
    def hausdorff_distance(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute Hausdorff distance for each class
        
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            hausdorff_distances: (C,) - Hausdorff distance for each class
        """
        # Apply threshold
        pred_binary = (predictions > self.threshold).float()
        target_binary = (targets > self.threshold).float()
        
        batch_size, num_classes = predictions.size(0), predictions.size(1)
        hausdorff_distances = []
        
        for c in range(num_classes):
            class_distances = []
            
            for b in range(batch_size):
                pred_c = pred_binary[b, c].cpu().numpy()
                target_c = target_binary[b, c].cpu().numpy()
                
                # Get boundary points
                pred_points = self._get_boundary_points(pred_c)
                target_points = self._get_boundary_points(target_c)
                
                if len(pred_points) > 0 and len(target_points) > 0:
                    # Compute directed Hausdorff distances
                    h1 = directed_hausdorff(pred_points, target_points)[0]
                    h2 = directed_hausdorff(target_points, pred_points)[0]
                    hd = max(h1, h2)
                    class_distances.append(hd)
                else:
                    class_distances.append(0.0)
            
            hausdorff_distances.append(np.mean(class_distances))
        
        return torch.tensor(hausdorff_distances, device=predictions.device)
    
    def _get_boundary_points(self, mask: np.ndarray) -> np.ndarray:
        """Extract boundary points from binary mask"""
        # Find contours
        contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if len(contours) == 0:
            return np.array([])
        
        # Concatenate all contour points
        boundary_points = np.vstack([contour.reshape(-1, 2) for contour in contours])
        
        return boundary_points
    
    def boundary_f1_score(self, predictions: torch.Tensor, targets: torch.Tensor, 
                         boundary_width: int = 3) -> torch.Tensor:
        """
        Compute boundary F1-score for each class
        
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
            boundary_width: Width of boundary region
        Returns:
            boundary_f1_scores: (C,) - Boundary F1-score for each class
        """
        # Apply threshold
        pred_binary = (predictions > self.threshold).float()
        target_binary = (targets > self.threshold).float()
        
        batch_size, num_classes = predictions.size(0), predictions.size(1)
        boundary_f1_scores = []
        
        for c in range(num_classes):
            class_f1_scores = []
            
            for b in range(batch_size):
                pred_c = pred_binary[b, c].cpu().numpy()
                target_c = target_binary[b, c].cpu().numpy()
                
                # Compute boundary regions
                pred_boundary = self._get_boundary_region(pred_c, boundary_width)
                target_boundary = self._get_boundary_region(target_c, boundary_width)
                
                # Compute precision and recall
                tp = np.sum(pred_boundary * target_boundary)
                fp = np.sum(pred_boundary * (1 - target_boundary))
                fn = np.sum((1 - pred_boundary) * target_boundary)
                
                precision = tp / (tp + fp + 1e-7)
                recall = tp / (tp + fn + 1e-7)
                f1 = 2 * precision * recall / (precision + recall + 1e-7)
                
                class_f1_scores.append(f1)
            
            boundary_f1_scores.append(np.mean(class_f1_scores))
        
        return torch.tensor(boundary_f1_scores, device=predictions.device)
    
    def _get_boundary_region(self, mask: np.ndarray, width: int) -> np.ndarray:
        """Get boundary region of specified width"""
        # Compute distance transform
        from scipy.ndimage import distance_transform_edt
        dt = distance_transform_edt(1 - mask)
        
        # Create boundary region
        boundary_region = (dt <= width) & (dt > 0)
        
        return boundary_region.astype(np.float32)
    
    def pixel_accuracy(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute pixel accuracy
        
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            pixel_accuracy: Scalar - pixel accuracy
        """
        # Apply threshold
        pred_binary = (predictions > self.threshold).float()
        target_binary = (targets > self.threshold).float()
        
        # Compute accuracy
        correct = (pred_binary == target_binary).float()
        accuracy = correct.mean()
        
        return accuracy
    
    def mean_dice_score(self, predictions: torch.Tensor, targets: torch.Tensor) -> float:
        """Compute mean Dice score across all classes"""
        dice_scores = self.dice_score(predictions, targets)
        return dice_scores.mean().item()
    
    def mean_iou_score(self, predictions: torch.Tensor, targets: torch.Tensor) -> float:
        """Compute mean IoU score across all classes"""
        iou_scores = self.iou_score(predictions, targets)
        return iou_scores.mean().item()
    
    def mean_hausdorff_distance(self, predictions: torch.Tensor, targets: torch.Tensor) -> float:
        """Compute mean Hausdorff distance across all classes"""
        hd_scores = self.hausdorff_distance(predictions, targets)
        return hd_scores.mean().item()
    
    def mean_boundary_f1_score(self, predictions: torch.Tensor, targets: torch.Tensor) -> float:
        """Compute mean boundary F1-score across all classes"""
        bf1_scores = self.boundary_f1_score(predictions, targets)
        return bf1_scores.mean().item()
    
    def compute_all_metrics(self, predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
        """
        Compute all metrics and return as dictionary
        
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            metrics: Dictionary of all computed metrics
        """
        metrics = {}
        
        # Per-class metrics
        dice_scores = self.dice_score(predictions, targets)
        iou_scores = self.iou_score(predictions, targets)
        hd_scores = self.hausdorff_distance(predictions, targets)
        bf1_scores = self.boundary_f1_score(predictions, targets)
        
        # Mean metrics
        metrics['mean_dice'] = dice_scores.mean().item()
        metrics['mean_iou'] = iou_scores.mean().item()
        metrics['mean_hausdorff'] = hd_scores.mean().item()
        metrics['mean_boundary_f1'] = bf1_scores.mean().item()
        metrics['pixel_accuracy'] = self.pixel_accuracy(predictions, targets).item()
        
        # Per-class metrics
        for c in range(self.num_classes):
            metrics[f'dice_class_{c}'] = dice_scores[c].item()
            metrics[f'iou_class_{c}'] = iou_scores[c].item()
            metrics[f'hausdorff_class_{c}'] = hd_scores[c].item()
            metrics[f'boundary_f1_class_{c}'] = bf1_scores[c].item()
        
        return metrics

class ModelEvaluator:
    """Model evaluation with comprehensive metrics"""
    
    def __init__(self, model: torch.nn.Module, device: torch.device, num_classes: int):
        self.model = model
        self.device = device
        self.metrics = SegmentationMetrics(num_classes)
        
    def evaluate(self, dataloader: torch.utils.data.DataLoader) -> Dict[str, float]:
        """
        Evaluate model on dataset
        
        Args:
            dataloader: DataLoader for evaluation
        Returns:
            metrics: Dictionary of evaluation metrics
        """
        self.model.eval()
        
        all_predictions = []
        all_targets = []
        
        with torch.no_grad():
            for batch_idx, (images, targets) in enumerate(dataloader):
                images = images.to(self.device)
                targets = targets.to(self.device)
                
                # Forward pass
                predictions = self.model(images)
                
                # Apply sigmoid if needed
                if predictions.min() < 0 or predictions.max() > 1:
                    predictions = torch.sigmoid(predictions)
                
                all_predictions.append(predictions.cpu())
                all_targets.append(targets.cpu())
        
        # Concatenate all predictions and targets
        all_predictions = torch.cat(all_predictions, dim=0)
        all_targets = torch.cat(all_targets, dim=0)
        
        # Compute metrics
        metrics = self.metrics.compute_all_metrics(all_predictions, all_targets)
        
        return metrics
    
    def evaluate_batch(self, images: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
        """
        Evaluate model on single batch
        
        Args:
            images: (B, C, H, W) - input images
            targets: (B, C, H, W) - target masks
        Returns:
            metrics: Dictionary of evaluation metrics
        """
        self.model.eval()
        
        with torch.no_grad():
            images = images.to(self.device)
            targets = targets.to(self.device)
            
            # Forward pass
            predictions = self.model(images)
            
            # Apply sigmoid if needed
            if predictions.min() < 0 or predictions.max() > 1:
                predictions = torch.sigmoid(predictions)
            
            # Compute metrics
            metrics = self.metrics.compute_all_metrics(predictions, targets)
        
        return metrics

def compute_model_efficiency(model: torch.nn.Module, input_size: Tuple[int, int, int], 
                           device: torch.device, num_runs: int = 100) -> Dict[str, float]:
    """
    Compute model efficiency metrics
    
    Args:
        model: PyTorch model
        input_size: (C, H, W) - input size
        device: Device to run on
        num_runs: Number of runs for timing
    Returns:
        efficiency_metrics: Dictionary of efficiency metrics
    """
    model.eval()
    model = model.to(device)
    
    # Create dummy input
    dummy_input = torch.randn(1, *input_size).to(device)
    
    # Warm up
    with torch.no_grad():
        for _ in range(10):
            _ = model(dummy_input)
    
    # Timing
    torch.cuda.synchronize() if device.type == 'cuda' else None
    start_time = torch.cuda.Event(enable_timing=True) if device.type == 'cuda' else None
    end_time = torch.cuda.Event(enable_timing=True) if device.type == 'cuda' else None
    
    if device.type == 'cuda':
        start_time.record()
    else:
        import time
        start_time = time.time()
    
    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(dummy_input)
    
    if device.type == 'cuda':
        end_time.record()
        torch.cuda.synchronize()
        inference_time = start_time.elapsed_time(end_time) / num_runs  # ms
    else:
        end_time = time.time()
        inference_time = (end_time - start_time) * 1000 / num_runs  # ms
    
    # Memory usage
    if device.type == 'cuda':
        memory_usage = torch.cuda.max_memory_allocated() / 1024**2  # MB
    else:
        memory_usage = 0.0
    
    # Parameter count
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # FLOPs (approximate)
    flops = estimate_flops(model, dummy_input)
    
    return {
        'inference_time_ms': inference_time,
        'memory_usage_mb': memory_usage,
        'num_parameters': num_params,
        'flops': flops,
        'fps': 1000.0 / inference_time if inference_time > 0 else 0.0
    }

def estimate_flops(model: torch.nn.Module, input_tensor: torch.Tensor) -> int:
    """Estimate FLOPs for a model (approximate)"""
    # This is a simplified FLOP estimation
    # In practice, you might want to use tools like thop or fvcore
    
    def count_conv_flops(module, input_shape, output_shape):
        kernel_size = module.kernel_size[0] * module.kernel_size[1]
        in_channels = module.in_channels
        out_channels = module.out_channels
        output_elements = np.prod(output_shape[2:])
        return kernel_size * in_channels * out_channels * output_elements
    
    def count_linear_flops(module, input_shape, output_shape):
        return module.in_features * module.out_features * input_shape[0]
    
    total_flops = 0
    
    def hook_fn(module, input, output):
        nonlocal total_flops
        if isinstance(module, torch.nn.Conv2d):
            total_flops += count_conv_flops(module, input[0].shape, output.shape)
        elif isinstance(module, torch.nn.Linear):
            total_flops += count_linear_flops(module, input[0].shape, output.shape)
    
    # Register hooks
    hooks = []
    for module in model.modules():
        if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
            hooks.append(module.register_forward_hook(hook_fn))
    
    # Forward pass
    with torch.no_grad():
        _ = model(input_tensor)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    return total_flops

if __name__ == "__main__":
    # Test metrics
    batch_size, num_classes, height, width = 2, 5, 128, 128
    
    # Create dummy data
    predictions = torch.randn(batch_size, num_classes, height, width)
    targets = torch.randint(0, 2, (batch_size, num_classes, height, width)).float()
    
    # Test metrics
    metrics = SegmentationMetrics(num_classes)
    
    print("Testing segmentation metrics...")
    print(f"Dice Score: {metrics.mean_dice_score(predictions, targets):.4f}")
    print(f"IoU Score: {metrics.mean_iou_score(predictions, targets):.4f}")
    print(f"Pixel Accuracy: {metrics.pixel_accuracy(predictions, targets):.4f}")
    
    # Test all metrics
    all_metrics = metrics.compute_all_metrics(predictions, targets)
    print(f"All Metrics: {all_metrics}")
    
    # Test model efficiency
    from model import MSAUNet
    model = MSAUNet(in_channels=3, num_classes=5)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    efficiency = compute_model_efficiency(model, (3, 512, 512), device)
    print(f"Model Efficiency: {efficiency}")

