import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'



def huber(s: torch.Tensor, epsilon: float = 0.01) -> torch.Tensor:
    """
    Compute the Huber loss element-wise.

    The Huber loss is quadratic for small values of 's' and linear for large values, 
    providing a robust alternative to squared error loss, especially for outliers.

    Args:
    - s (torch.Tensor): The input tensor for which to compute the Huber loss.
    - epsilon (float): The threshold at which the loss changes from quadratic to linear. Defaults to 0.01.

    Returns:
    - torch.Tensor: The computed element-wise Huber loss.

    Raises:
    - ValueError: If the input 's' is not a torch.Tensor or if 'epsilon' is not positive.
    """
    
    # Check if the input is a torch.Tensor
    if not isinstance(s, torch.Tensor):
        raise ValueError("Input 's' to function huber() must be a torch.Tensor")

    # Ensure epsilon is positive
    if epsilon <= 0:
        raise ValueError("'epsilon' in function huber() must be greater than zero.")

    # Compute the Huber loss element-wise
    return torch.where(
        torch.abs(s) <= epsilon,  # Quadratic region (small errors)
        (0.5 * s ** 2) / epsilon,  # Quadratic penalty for small errors
        torch.abs(s) - 0.5 * epsilon  # Linear penalty for large errors
    )


def fast_huber_TV(x: torch.Tensor, alpha: float = 1, delta: float = 0.01) -> torch.Tensor:
    """
    Compute the Huber-based total variation (TV) regularization for a batch of input tensors.

    Args:
    - x (torch.Tensor): Input tensor with shape (batch_size, channels, height, width).
    - alpha (float): Scaling factor for the total variation regularization. Defaults to 1.
    - delta (float): Threshold parameter for Huber loss. Defaults to 0.01.

    Returns:
    - torch.Tensor: The computed Huber-based total variation regularization value.
    """
    
    # Compute Huber loss along the height (vertical) axis
    tv_h = huber(x[:, :, 1:, :] - x[:, :, :-1, :], delta).sum()
    
    # Compute Huber loss along the width (horizontal) axis
    tv_w = huber(x[:, :, :, 1:] - x[:, :, :, :-1], delta).sum()
    
    # Sum both vertical and horizontal TV terms
    huber_tv = tv_h + tv_w
    
    # Return the scaled Huber TV regularization
    return alpha * huber_tv


def fast_huber_grad(x: torch.Tensor, alpha: float = 1, delta: float = 0.01) -> torch.Tensor:
    """
    Compute the gradient of the Huber-based total variation (TV) regularization with respect to the input tensor.

    Args:
    - x (torch.Tensor): Input tensor with shape (batch_size, channels, height, width).
    - alpha (float): Scaling factor for the total variation regularization. Defaults to 1.
    - delta (float): Threshold parameter for Huber loss. Defaults to 0.01.

    Returns:
    - torch.Tensor: The gradient of the Huber-based total variation regularization.
    """
    
    # Extract height and width dimensions of the input tensor
    height, width = x.shape[2:]
    batch_size = x.shape[0]

    # Compute finite differences along the height (vertical) dimension
    diff_y = torch.cat(
        (x[:, :, 1:, :] - x[:, :, :-1, :],  # Difference between adjacent pixels
         torch.zeros(batch_size, 1, 1, width, dtype=x.dtype).to(x.device)),  # Pad the last row with zeros
        dim=2
    )
    
    # Compute finite differences along the width (horizontal) dimension
    diff_x = torch.cat(
        (x[:, :, :, 1:] - x[:, :, :, :-1],  # Difference between adjacent pixels
         torch.zeros(batch_size, 1, height, 1, dtype=x.dtype).to(x.device)),  # Pad the last column with zeros
        dim=3
    )

    # Create a tensor with delta values to compare against
    delta_ones = delta * torch.ones_like(diff_y)

    # Compute maximum between absolute differences and delta for stabilization
    max_diff_delta1 = torch.max(torch.abs(diff_y), delta_ones)
    max_diff_delta2 = torch.max(torch.abs(diff_x), delta_ones)

    # Compute the gradient penalty for both directions
    penult1 = diff_y / max_diff_delta1  # Regularized gradient along height
    penult2 = diff_x / max_diff_delta2  # Regularized gradient along width

    # Reconstruct the gradient in the vertical direction
    f1 = torch.cat(
        (-penult1[:, :, 0, :].unsqueeze(-2),  # First row
         -penult1[:, :, 1:, :] + penult1[:, :, :-1, :]),  # Interior rows
        dim=2
    )

    # Reconstruct the gradient in the horizontal direction
    f2 = torch.cat(
        (-penult2[:, :, :, 0].unsqueeze(-1),  # First column
         -penult2[:, :, :, 1:] + penult2[:, :, :, :-1]),  # Interior columns
        dim=3
    )

    # Return the weighted sum of gradients along both directions
    return alpha * (f1 + f2)
