import einops
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF

class SpatiallyWeightedMSE:
    def __init__(self, 
                 sigma=4, 
                 share_of_total_loss=0.2, 
                 kernel_size=3, # equals to just probe neighbors
                 ):
        '''
        Weigh around the mask with a Gaussian blur.
        But set the centers are masked out again.
        Then these weighting overall takes `share_of_total_loss` of the total weight.

        Args:
            sigma (float): Sigma for Gaussian blur.
            share_of_total_loss (float): Ratio of total loss contributed by the weighted regions.
            kernel_size (int): Kernel size for Gaussian blur.
        '''
        self.sigma = sigma
        self.share_of_total_loss = share_of_total_loss
        self.kernel_size = kernel_size

    def __call__(self, y_true, y_pred, mask):

        with torch.autocast(device_type=str(y_true.device).split(":")[0], enabled=False):

            b, T, h, w = mask.shape
            mask_for_blur = einops.rearrange(mask, 'b T h w -> (b T) h w').float()
            # Apply Gaussian blur to the mask
            blurred = TF.gaussian_blur(mask_for_blur, kernel_size=self.kernel_size, sigma=self.sigma)
            blurred = einops.rearrange(blurred, '(b T) h w -> b T h w', T=T)
            blurred[mask] = 0

            # Create a weight map: 1 for non-masked regions, blurred for masked regions
            weight_map = torch.ones_like(mask, device=y_true.device)  # (batch_size, timesteps, height, width)
            weight_map[mask] = 0
            
            # Scale the blurred weights so that their sum is share_of_total_loss * total_weight
            blurred = blurred * (self.share_of_total_loss * weight_map.sum()) / blurred.sum()
            weight_map = weight_map + blurred

            # Expand weight_map to match y_true/y_pred dimensions
            weight_map = weight_map.unsqueeze(-1) / weight_map.sum()  # (batch_size, timesteps, height, width, 1)

            # Compute weighted MSE
            # squared_error = (y_true - y_pred) ** 2
            # weighted_squared_error = squared_error * weight_map
            weighted_squared_error = numerically_save_weighted_mse(y_true, y_pred, weight_map)


            # Return the mean of the weighted squared error
            loss = torch.mean(weighted_squared_error)
            return loss



def numerically_save_weighted_mse(y_true: torch.Tensor, y_pred: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
    """
    Numerically stable weighted MSE in PyTorch.

    Args:
        y_true: Ground truth values (shape: [n, ...]).
        y_pred: Predicted values (shape: [n, ...]).
        weights: Weights for each element (shape: [n, ...]).

    Returns:
        Weighted MSE (scalar).
    """
    # Ensure all tensors are the same shape
    assert y_true.shape == y_pred.shape, "Shapes must match"

    # Compute weighted sums
    sum_w = torch.sum(weights)
    sum_wy_true_sq = torch.sum(weights * y_true**2)
    sum_wy_pred_sq = torch.sum(weights * y_pred**2)
    sum_wy_true_y_pred = torch.sum(weights * y_true * y_pred)

    # Numerically stable WMSE
    wmse = (sum_wy_true_sq - 2 * sum_wy_true_y_pred + sum_wy_pred_sq) / sum_w

    return wmse


if __name__ == "__main__":
    # Test case
    batch_size, timesteps, height, width, channels = 2, 3, 64, 64, 1
    y_true = torch.rand(batch_size, timesteps, height, width, channels)
    y_pred = torch.rand(batch_size, timesteps, height, width, channels)
    mask = torch.zeros(batch_size, height, width)
    mask[0, 20:40, 20:40] = 1  # Example mask for the first sample

    loss_fn = SpatiallyWeightedMSE(sigma=4, share_of_total_loss=0.1, kernel_size=21)
    loss = loss_fn.forward(y_true, y_pred, mask)

    print(f"Computed loss: {loss.item():.4f}")
