"""Anomaly map computation for U-Flow model.

This module implements functionality to generate anomaly heatmaps from the latent
variables produced by a U-Flow model. The anomaly maps are generated by:

1. Computing per-scale likelihoods from latent variables
2. Upscaling likelihoods to original image size
3. Combining multiple scale likelihoods

Example:
    >>> from anomalib.models.image.uflow.anomaly_map import AnomalyMapGenerator
    >>> generator = AnomalyMapGenerator(input_size=(256, 256))
    >>> latent_vars = [torch.randn(1, 64, 32, 32), torch.randn(1, 128, 16, 16)]
    >>> anomaly_map = generator(latent_vars)

See Also:
    - :class:`AnomalyMapGenerator`: Main class for generating anomaly maps
    - :func:`compute_anomaly_map`: Function to generate anomaly maps from latents
"""

# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import scipy.stats as st
import torch
import torch.nn.functional as F  # noqa: N812
from mpmath import binomial, mp
from omegaconf import ListConfig
from scipy import integrate
from torch import Tensor, nn

mp.dps = 15  # Set precision for NFA computation (in case of high_precision=True)


class AnomalyMapGenerator(nn.Module):
    """Generate anomaly heatmaps and segmentation masks from U-Flow latent variables.

    This class implements functionality to generate anomaly maps by analyzing the latent
    variables produced by a U-Flow model. The anomaly maps can be generated in two ways:

    1. Using likelihood-based scoring (default method):
        - Computes per-scale likelihoods from latent variables
        - Upscales likelihoods to original image size
        - Combines multiple scale likelihoods via averaging

    2. Using NFA-based segmentation (optional method):
        - Applies binomial testing on local windows
        - Computes Number of False Alarms (NFA) statistics
        - Generates binary segmentation masks

    Args:
        input_size (ListConfig | tuple): Size of input images as ``(height, width)``

    Example:
        >>> from anomalib.models.image.uflow.anomaly_map import AnomalyMapGenerator
        >>> generator = AnomalyMapGenerator(input_size=(256, 256))
        >>> latents = [torch.randn(1, 64, 32, 32), torch.randn(1, 128, 16, 16)]
        >>> anomaly_map = generator(latents)
        >>> anomaly_map.shape
        torch.Size([1, 1, 256, 256])

    See Also:
        - :func:`compute_anomaly_map`: Main method for likelihood-based maps
        - :func:`compute_anomaly_mask`: Optional method for NFA-based segmentation
    """

    def __init__(self, input_size: ListConfig | tuple) -> None:
        super().__init__()
        self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size)

    def forward(self, latent_variables: list[Tensor]) -> Tensor:
        """Generate anomaly map from latent variables.

        Args:
            latent_variables (list[Tensor]): List of latent tensors from U-Flow model

        Returns:
            Tensor: Anomaly heatmap of shape ``(batch_size, 1, height, width)``
        """
        return self.compute_anomaly_map(latent_variables)

    def compute_anomaly_map(self, latent_variables: list[Tensor]) -> Tensor:
        """Generate likelihood-based anomaly map from latent variables.

        The method:
        1. Computes per-scale likelihoods from latent variables
        2. Upscales each likelihood map to input image size
        3. Combines scale likelihoods via averaging

        Args:
            latent_variables (list[Tensor]): List of latent tensors from U-Flow model,
                each with shape ``(batch_size, channels, height, width)``

        Returns:
            Tensor: Anomaly heatmap of shape ``(batch_size, 1, height, width)``
        """
        likelihoods = []
        for z in latent_variables:
            # Mean prob by scale. Likelihood is actually with sum instead of mean. Using mean to avoid numerical issues.
            # Also, this way all scales have the same weight, and it does not depend on the number of channels
            log_prob_i = -torch.mean(z**2, dim=1, keepdim=True) * 0.5
            prob_i = torch.exp(log_prob_i)
            likelihoods.append(
                F.interpolate(
                    prob_i,
                    size=self.input_size,
                    mode="bilinear",
                    align_corners=False,
                ),
            )
        return 1 - torch.mean(torch.stack(likelihoods, dim=-1), dim=-1)

    def compute_anomaly_mask(
        self,
        z: list[torch.Tensor],
        window_size: int = 7,
        binomial_probability_thr: float = 0.5,
        high_precision: bool = False,
    ) -> torch.Tensor:
        """Generate NFA-based anomaly segmentation mask from latent variables.

        This optional method implements the Number of False Alarms (NFA) approach from
        the U-Flow paper. It is slower than the default likelihood method but provides
        unsupervised binary segmentation.

        The method:
        1. Applies binomial testing on local windows around each pixel
        2. Computes NFA statistics based on concentration of candidate pixels
        3. Generates binary segmentation mask

        Args:
            z (list[torch.Tensor]): List of latent tensors from U-Flow model
            window_size (int, optional): Size of local window for binomial test.
                Defaults to ``7``.
            binomial_probability_thr (float, optional): Probability threshold for
                binomial test. Defaults to ``0.5``.
            high_precision (bool, optional): Whether to use high precision NFA
                computation. Slower but more accurate. Defaults to ``False``.

        Returns:
            torch.Tensor: Binary anomaly mask of shape ``(batch_size, 1, height,
                width)``
        """
        log_prob_l = [
            self.binomial_test(zi, window_size / (2**scale), binomial_probability_thr, high_precision)
            for scale, zi in enumerate(z)
        ]

        log_prob_l_up = torch.cat(
            [F.interpolate(lpl, size=self.input_size, mode="bicubic", align_corners=True) for lpl in log_prob_l],
            dim=1,
        )

        log_prob = torch.sum(log_prob_l_up, dim=1, keepdim=True)

        log_number_of_tests = torch.log10(torch.sum(torch.tensor([zi.shape[-2] * zi.shape[-1] for zi in z])))
        log_nfa = log_number_of_tests + log_prob

        anomaly_score = -log_nfa

        return anomaly_score < 0

    @staticmethod
    def binomial_test(
        z: torch.Tensor,
        window_size: int,
        probability_thr: float,
        high_precision: bool = False,
    ) -> torch.Tensor:
        """Apply binomial test to validate/reject normality hypothesis.

        For each pixel, tests the null hypothesis that the pixel and its local
        neighborhood are normal against the alternative that they are anomalous.

        The test:
        1. Counts anomalous pixels in local window using chi-square threshold
        2. Compares observed count to expected count under null hypothesis
        3. Returns log probability of observing such extreme counts

        Args:
            z (torch.Tensor): Latent tensor of shape ``(batch_size, channels,
                height, width)``
            window_size (int): Size of local window for counting
            probability_thr (float): Probability threshold for chi-square test
            high_precision (bool, optional): Whether to use high precision
                computation. Defaults to ``False``.

        Returns:
            torch.Tensor: Log probability tensor of shape ``(batch_size, 1,
                height, width)``
        """
        tau = st.chi2.ppf(probability_thr, 1)
        half_win = np.max([int(window_size // 2), 1])

        n_chann = z.shape[1]

        # Candidates
        z2 = F.pad(z**2, tuple(4 * [half_win]), "reflect").detach().cpu()
        z2_unfold_h = z2.unfold(-2, 2 * half_win + 1, 1)
        z2_unfold_hw = z2_unfold_h.unfold(-2, 2 * half_win + 1, 1).numpy()
        observed_candidates_k = np.sum(z2_unfold_hw >= tau, axis=(-2, -1))

        # All volume together
        observed_candidates = np.sum(observed_candidates_k, axis=1, keepdims=True)
        x = observed_candidates / n_chann
        n = int((2 * half_win + 1) ** 2)

        # Low precision
        if not high_precision:
            log_prob = torch.tensor(st.binom.logsf(x, n, 1 - probability_thr) / np.log(10))
        # High precision - good and slow
        else:
            to_mp = np.frompyfunc(mp.mpf, 1, 1)
            mpn = mp.mpf(n)
            mpp = probability_thr

            def binomial_density(tensor: torch.tensor) -> torch.Tensor:
                return binomial(mpn, to_mp(tensor)) * (1 - mpp) ** tensor * mpp ** (mpn - tensor)

            def integral(tensor: torch.Tensor) -> torch.Tensor:
                return integrate.quad(binomial_density, tensor, n)[0]

            integral_array = np.vectorize(integral)
            prob = integral_array(x)
            log_prob = torch.tensor(np.log10(prob))

        return log_prob
