import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.special


class HLGaussLoss(nn.Module):
    def __init__(self, min_value: float, max_value: float, num_bins: int, sigma: float):
        super().__init__()
        self.min_value = min_value
        self.max_value = max_value
        self.num_bins = num_bins
        self.sigma = sigma
        self.support = torch.linspace(min_value, max_value, num_bins + 1, dtype=torch.float32)

    def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return F.cross_entropy(logits, self.transform_to_probs(target))

    def transform_to_probs(self, target: torch.Tensor) -> torch.Tensor:
        cdf_evals = torch.special.erf(
            (target.unsqueeze(-1) - self.support) / (torch.sqrt(torch.tensor(2.0)) * self.sigma)
        )
        z = cdf_evals[..., -1] - cdf_evals[..., 0]
        bin_probs = cdf_evals[..., 1:] - cdf_evals[..., :-1]
        return bin_probs / z.unsqueeze(-1)

    def transform_from_probs(self, probs: torch.Tensor) -> torch.Tensor:
        centers = (self.support[:-1] + self.support[1:]) / 2
        return torch.sum(probs * centers, dim=-1)


class TreeHLGaussLoss(nn.Module):
    def __init__(
        self, min_value: float, max_value: float, num_bins: int, sigma: float, device: torch.device
    ):
        super().__init__()
        self.min_value = min_value
        self.max_value = max_value
        self.num_bins = num_bins
        self.device = device
        self.support = torch.linspace(self.min_value, self.max_value, self.num_bins + 1, dtype=torch.float32)
        self.support = self.support.to(device)
        self.centers = (self.support[:-1] + self.support[1:]) / 2
        self.bin_widths = self.support[1] - self.support[0]
        self.sigma = sigma * self.bin_widths
        self.entropy = [0, 0]
        self.cross_entropy = [0, 0]
        self.KL_divergence = [0, 0]

    def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        target_probs = self.transform_to_probs(target)
        self.update_entropy()
        self.entropy[1] = -torch.sum(target_probs * ((target_probs + 1e-8).log()), dim=-1).mean().item()
        self.cross_entropy[1] = F.cross_entropy(logits, target_probs).item()
        self.KL_divergence[1] = F.kl_div(
            (logits.softmax(dim=-1) + 1e-8).log(), target_probs, reduction="batchmean"
        ).item()
        return F.cross_entropy(logits, target_probs, reduction="none")

    def transform_to_probs(self, target: torch.Tensor) -> torch.Tensor:
        cdf_evals = torch.special.erf(
            (target.unsqueeze(-1) - self.support) / (torch.sqrt(torch.tensor(2.0)) * self.sigma)
        )
        z = cdf_evals[..., -1] - cdf_evals[..., 0]
        bin_probs = cdf_evals[..., 1:] - cdf_evals[..., :-1]
        return bin_probs / z.unsqueeze(-1)

    def transform_from_probs(self, probs: torch.Tensor) -> torch.Tensor:
        target = torch.sum(probs * self.centers, dim=-1)
        return target

    def update_entropy(self):
        ######### FLAAAAAAAAAAAAAAAAAAAG ############# assert KL > 0, and all sum to 0 !
        self.entropy[0] = self.entropy[1]
        self.cross_entropy[0] = self.cross_entropy[1]
        self.KL_divergence[0] = self.KL_divergence[1]
