"""
Reconstruction loss functions for autoencoder and generative tasks.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union


class MSELoss(nn.Module):
    """
    Mean Squared Error loss for reconstruction.
    """
    
    def __init__(self, reduction: str = 'mean', normalize: bool = False):
        super().__init__()
        self.reduction = reduction
        self.normalize = normalize
        
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute MSE loss.
        
        Args:
            pred: [batch_size, *] - predicted values
            target: [batch_size, *] - target values
            
        Returns:
            MSE loss scalar
        """
        if self.normalize:
            # Normalize by target variance
            target_var = torch.var(target, dim=0, keepdim=True)
            target_var = torch.clamp(target_var, min=1e-8)
            pred = pred / torch.sqrt(target_var)
            target = target / torch.sqrt(target_var)
            
        loss = F.mse_loss(pred, target, reduction=self.reduction)
        return loss


class L1Loss(nn.Module):
    """
    L1 loss for reconstruction.
    """
    
    def __init__(self, reduction: str = 'mean', normalize: bool = False):
        super().__init__()
        self.reduction = reduction
        self.normalize = normalize
        
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute L1 loss.
        
        Args:
            pred: [batch_size, *] - predicted values
            target: [batch_size, *] - target values
            
        Returns:
            L1 loss scalar
        """
        if self.normalize:
            # Normalize by target mean absolute value
            target_abs = torch.mean(torch.abs(target), dim=0, keepdim=True)
            target_abs = torch.clamp(target_abs, min=1e-8)
            pred = pred / target_abs
            target = target / target_abs
            
        loss = F.l1_loss(pred, target, reduction=self.reduction)
        return loss


class KLDivLoss(nn.Module):
    """
    KL Divergence loss for variational autoencoders.
    """
    
    def __init__(self, reduction: str = 'mean', log_target: bool = False):
        super().__init__()
        self.reduction = reduction
        self.log_target = log_target
        
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute KL divergence loss.
        
        Args:
            pred: [batch_size, *] - predicted log probabilities
            target: [batch_size, *] - target probabilities or log probabilities
            
        Returns:
            KL divergence loss scalar
        """
        loss = F.kl_div(pred, target, reduction=self.reduction, log_target=self.log_target)
        return loss


class PerceptualLoss(nn.Module):
    """
    Perceptual loss using pre-trained features.
    """
    
    def __init__(self, feature_extractor: nn.Module, layer_weights: Optional[list] = None):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.layer_weights = layer_weights or [1.0]
        
        # Freeze feature extractor
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
            
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute perceptual loss.
        
        Args:
            pred: [batch_size, channels, height, width] - predicted images
            target: [batch_size, channels, height, width] - target images
            
        Returns:
            Perceptual loss scalar
        """
        # Extract features
        pred_features = self.feature_extractor(pred)
        target_features = self.feature_extractor(target)
        
        if not isinstance(pred_features, list):
            pred_features = [pred_features]
        if not isinstance(target_features, list):
            target_features = [target_features]
            
        # Compute L2 loss for each layer
        total_loss = 0.0
        for i, (pred_feat, target_feat, weight) in enumerate(
            zip(pred_features, target_features, self.layer_weights)
        ):
            layer_loss = F.mse_loss(pred_feat, target_feat)
            total_loss += weight * layer_loss
            
        return total_loss


class SSIMLoss(nn.Module):
    """
    Structural Similarity Index (SSIM) loss.
    """
    
    def __init__(self, window_size: int = 11, sigma: float = 1.5):
        super().__init__()
        self.window_size = window_size
        self.sigma = sigma
        
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute SSIM loss.
        
        Args:
            pred: [batch_size, channels, height, width] - predicted images
            target: [batch_size, channels, height, width] - target images
            
        Returns:
            SSIM loss scalar (1 - SSIM)
        """
        # Convert to grayscale if needed
        if pred.size(1) == 3:
            pred = 0.299 * pred[:, 0:1] + 0.587 * pred[:, 1:2] + 0.114 * pred[:, 2:3]
        if target.size(1) == 3:
            target = 0.299 * target[:, 0:1] + 0.587 * target[:, 1:2] + 0.114 * target[:, 2:3]
            
        # Compute SSIM
        ssim = self._ssim(pred, target)
        loss = 1 - ssim.mean()
        
        return loss
        
    def _ssim(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute SSIM between two images."""
        # This is a simplified SSIM implementation
        # For production use, consider using torchmetrics or similar libraries
        
        # Gaussian window
        window = self._create_window(self.window_size, self.sigma, x.device)
        
        # Means
        mu1 = F.conv2d(x, window, padding=self.window_size//2, groups=1)
        mu2 = F.conv2d(y, window, padding=self.window_size//2, groups=1)
        
        # Variances and covariance
        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2
        
        sigma1_sq = F.conv2d(x * x, window, padding=self.window_size//2, groups=1) - mu1_sq
        sigma2_sq = F.conv2d(y * y, window, padding=self.window_size//2, groups=1) - mu2_sq
        sigma12 = F.conv2d(x * y, window, padding=self.window_size//2, groups=1) - mu1_mu2
        
        # SSIM constants
        C1 = 0.01 ** 2
        C2 = 0.03 ** 2
        
        # SSIM formula
        ssim = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
               ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
               
        return ssim
        
    def _create_window(self, window_size: int, sigma: float, device: torch.device) -> torch.Tensor:
        """Create Gaussian window for SSIM computation."""
        coords = torch.arange(window_size, dtype=torch.float32, device=device)
        coords -= window_size // 2
        
        g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
        g /= g.sum()
        
        window = g.unsqueeze(0) * g.unsqueeze(1)
        window = window.unsqueeze(0).unsqueeze(0)
        
        return window


class CombinedReconstructionLoss(nn.Module):
    """
    Combined reconstruction loss with multiple components.
    """
    
    def __init__(self, 
                 mse_weight: float = 1.0,
                 l1_weight: float = 0.1,
                 perceptual_weight: float = 0.0,
                 ssim_weight: float = 0.0,
                 normalize: bool = True):
        super().__init__()
        self.mse_weight = mse_weight
        self.l1_weight = l1_weight
        self.perceptual_weight = perceptual_weight
        self.ssim_weight = ssim_weight
        self.normalize = normalize
        
        # Initialize loss functions
        self.mse_loss = MSELoss(normalize=normalize)
        self.l1_loss = L1Loss(normalize=normalize)
        
        if perceptual_weight > 0:
            # Use a simple feature extractor (e.g., first few layers of ResNet)
            self.perceptual_loss = None  # Would need to be set externally
        else:
            self.perceptual_loss = None
            
        if ssim_weight > 0:
            self.ssim_loss = SSIMLoss()
        else:
            self.ssim_loss = None
            
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute combined reconstruction loss.
        
        Args:
            pred: [batch_size, *] - predicted values
            target: [batch_size, *] - target values
            
        Returns:
            Combined loss scalar
        """
        total_loss = 0.0
        
        # MSE loss
        if self.mse_weight > 0:
            mse = self.mse_loss(pred, target)
            total_loss += self.mse_weight * mse
            
        # L1 loss
        if self.l1_weight > 0:
            l1 = self.l1_loss(pred, target)
            total_loss += self.l1_weight * l1
            
        # Perceptual loss
        if self.perceptual_weight > 0 and self.perceptual_loss is not None:
            perceptual = self.perceptual_loss(pred, target)
            total_loss += self.perceptual_weight * perceptual
            
        # SSIM loss
        if self.ssim_weight > 0 and self.ssim_loss is not None:
            ssim = self.ssim_loss(pred, target)
            total_loss += self.ssim_weight * ssim
            
        return total_loss 