from math import inf
from typing import Callable, Optional, Union

import torch
from scipy.special import gammaln
from torch.utils.data import Dataset
from tqdm.auto import tqdm

from .certified_malconv import CertifiedMalConv

from ..utils import collate_pad


def combln(n: int, k: int) -> float:
    """Evaluates the natural logarithm of the binomial coefficient `ln (n! / (k! * (n - k)!))`"""
    return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1)


def brute_force_solve(f: Callable[[int], float], x_max: Optional[int] = None) -> int:
    """Find the largest non-negative integer argument of a decreasing function, such that its output remains positive

    Note:
    The solution is found by brute force: testing each value of the argument starting from 0.

    Args:
        f: A real-valued decreasing function, whose domain is the non-negative integers.
        x_max: Upper bound on the domain of `f`.

    Returns:
        The argument that satisfies the constraint. If the function is never positive, a value of -1 is returned.
    """
    x = -1
    while (x + 1 <= x_max if x_max else True) and f(x + 1) > 0:
        x += 1
    return x


def _exponential_bound(
    f: Callable[[int], float], base: Optional[int] = 4, x_max: Optional[int] = None
) -> int:
    # This assumes f(-1) > 0
    if base <= 1:
        raise ValueError("Base value have to be larger than 1")
    x = 1
    while (x_max is None or x <= x_max) and f(x) > 0:
        x *= base
    if x_max:
        x = min(x, x_max)
    return x


def binary_search_solve(f: Callable[[int], float], x_max: Optional[int] = None) -> int:
    """Find the largest non-negative integer argument of a decreasing function, such that its output remains positive
    using binary search

    Note:
    The solution is found by binary search: the upper value is specified by x_max.

    Args:
        f: A real-valued decreasing function, whose domain is the non-negative integers.
        x_max: Upper bound on the domain of `f`.

    Returns:
        The argument that satisfies the constraint. If the function is never positive, a value of -1 is returned.
    """
    x_left, f_left = -1, 1
    x_max_bound = _exponential_bound(f, base=4, x_max=x_max)
    x_max = min(x_max_bound, inf if x_max is None else x_max)
    x_right, f_right = x_max, f(x_max)

    # The maximum value is still negative, return the left x value
    if f(x_left + 1) < 0:
        return x_left
    # The minimum value is still positive, return the right x value
    elif f_right > 0:
        return x_right

    # Stop when left = right - 1, return left
    while x_left < x_right - 1:
        x_mid = (x_right + x_left) // 2
        f_mid = f(x_mid)
        if f_mid <= 0:
            x_right, f_right = x_mid, f_mid
        elif f_mid > 0:
            x_left, f_left = x_mid, f_mid
        else:
            raise ValueError("Nan detected when computing")
    assert f_left > 0 and f_right <= 0 and x_left == x_right - 1, "BS error"
    return x_left


def repeat_forward_ds(
    dataset: Dataset,
    classifier: CertifiedMalConv,
    num_samples: int,
    batch_size: int = 1,
    device: Union[torch.device, str, None] = None,
    verbose: int = 0,
) -> torch.Tensor:
    """Repeat a dataset and compute the predicted probabilies in a 3d tensor

    Args:
        dataset: A dataset of samples to certify. Note that each sample should only consist of an input, i.e. no
            target, which can be passed directly to `perturbation`.
        classifier: Base classifier.
        perturbation: Random perturbation that is applied to raw inputs before being passed to the base classifier.
        num_samples: Number of Monte Carlo samples to use.

    Keyword args:
        batch_size: Number of samples to pass to the classifier in one call when computing expectations.
        device: Device used for the computation.
        verbose: Set logging level for this function (for debugging)

    Returns:
        Tensor with shape `(repeat, len(dataset), num_classes)`
    """
    classifier.to(device)
    classifier.eval()

    # Allocate tensor later once we know num_classes
    repeat_probs = None

    if verbose:
        print(f"Dataset size: {len(dataset)}")

    for i, sample in tqdm(enumerate(dataset), desc="Repeat forward", total=len(dataset)):
        classifier.reduce = "none"
        binaries, metadata = sample
        binaries = binaries.unsqueeze(0).to(device)
        metadata = collate_pad([metadata]).to(device)
        with torch.no_grad():
            repeat_probs_i = classifier.forward(
                binaries,
                num_samples=num_samples,
                return_logits=False,
                return_radii=False,
                batch_size=batch_size,
                forward_kwargs=dict(metadata=metadata),
            ).squeeze(1).cpu()

        if repeat_probs is None:
            num_classes = repeat_probs_i.size(1)
            repeat_probs = torch.empty(
                (num_samples, len(dataset), num_classes),
                dtype=torch.float,
            )

        repeat_probs[:, i, :] = repeat_probs_i

    return repeat_probs