import torch as th
import torchvision as tv
from torch import nn
from utils.utils import BatchToSharedObjects, SharedObjectsToBatch, LambdaModule, RGB2YCbCr
from einops import rearrange, repeat, reduce
from utils.optimizers import Ranger
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.optimize import linear_sum_assignment


class LatentDisentanglingLoss(nn.Module):
    def __init__(self, num_objects, sigma=0.5):
        super(LatentDisentanglingLoss, self).__init__()
        self.num_objects = num_objects
        self.sigma = sigma

    def distance_weights(self, positions):
        # positions is of shape (B, N, 3)
        B, N, _ = positions.shape

        # give more weight to z distance
        positions = th.cat((positions, positions[:,:,-1:]), dim=2)
        
        # Expand dims to compute pairwise differences
        p1 = positions[:, :, None, :]
        p2 = positions[:, None, :, :]
        
        # Compute pairwise differences and squared Euclidean distance
        diff = p1 - p2
        squared_diff = diff ** 2
        squared_distances = th.sum(squared_diff, dim=-1)
        
        # Compute the actual distances
        distances = th.sqrt(th.relu(squared_distances) + 1e-8)
        weights = th.exp(-distances / (2 * self.sigma ** 2))
        
        return weights

    def batch_covariance(self, slots):
        mean_slots = th.mean(slots, dim=1, keepdim=True)
        centered_slots = slots - mean_slots
        cov_matrix = th.bmm(centered_slots.transpose(1, 2), centered_slots) / (slots.size(1) - 1)
        return cov_matrix

    def batch_correlation(self, slots):
        cov_matrix = self.batch_covariance(slots)
        variances = th.diagonal(cov_matrix, dim1=-2, dim2=-1)
        std_matrix = th.sqrt(th.relu(variances[:, :, None] * variances[:, None, :]) + 1e-8)
        corr_matrix = cov_matrix / std_matrix
        return corr_matrix

    def forward(self, positions, gestalt, mask):
        positions = rearrange(positions, 'b (o c) -> b o c', o=self.num_objects)
        gestalt   = rearrange(gestalt, 'b (o c) -> b o c', o=self.num_objects)
        gestalt   = rearrange(th.cat((gestalt, positions[:,:,-1:]), dim=2), 'b o c -> b c o') # add std to gestalt

        distance_weights   = self.distance_weights(positions[:,:,:-1])
        correlation_matrix = self.batch_correlation(gestalt)

        visibility = (reduce(mask[:,:-1], 'b c h w -> b c 1', 'max') > 0.5).float()

        similarity = correlation_matrix * distance_weights * visibility
        diag_indices = th.arange(0, similarity.size(-1))
        similarity[:, diag_indices, diag_indices] = 0.0

        return th.mean(similarity**2)

class GestaltLoss(nn.Module):
    def __init__(self, warmup_period=1000.0):
        super(GestaltLoss, self).__init__()
        self.register_buffer('num_updates', th.tensor(0.0))
        self.warmup_period = warmup_period

    def forward(self, gestalt):
        if self.num_updates < 30 * self.warmup_period:
            scaling_factor = max(0.0, min(1, 0.1 ** (self.num_updates.item() / self.warmup_period - 1)))
            loss = th.mean(th.abs(gestalt - 0.5)) * scaling_factor
        else:
            loss = th.tensor(0.0, device=self.num_updates.device)

        self.num_updates.add_(1.0)
        return loss

class DecayingFactor(nn.Module):
    def __init__(self, warmup_period=2500.0, min_factor=0.01, inverse=False):
        super(DecayingFactor, self).__init__()
        self.register_buffer('num_updates', th.tensor(0.0))
        self.warmup_period = warmup_period
        self.min_factor = min_factor
        self.inverse = inverse

    def get(self):
        factor = max(self.min_factor, min(1, 0.1 ** (self.num_updates.item() / self.warmup_period - 1)))
        if self.inverse:
            factor = 1 - factor
        self.num_updates.add_(1.0)
        return factor

    def forward(self, x):
        return x * self.get()

class DecayingMSELoss(nn.Module):
    def __init__(self, warmup_period=2500.0, min_factor=0.01):
        super(DecayingMSELoss, self).__init__()
        self.register_buffer('num_updates', th.tensor(0.0))
        self.warmup_period = warmup_period
        self.min_factor = min_factor

    def forward(self, pred, target):
        scaling_factor = max(self.min_factor, min(1, 0.1 ** (self.num_updates.item() / self.warmup_period - 1)))
        loss = th.mean((pred - target)**2) * scaling_factor

        self.num_updates.add_(1.0)
        return loss

class PositionLoss(nn.Module):
    def __init__(self, num_objects: int, teacher_forcing: int):
        super(PositionLoss, self).__init__()

        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
        self.last_mask = None
        self.t = 0 
        self.teacher_forcing = teacher_forcing

    def reset_state(self):
        self.last_mask = None
        self.t = 0

    def forward(self, position, position_last, mask):
        
        mask = th.max(th.max(mask, dim=3)[0], dim=2)[0]
        mask = self.to_batch(mask).detach()
        self.t = self.t + 1

        if self.last_mask is None or self.t <= self.teacher_forcing:
            self.last_mask = mask.detach()
            return th.zeros(1, device=mask.device)

        self.last_mask = th.maximum(self.last_mask, mask)

        position      = self.to_batch(position)
        position_last = self.to_batch(position_last).detach()

        position      = th.cat((position[:,:2], 0.25 * th.sigmoid(position[:,3:4])), dim=1) 
        position_last = th.cat((position_last[:,:2], 0.25 * th.sigmoid(position_last[:,3:4])), dim=1) 

        return 0.01 * th.mean(self.last_mask * (position - position_last)**2)


class MaskModulatedObjectLoss(nn.Module):
    def __init__(self, num_objects: int, teacher_forcing: int):
        super(MaskModulatedObjectLoss, self).__init__()

        self.to_batch  = SharedObjectsToBatch(num_objects)
        self.last_mask = None
        self.t = 0 
        self.teacher_forcing = teacher_forcing

    def reset_state(self):
        self.last_mask = None
        self.t = 0
    
    def forward(
        self, 
        object_output,
        object_target,
        mask: th.Tensor
    ):
        mask = self.to_batch(mask).detach()
        mask = th.max(th.max(mask, dim=3, keepdim=True)[0], dim=2, keepdim=True)[0]
        self.t = self.t + 1

        if self.last_mask is None or self.t <= self.teacher_forcing:
            self.last_mask = mask.detach()
            return th.zeros(1, device=mask.device)

        self.last_mask = th.maximum(self.last_mask, mask).detach()

        object_output = th.sigmoid(self.to_batch(object_output) - 2.5)
        object_target = th.sigmoid(self.to_batch(object_target) - 2.5)

        return th.mean((1 - mask) * self.last_mask * (object_output - object_target)**2)

class ObjectModulator(nn.Module):
    def __init__(self, num_objects: int): 
        super(ObjectModulator, self).__init__()
        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))
        self.position  = None
        self.gestalt   = None

    def reset_state(self):
        self.position = None
        self.gestalt  = None

    def forward(self, position: th.Tensor, gestalt: th.Tensor, mask: th.Tensor):

        position = self.to_batch(position)
        gestalt  = self.to_batch(gestalt)

        if self.position is None or self.gestalt is None:
            self.position = position.detach()
            self.gestalt  = gestalt.detach()
            return self.to_shared(position), self.to_shared(gestalt)

        mask = th.max(th.max(mask, dim=3)[0], dim=2)[0]
        mask = self.to_batch(mask.detach())

        _position = mask * position + (1 - mask) * self.position
        position  = th.cat((position[:,:-1], _position[:,-1:]), dim=1)
        gestalt   = mask * gestalt  + (1 - mask) * self.gestalt

        self.gestalt = gestalt.detach()
        self.position = position.detach()
        return self.to_shared(position), self.to_shared(gestalt)

class MoveToCenter(nn.Module):
    def __init__(self, num_objects: int):
        super(MoveToCenter, self).__init__()

        self.to_batch2d = SharedObjectsToBatch(num_objects)
        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
    
    def forward(self, input: th.Tensor, position: th.Tensor):
        
        input    = self.to_batch2d(input)
        position = self.to_batch(position).detach()
        position = th.stack((position[:,0], position[:,1]), dim=1)

        theta = th.tensor([1, 0, 0, 1], dtype=th.float, device=input.device).view(1,2,2)
        theta = repeat(theta, '1 a b -> n a b', n=input.shape[0])

        position = rearrange(position, 'b c -> b c 1')
        theta    = th.cat((theta, position), dim=2)

        grid   = nn.functional.affine_grid(theta, input.shape, align_corners=False)
        output = nn.functional.grid_sample(input, grid, align_corners=False)

        return output

class TranslationInvariantObjectLoss(nn.Module):
    def __init__(self, num_objects: int, teacher_forcing: int):
        super(TranslationInvariantObjectLoss, self).__init__()

        self.move_to_center  = MoveToCenter(num_objects)
        self.to_batch        = SharedObjectsToBatch(num_objects)
        self.last_mask       = None
        self.t               = 0 
        self.teacher_forcing = teacher_forcing

    def reset_state(self):
        self.last_mask = None
        self.t = 0
    
    def forward(
        self, 
        mask: th.Tensor,
        object1: th.Tensor, 
        position1: th.Tensor,
        object2: th.Tensor, 
        position2: th.Tensor,
    ):
        mask = self.to_batch(mask).detach()
        mask = th.max(th.max(mask, dim=3, keepdim=True)[0], dim=2, keepdim=True)[0]
        self.t = self.t + 1

        if self.last_mask is None or self.t <= self.teacher_forcing:
            self.last_mask = mask.detach()
            return th.zeros(1, device=mask.device)

        self.last_mask = th.maximum(self.last_mask, mask).detach()

        object1 = self.move_to_center(th.sigmoid(object1 - 2.5), position1)
        object2 = self.move_to_center(th.sigmoid(object2 - 2.5), position2)

        return th.mean(self.last_mask * (object1 - object2)**2)

def depth_smooth_loss(depth, img):
    """Computes the smoothness loss for a depth image
    The color image is used for edge-aware smoothness
    """

    mean_depth = reduce(depth, 'b c h w -> b c 1 1', 'mean')
    norm_depth = depth / (mean_depth + 1e-7)
    depth = norm_depth

    grad_depth_x = th.abs(depth[:, :, :, :-1] - depth[:, :, :, 1:])
    grad_depth_y = th.abs(depth[:, :, :-1, :] - depth[:, :, 1:, :])

    grad_img_x = th.mean(th.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
    grad_img_y = th.mean(th.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)

    grad_depth_x *= th.exp(-grad_img_x)
    grad_depth_y *= th.exp(-grad_img_y)

    return grad_depth_x.mean() + grad_depth_y.mean()

class SSIM(nn.Module):
    """Layer to compute the SSIM loss between a pair of images
    """

    def __init__(self):
        super(SSIM, self).__init__()
        self.mu_x_pool = nn.AvgPool2d(3, 1)
        self.mu_y_pool = nn.AvgPool2d(3, 1)
        self.sig_x_pool = nn.AvgPool2d(3, 1)
        self.sig_y_pool = nn.AvgPool2d(3, 1)
        self.sig_xy_pool = nn.AvgPool2d(3, 1)

        self.refl = nn.ReflectionPad2d(1)

        self.C1 = 0.01 ** 2
        self.C2 = 0.03 ** 2

    def forward(self, x, y):
        x = self.refl(x)
        y = self.refl(y)

        mu_x = self.mu_x_pool(x)
        mu_y = self.mu_y_pool(y)

        sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
        sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
        sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y

        SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
        SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)

        return th.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)  

class MaskedL1SSIMLoss(nn.Module):
    def __init__(self, ssim_factor = 0.85):
        super(MaskedL1SSIMLoss, self).__init__()

        self.ssim = SSIM()
        self.ssim_factor = ssim_factor

    def forward(self, output, target, mask):
        
        l1    = th.abs(output - target) * mask
        ssim  = self.ssim(output, target) * mask

        numel = th.sum(mask, dim=(1, 2, 3)) + 1e-7

        l1 = th.sum(l1, dim=(1, 2, 3)) / numel
        ssim = th.sum(ssim, dim=(1, 2, 3)) / numel

        f = self.ssim_factor

        return th.mean(l1 * (1 - f) + ssim * f), th.mean(l1), th.mean(ssim)

class L1SSIMLoss(nn.Module):
    def __init__(self, ssim_factor = 0.85):
        super(L1SSIMLoss, self).__init__()

        self.ssim = SSIM()
        self.ssim_factor = ssim_factor

    def forward(self, output, target):
        
        l1    = th.abs(output - target)
        ssim  = self.ssim(output, target)

        f = self.ssim_factor

        return th.mean(l1 * (1 - f) + ssim * f), th.mean(l1), th.mean(ssim)

class RGBL1SSIMLoss(nn.Module):
    def __init__(self, ssim_factor = 0.5):
        super(RGBL1SSIMLoss, self).__init__()

        self.ssim = SSIM()
        self.ssim_factor = ssim_factor

    def forward(self, output, target):

        grey_output = output[:, 0:1] * 0.299 + output[:, 1:2] * 0.587 + output[:, 2:3] * 0.114
        grey_target = target[:, 0:1] * 0.299 + target[:, 1:2] * 0.587 + target[:, 2:3] * 0.114
        
        l1    = (output - target)**2
        ssim  = self.ssim(grey_output, grey_target)

        f = self.ssim_factor

        return th.mean(l1 * (1 - f) + ssim * f), th.mean(l1), th.mean(ssim)

class YCbCrL2SSIMLoss(nn.Module):
    def __init__(self):
        super(YCbCrL2SSIMLoss, self).__init__()

        self.to_YCbCr = RGB2YCbCr()
        self.ssim     = SSIM()
    
    def forward(self, x, y):

        x = self.to_YCbCr(x)

        with th.no_grad():
            y = self.to_YCbCr(y.detach()).detach()

        y_loss  = th.mean(self.ssim(x[:,0:1], y[:,0:1]))
        cb_loss = th.mean((x[:,1] - y[:,1])**2) * 10
        cr_loss = th.mean((x[:,2] - y[:,2])**2) * 10

        loss = y_loss + cb_loss + cr_loss
        return loss, cb_loss + cr_loss, y_loss

class MaskedYCbCrL2SSIMLoss(nn.Module):
    def __init__(self, rgb_input = True):
        super(MaskedYCbCrL2SSIMLoss, self).__init__()
        
        self.rgb_input = rgb_input
        self.to_YCbCr  = RGB2YCbCr()
        self.ssim      = SSIM()
    
    def forward(self, x, y, mask):

        if self.rgb_input:
            x = self.to_YCbCr(x)

            with th.no_grad():
                y = self.to_YCbCr(y.detach()).detach()

        y_loss  = self.ssim(x[:,0:1], y[:,0:1]) * mask
        cb_loss = (x[:,1:2] - y[:,1:2])**2 * 10 * mask
        cr_loss = (x[:,2:3] - y[:,2:3])**2 * 10 * mask

        numel = th.sum(mask, dim=(1, 2, 3)) + 1e-7

        y_loss  = th.sum(y_loss,  dim=(1, 2, 3)) / numel
        cb_loss = th.sum(cb_loss, dim=(1, 2, 3)) / numel
        cr_loss = th.sum(cr_loss, dim=(1, 2, 3)) / numel

        loss = y_loss + cb_loss + cr_loss
        return th.mean(loss), th.mean(cb_loss + cr_loss), th.mean(y_loss)

class UncertaintyYCbCrL2SSIMLoss(nn.Module):
    def __init__(self):
        super(UncertaintyYCbCrL2SSIMLoss, self).__init__()

        self.to_YCbCr = RGB2YCbCr()
        self.ssim     = SSIM()
    
    def forward(self, x, y, uncertainty):

        with th.no_grad():
            x = self.to_YCbCr(x.detach()).detach()
            y = self.to_YCbCr(y.detach()).detach()

        y_loss  = self.ssim(x[:,0:1], y[:,0:1]) * uncertainty
        cb_loss = (x[:,1:2] - y[:,1:2])**2 * 10 * uncertainty
        cr_loss = (x[:,2:3] - y[:,2:3])**2 * 10 * uncertainty

        return y_loss + cb_loss + cr_loss

class MaskedMSELoss(nn.Module):
    def __init__(self):
        super(MaskedMSELoss, self).__init__()

    def forward(self, output, target, mask):
        mse   = ((output - target)**2) * mask
        numel = th.sum(mask, dim=(1, 2, 3)) + 1e-7

        return th.mean(th.sum(mse, dim=(1, 2, 3)) / numel)

class UncertaintyGANLoss(nn.Module):
    def __init__(
        self, 
        discriminator_start = 50001, 
        in_channels = 3, 
        base_channels = 16,
        blocks = [1,2,3,4],
        discriminator_weight = 0.5,
        discriminator_lr = 4e-4,

    ):
        super(UncertaintyGANLoss, self).__init__()
        self.discriminator = nn.Sequential(
            ConvNeXtUnet(
                in_channels    = in_channels,
                out_channels   = 2,
                base_channels  = base_channels,
                blocks         = blocks,
            ),
            nn.Softmax(dim=1),
        )

        self.l1ssim = MaskedL1SSIMLoss()

        self.register_buffer('num_updates', th.zeros(1))

        self.discriminator_start  = discriminator_start
        self.discriminator_weight = discriminator_weight
        self.optimizer            = Ranger(self.discriminator.parameters(), lr = discriminator_lr, weight_decay = 0.001)

        self.fake_target_certain = None
        self.fake_target_uncertain = None

    def init_targets(self, B, H, W, device):
        self.fake_target_certain   = th.zeros((B, H, W), device=device, dtype=th.long) + 0
        self.fake_target_uncertain = th.zeros((B, H, W), device=device, dtype=th.long) + 1

    def calculate_adaptive_weight(self, rec_loss, g_loss, last_layer=None):
        rec_grads = th.autograd.grad(rec_loss, last_layer, retain_graph=True)[0]
        g_grads   = th.autograd.grad(g_loss, last_layer, retain_graph=True)[0]

        d_weight = th.linalg.norm(rec_grads) / (th.linalg.norm(g_grads) + 1e-6)
        d_weight = th.clamp(d_weight, 0.0, 1e6).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight

    def masked_cross_entropy(self, prediction, target, mask):
        return th.sum(
            F.cross_entropy(prediction, target, reduction='none') * mask.squeeze(dim=1).detach()
        ) / (th.sum(mask.detach()) + 1e-6)

    def masked_accuracy(self, prediction, target, mask):
        return th.sum(
            (th.argmax(prediction, dim=1) == target).float() * mask.squeeze(dim=1)
        ) / (th.sum(mask) + 1e-6)

    def forward(self, inputs, reconstructions, uncertainty, last_layer):

        self.num_updates = self.num_updates.detach() + 1

        if self.fake_target_certain is None:
            self.init_targets(inputs.shape[0], *inputs.shape[2:], inputs.device)

        # discriminator update
        if self.num_updates.item() >= self.discriminator_start:
            
            self.optimizer.zero_grad()

            logits_fake = self.discriminator(reconstructions.detach())

            acc_fake_certain    = self.masked_accuracy(logits_fake,      self.fake_target_certain,   (1 - uncertainty))
            acc_fake_uncertain  = self.masked_accuracy(logits_fake,      self.fake_target_uncertain, uncertainty)
            loss_fake_certain   = self.masked_cross_entropy(logits_fake, self.fake_target_certain,   (1 - uncertainty))
            loss_fake_uncertain = self.masked_cross_entropy(logits_fake, self.fake_target_uncertain, uncertainty)

            d_loss = loss_fake_certain + loss_fake_uncertain

            d_loss.backward()
            self.optimizer.step()

        rec_loss, rec_loss_l1, rec_loss_ssim = self.l1ssim(inputs, reconstructions, (1 - uncertainty).detach())
        loss     = rec_loss

        if self.num_updates.item() >= self.discriminator_start:

            # generator update
            logits_fake = self.discriminator(reconstructions)
            g_loss      = self.masked_cross_entropy(logits_fake, self.fake_target_certain, uncertainty)

            d_weight = self.calculate_adaptive_weight(rec_loss, g_loss, last_layer=last_layer)

            loss = rec_loss + d_weight * g_loss

        log = {
            "total_loss"          : loss.item(), 
            "rec_loss"            : rec_loss.item(),
            "rec_loss_l1"         : rec_loss_l1.item(),
            "rec_loss_ssim"       : rec_loss_ssim.item(),
            "d_weight"            : d_weight.item()            if self.num_updates.item() >= self.discriminator_start else 0,
            "d_loss"              : d_weight.item()            if self.num_updates.item() >= self.discriminator_start else 0,
            "g_loss"              : g_loss.item()              if self.num_updates.item() >= self.discriminator_start else 0,
            "d_loss"              : d_loss.item()              if self.num_updates.item() >= self.discriminator_start else 0,
            "acc_fake_certain"    : acc_fake_certain.item()    if self.num_updates.item() >= self.discriminator_start else 0,
            "acc_fake_uncertain"  : acc_fake_uncertain.item()  if self.num_updates.item() >= self.discriminator_start else 0,
            "loss_fake_certain"   : loss_fake_certain.item()   if self.num_updates.item() >= self.discriminator_start else 0,
            "loss_fake_uncertain" : loss_fake_uncertain.item() if self.num_updates.item() >= self.discriminator_start else 0,
        }

        return loss, log

def bipartite_matching_mse(tensor1, tensor2):
    """
    Perform bipartite matching based on the MSE matrix for batched tensor.
    
    :param tensor1: Tensor of shape [batch, N, C]
    :param tensor2: Tensor of shape [batch, M, C] (Note: N and M can be different)
    :return: the optimal matched loss
    """
    
    batch_size = tensor1.shape[0]
    
    all_matched_losses = []
    mse_matrix = th.mean((tensor1[:,:, None, :] - tensor2[:,None, :, :])**2, dim=-1)
    mse_matrix_numpy = mse_matrix.detach().cpu().numpy()
    
    # Loop through the batches
    for b in range(batch_size):
        
        # Step 3: Hungarian algorithm to solve linear sum assignment
        pred_indices, gt_indices = linear_sum_assignment(mse_matrix_numpy[b])
        
        # Compute the matched loss for this batch within PyTorch
        matched_loss = mse_matrix[b, pred_indices, gt_indices].mean()
        all_matched_losses.append(matched_loss)
    
    return th.stack(all_matched_losses).mean()


class LovaszHingeLoss(nn.Module):
    """
    Lovasz-Softmax and Jaccard hinge loss in PyTorch
    Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
    """
    def __init__(self):
        super(LovaszHingeLoss, self).__init__()

    def forward(self, logits, labels):
        # Initialize a variable to accumulate the loss of each image
        total_loss = 0

        # Loop over each image in the batch
        for logit, label in zip(logits, labels):
            logit = logit.reshape(-1)
            label = label.reshape(-1)
            signs = 2. * label.float() - 1.
            errors = (1. - logit * Variable(signs))
            errors_sorted, perm = th.sort(errors, dim=0, descending=True)
            perm = perm.data
            gt_sorted = label[perm]
            grad = self.lovasz_grad(gt_sorted)
            loss = th.dot(F.relu(errors_sorted), Variable(grad))
            total_loss += loss

        # Return the average loss across all images in the batch
        return total_loss / len(logits)

    def lovasz_grad(self, gt_sorted):
        p = len(gt_sorted)
        gts = gt_sorted.sum()
        intersection = gts - gt_sorted.float().cumsum(0)
        union = gts + (1 - gt_sorted).float().cumsum(0)
        jaccard = 1. - intersection / union
        if p > 1: # cover 1-pixel case
            jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
        return jaccard

class YCbCrL1SSIMLoss(nn.Module):
    def __init__(self, factor_Y = 0.95, factor_Cb = 0.025, factor_Cr = 0.025):
        super(YCbCrL1SSIMLoss, self).__init__()
        self.factor_Y  = factor_Y
        self.factor_Cb = factor_Cb
        self.factor_Cr = factor_Cr

        self.to_YCbCr = RGB2YCbCr()
        self.l1ssim   = L1SSIMLoss()
    
    def forward(self, x, y):
        x = self.to_YCbCr(x)

        with th.no_grad():
            y = self.to_YCbCr(y.detach()).detach()

        y_loss, l1, ssim = self.l1ssim(x[:,0:1], y[:,0:1])

        cb_loss = th.mean((x[:,1] - y[:,1])**2)
        cr_loss = th.mean((x[:,2] - y[:,2])**2)

        cr_factor = (y_loss / cr_loss).detach() * self.factor_Cr
        cb_factor = (y_loss / cb_loss).detach() * self.factor_Cb
        
        sum_factors = cr_factor + cb_factor + self.factor_Y

        y_factor  = self.factor_Y  / sum_factors
        cb_factor = cb_factor / sum_factors
        cr_factor = cr_factor / sum_factors

        loss = y_loss * y_factor + cb_loss * cb_factor + cr_loss * cr_factor

        return loss, l1, ssim

class SobelEdgeLoss(nn.Module):
    def __init__(self, loss_type="L1"):
        super().__init__()

        self.sobel_x = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.sobel_y = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.sobel_diag1 = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.sobel_diag2 = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)

        self.sobel_x.weight = nn.Parameter(th.Tensor([[[[-1., 0., 1.], [-2., 0., 2.], [-1., 0., 1.]]]]), requires_grad=False)
        self.sobel_y.weight = nn.Parameter(th.Tensor([[[[-1., -2., -1.], [0., 0., 0.], [1., 2., 1.]]]]), requires_grad=False)
        self.sobel_diag1.weight = nn.Parameter(th.Tensor([[[[-2., -1., 0.], [-1., 0., 1.], [0., 1., 2.]]]]), requires_grad=False)
        self.sobel_diag2.weight = nn.Parameter(th.Tensor([[[[0., 1., 2.], [-1., 0., 1.], [-2., -1., 0.]]]]), requires_grad=False)
        
        assert loss_type in ["L1", "L2"], 'loss_type should be either "L1" or "L2"'
        self.loss_type = loss_type

    def forward(self, depth, rgb):
        depth = F.pad(depth, (1, 1, 1, 1), mode='replicate')
        rgb = F.pad(rgb, (1, 1, 1, 1), mode='replicate')

        depth_edge_map_x = self.sobel_x(depth)
        depth_edge_map_y = self.sobel_y(depth)
        depth_edge_map_diag1 = self.sobel_diag1(depth)
        depth_edge_map_diag2 = self.sobel_diag2(depth)

        rgb_edge_map_x = self.sobel_x(rgb)
        rgb_edge_map_y = self.sobel_y(rgb)
        rgb_edge_map_diag1 = self.sobel_diag1(rgb)
        rgb_edge_map_diag2 = self.sobel_diag2(rgb)

        if self.loss_type == "L1":
            loss_x = F.l1_loss(depth_edge_map_x, rgb_edge_map_x)
            loss_y = F.l1_loss(depth_edge_map_y, rgb_edge_map_y)
            loss_diag1 = F.l1_loss(depth_edge_map_diag1, rgb_edge_map_diag1)
            loss_diag2 = F.l1_loss(depth_edge_map_diag2, rgb_edge_map_diag2)
        else:  # L2
            loss_x = F.mse_loss(depth_edge_map_x, rgb_edge_map_x)
            loss_y = F.mse_loss(depth_edge_map_y, rgb_edge_map_y)
            loss_diag1 = F.mse_loss(depth_edge_map_diag1, rgb_edge_map_diag1)
            loss_diag2 = F.mse_loss(depth_edge_map_diag2, rgb_edge_map_diag2)

        return loss_x + loss_y + loss_diag1 + loss_diag2
