# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

from ...models.builder import LOSSES


@LOSSES.register_module()
class GradientLoss(nn.Module):
    """GradientLoss.

    Adapted from XXXX

    Args:
        valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
        loss_weight (float): Weight of the loss. Default: 1.0.
        max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
    """

    def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"):
        super(GradientLoss, self).__init__()
        self.valid_mask = valid_mask
        self.loss_weight = loss_weight
        self.max_depth = max_depth
        self.loss_name = loss_name

        self.eps = 0.001  # avoid grad explode

    def gradientloss(self, input, target):
        input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)]
        target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)]

        gradient_loss = 0
        for input, target in zip(input_downscaled, target_downscaled):
            if self.valid_mask:
                mask = target > 0
                if self.max_depth is not None:
                    mask = torch.logical_and(target > 0, target <= self.max_depth)
                N = torch.sum(mask)
            else:
                mask = torch.ones_like(target)
                N = input.numel()
            input_log = torch.log(input + self.eps)
            target_log = torch.log(target + self.eps)
            log_d_diff = input_log - target_log

            log_d_diff = torch.mul(log_d_diff, mask)

            v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :])
            v_mask = torch.mul(mask[0:-2, :], mask[2:, :])
            v_gradient = torch.mul(v_gradient, v_mask)

            h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:])
            h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:])
            h_gradient = torch.mul(h_gradient, h_mask)

            gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N

        return gradient_loss

    def forward(self, depth_pred, depth_gt):
        """Forward function."""

        gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt)
        return gradient_loss
