import torch
import torch.nn.functional as F
from src.utils.ops import batch_gather

class Criterion(torch.nn.Module):
    def __init__(self, weights, criterions):
        super().__init__()
        self.weights = weights
        self.criterions = criterions

        
    def forward(self, batch, preds):
        total_loss = 0
        info = {}
        for weight, criterion in zip(self.weights, self.criterions):
            loss = weight * criterion(batch, preds)
            total_loss += loss
            info[criterion.__class__.__name__] = loss.item()
        info['total_loss'] = total_loss.item()
        return total_loss, info

class DepthLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch, preds):
        depth_pred = preds['depth']
        depth_gt = batch_gather(batch['depths'], batch['target_indices'])
        return F.mse_loss(depth_pred, depth_gt)
    
def binary_cross_entropy(input, target):
    """
    F.binary_cross_entropy is not numerically stable in mixed-precision training.
    """
    return -(target * torch.log(input) + (1 - target) * torch.log(1 - input)).mean()

class MaskLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch, preds):
        mask_pred = preds['mask']
        mask_gt = batch_gather(batch['masks'], batch['target_indices'])
        return binary_cross_entropy(mask_pred.clip(1e-3, 1.0 - 1e-3), mask_gt)

class NormalLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch, preds):
        normal_pred = preds['normal']
        normal_gt = batch_gather(batch['normals'], batch['target_indices'])
        return F.mse_loss(normal_pred, normal_gt)

class EikonalLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch, preds):
        return ((torch.linalg.norm(preds['sdf_grad'], ord=2, dim=-1) - 1.)**2).mean()

class RGBLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, batch, preds):
        rgb_pred = preds['rgb']
        rgb_gt = batch_gather(batch['images'], batch['target_indices'])
        return F.mse_loss(rgb_pred, rgb_gt)