import torch
from torch import nn, Tensor
import torch.nn.functional as F


class BernoulliEntropyPenalty(nn.Module):
    def __init__(self, eps: float):
        super().__init__()
        self.eps = eps

    def forward(self, p: Tensor):
        if p.min() < 0.0 or p.max() > 1.0:
            raise ValueError("Expect p to be in [0, 1]")
        p_eps = torch.clamp(p, min=self.eps, max=1 - self.eps)
        h = -p * torch.log(p_eps) - (1.0 - p) * torch.log(1 - p_eps)
        return h.mean()


def extract_3x3_patches(tensor):
    b, h, w = tensor.shape
    tensor = tensor.unsqueeze(1)  # Shape: [b, 1, h, w]
    tensor = F.pad(tensor, (1, 1, 1, 1), mode="constant", value=0)  # [b, 1, h+2, w+2]
    patches = F.unfold(tensor, kernel_size=3, stride=1)  # Shape: [b, 9, h*w]
    patches = patches.view(b, 9, h, w)
    return patches


class NMSEntropyPenalty(nn.Module):
    def __init__(self, eps: float):
        super().__init__()
        self.eps = eps

    def forward(self, x: Tensor):
        if x.ndim != 3:
            raise ValueError("Expect a tensor (bs, h, w)")
        d = extract_3x3_patches(x)
        d = torch.nn.functional.normalize(d, p=1.0, dim=1)
        d_eps = torch.clamp(d, min=self.eps, max=1 - self.eps)
        h = -torch.sum(d * torch.log(d_eps), dim=1)
        loss = x * h
        return loss.mean()
