"""
Loss functions for Multi-Scale Attention U-Net medical image segmentation
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple

class DiceLoss(nn.Module):
    """Dice loss for medical image segmentation"""
    
    def __init__(self, smooth: float = 1e-7):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            dice_loss: Dice loss value
        """
        # Flatten spatial dimensions
        predictions = predictions.view(predictions.size(0), predictions.size(1), -1)
        targets = targets.view(targets.size(0), targets.size(1), -1)
        
        # Compute intersection and union
        intersection = (predictions * targets).sum(dim=2)
        union = predictions.sum(dim=2) + targets.sum(dim=2)
        
        # Compute Dice coefficient
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        
        # Return Dice loss (1 - Dice coefficient)
        return 1.0 - dice.mean()

class BoundaryLoss(nn.Module):
    """Boundary loss for emphasizing boundary accuracy"""
    
    def __init__(self, kernel_size: int = 3):
        super(BoundaryLoss, self).__init__()
        self.kernel_size = kernel_size
        
        # Sobel kernels for edge detection
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
        
        # Reshape for convolution
        self.sobel_x = sobel_x.view(1, 1, 3, 3)
        self.sobel_y = sobel_y.view(1, 1, 3, 3)
        
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            boundary_loss: Boundary loss value
        """
        device = predictions.device
        self.sobel_x = self.sobel_x.to(device)
        self.sobel_y = self.sobel_y.to(device)
        
        batch_size, num_classes = predictions.size(0), predictions.size(1)
        boundary_loss = 0.0
        
        for c in range(num_classes):
            pred_c = predictions[:, c:c+1, :, :]
            target_c = targets[:, c:c+1, :, :]
            
            # Compute gradients (boundaries)
            pred_grad_x = F.conv2d(pred_c, self.sobel_x, padding=1)
            pred_grad_y = F.conv2d(pred_c, self.sobel_y, padding=1)
            pred_boundary = torch.sqrt(pred_grad_x**2 + pred_grad_y**2 + 1e-8)
            
            target_grad_x = F.conv2d(target_c, self.sobel_x, padding=1)
            target_grad_y = F.conv2d(target_c, self.sobel_y, padding=1)
            target_boundary = torch.sqrt(target_grad_x**2 + target_grad_y**2 + 1e-8)
            
            # Boundary loss (L2 distance between boundaries)
            boundary_loss += F.mse_loss(pred_boundary, target_boundary)
        
        return boundary_loss / num_classes

class TverskyLoss(nn.Module):
    """Tversky loss for handling class imbalance"""
    
    def __init__(self, alpha: float = 0.3, beta: float = 0.7, smooth: float = 1e-7):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha  # Weight for false positives
        self.beta = beta    # Weight for false negatives
        self.smooth = smooth
        
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            tversky_loss: Tversky loss value
        """
        # Flatten spatial dimensions
        predictions = predictions.view(predictions.size(0), predictions.size(1), -1)
        targets = targets.view(targets.size(0), targets.size(1), -1)
        
        # Compute true positives, false positives, and false negatives
        tp = (predictions * targets).sum(dim=2)
        fp = (predictions * (1 - targets)).sum(dim=2)
        fn = ((1 - predictions) * targets).sum(dim=2)
        
        # Compute Tversky coefficient
        tversky = (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)
        
        # Return Tversky loss (1 - Tversky coefficient)
        return 1.0 - tversky.mean()

class FocalLoss(nn.Module):
    """Focal loss for handling hard examples"""
    
    def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            focal_loss: Focal loss value
        """
        # Compute cross entropy
        ce_loss = F.binary_cross_entropy_with_logits(predictions, targets, reduction='none')
        
        # Compute p_t
        p_t = torch.exp(-ce_loss)
        
        # Compute focal weight
        focal_weight = self.alpha * (1 - p_t) ** self.gamma
        
        # Compute focal loss
        focal_loss = focal_weight * ce_loss
        
        return focal_loss.mean()

class CombinedLoss(nn.Module):
    """Combined loss function with multiple components"""
    
    def __init__(self, 
                 dice_weight: float = 0.7, 
                 boundary_weight: float = 0.3,
                 tversky_weight: float = 0.0,
                 focal_weight: float = 0.0,
                 dice_smooth: float = 1e-7,
                 tversky_alpha: float = 0.3,
                 tversky_beta: float = 0.7,
                 focal_alpha: float = 1.0,
                 focal_gamma: float = 2.0):
        super(CombinedLoss, self).__init__()
        
        self.dice_weight = dice_weight
        self.boundary_weight = boundary_weight
        self.tversky_weight = tversky_weight
        self.focal_weight = focal_weight
        
        # Initialize loss functions
        self.dice_loss = DiceLoss(smooth=dice_smooth)
        self.boundary_loss = BoundaryLoss()
        
        if tversky_weight > 0:
            self.tversky_loss = TverskyLoss(alpha=tversky_alpha, beta=tversky_beta)
        else:
            self.tversky_loss = None
            
        if focal_weight > 0:
            self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
        else:
            self.focal_loss = None
    
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> Tuple[torch.Tensor, dict]:
        """
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            total_loss: Combined loss value
            loss_dict: Dictionary of individual loss components
        """
        # Apply sigmoid to predictions if needed
        if predictions.min() < 0 or predictions.max() > 1:
            predictions = torch.sigmoid(predictions)
        
        # Compute individual losses
        dice_loss = self.dice_loss(predictions, targets)
        boundary_loss = self.boundary_loss(predictions, targets)
        
        # Initialize loss dictionary
        loss_dict = {
            'dice_loss': dice_loss.item(),
            'boundary_loss': boundary_loss.item()
        }
        
        # Compute total loss
        total_loss = (self.dice_weight * dice_loss + 
                     self.boundary_weight * boundary_loss)
        
        # Add optional losses
        if self.tversky_loss is not None:
            tversky_loss = self.tversky_loss(predictions, targets)
            total_loss += self.tversky_weight * tversky_loss
            loss_dict['tversky_loss'] = tversky_loss.item()
        
        if self.focal_loss is not None:
            focal_loss = self.focal_loss(predictions, targets)
            total_loss += self.focal_weight * focal_loss
            loss_dict['focal_loss'] = focal_loss.item()
        
        loss_dict['total_loss'] = total_loss.item()
        
        return total_loss, loss_dict

class IoULoss(nn.Module):
    """IoU (Jaccard) loss for segmentation"""
    
    def __init__(self, smooth: float = 1e-7):
        super(IoULoss, self).__init__()
        self.smooth = smooth
        
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            iou_loss: IoU loss value
        """
        # Flatten spatial dimensions
        predictions = predictions.view(predictions.size(0), predictions.size(1), -1)
        targets = targets.view(targets.size(0), targets.size(1), -1)
        
        # Compute intersection and union
        intersection = (predictions * targets).sum(dim=2)
        union = predictions.sum(dim=2) + targets.sum(dim=2) - intersection
        
        # Compute IoU
        iou = (intersection + self.smooth) / (union + self.smooth)
        
        # Return IoU loss (1 - IoU)
        return 1.0 - iou.mean()

class HausdorffLoss(nn.Module):
    """Hausdorff distance loss (approximated)"""
    
    def __init__(self, threshold: float = 0.5):
        super(HausdorffLoss, self).__init__()
        self.threshold = threshold
        
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            predictions: (B, C, H, W) - predicted segmentation masks
            targets: (B, C, H, W) - ground truth segmentation masks
        Returns:
            hausdorff_loss: Hausdorff loss value
        """
        # Convert to binary masks
        pred_binary = (predictions > self.threshold).float()
        target_binary = (targets > self.threshold).float()
        
        batch_size, num_classes = predictions.size(0), predictions.size(1)
        hausdorff_loss = 0.0
        
        for b in range(batch_size):
            for c in range(num_classes):
                pred_c = pred_binary[b, c]
                target_c = target_binary[b, c]
                
                # Compute distance transforms
                pred_dt = self._distance_transform(pred_c)
                target_dt = self._distance_transform(target_c)
                
                # Compute Hausdorff distance (approximated)
                h1 = torch.max(pred_dt * target_c)
                h2 = torch.max(target_dt * pred_c)
                hausdorff_loss += torch.max(h1, h2)
        
        return hausdorff_loss / (batch_size * num_classes)
    
    def _distance_transform(self, mask: torch.Tensor) -> torch.Tensor:
        """Compute distance transform of a binary mask"""
        # Convert to numpy for distance transform
        mask_np = mask.cpu().numpy()
        
        # Compute distance transform
        from scipy.ndimage import distance_transform_edt
        dt = distance_transform_edt(1 - mask_np)
        
        # Convert back to tensor
        return torch.from_numpy(dt).to(mask.device)

def create_loss_function(loss_type: str = 'combined', **kwargs) -> nn.Module:
    """Factory function to create loss functions"""
    
    if loss_type == 'dice':
        return DiceLoss(**kwargs)
    elif loss_type == 'boundary':
        return BoundaryLoss(**kwargs)
    elif loss_type == 'tversky':
        return TverskyLoss(**kwargs)
    elif loss_type == 'focal':
        return FocalLoss(**kwargs)
    elif loss_type == 'combined':
        return CombinedLoss(**kwargs)
    elif loss_type == 'iou':
        return IoULoss(**kwargs)
    elif loss_type == 'hausdorff':
        return HausdorffLoss(**kwargs)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")

if __name__ == "__main__":
    # Test loss functions
    batch_size, num_classes, height, width = 2, 5, 128, 128
    
    # Create dummy data
    predictions = torch.randn(batch_size, num_classes, height, width)
    targets = torch.randint(0, 2, (batch_size, num_classes, height, width)).float()
    
    # Test individual losses
    dice_loss = DiceLoss()
    boundary_loss = BoundaryLoss()
    tversky_loss = TverskyLoss()
    focal_loss = FocalLoss()
    
    print(f"Dice Loss: {dice_loss(predictions, targets):.4f}")
    print(f"Boundary Loss: {boundary_loss(predictions, targets):.4f}")
    print(f"Tversky Loss: {tversky_loss(predictions, targets):.4f}")
    print(f"Focal Loss: {focal_loss(predictions, targets):.4f}")
    
    # Test combined loss
    combined_loss = CombinedLoss(dice_weight=0.7, boundary_weight=0.3)
    total_loss, loss_dict = combined_loss(predictions, targets)
    
    print(f"Combined Loss: {total_loss:.4f}")
    print(f"Loss Dictionary: {loss_dict}")
    
    # Test factory function
    loss_fn = create_loss_function('combined', dice_weight=0.8, boundary_weight=0.2)
    print(f"Factory Loss: {loss_fn(predictions, targets)[0]:.4f}")

