from multiprocessing.pool import ThreadPool
from typing import Callable, List, Union

import numpy as np
import tqdm

Array = Union[List[float], np.ndarray]


class QuantileFunction:
    """Computes the quantile function corresponding to an empirical distribution of samples.

    Args
        samples (numpy array): Empirical distribution of data e.g. scores belonging to an algorithm.
    """

    def __init__(self, samples: Array):
        if isinstance(samples, list):
            samples = np.array(samples)
        self.samples = np.sort(samples)

    def __len__(self):
        return len(self.samples)

    def __call__(self, p: Union[float, Array]) -> np.ndarray:
        if isinstance(p, list):
            p = np.array(p)
        inds = np.ceil(len(self) * p).astype(int)
        clip_inds = np.clip(inds - 1, 0, len(self) - 1)
        return self.samples[clip_inds]


def num_integrate(func: Callable, x0: float, x1: float, dx: float) -> float:
    r""" Numerical integration

    Integrate a function func from x0 to x1 with step dx

    Args:
        func (Callable): function to integrate
        x0 (float): start point
        x1 (float): end point
        dx (float): step size

    Returns:
        numpy array containing the results of the integration
    """
    X = np.linspace(x0, x1, int((x1 - x0) / dx))
    y = func(X[1:])  # remove x0
    return np.sum(y * dx).item()


def num_integrate_func(func: Callable, x0: float, x1: float,
                       dx: float) -> Callable:
    r""" Numerical integration

    Integrate a function func from x0 to x1 with step dx and returns
    the result as the indefinite integral function $F(x) = \Int_x0^x f(s) ds$

    Args:
        func (Callable): function to integrate
        x0 (float): start point
        x1 (float): end point
        dx (float): step size

    Returns:
        function, a function that returns the integral until a given point x
    """
    X = np.linspace(x0, x1, int((x1 - x0) / dx))
    y = func(X[1:])  # remove x0

    res = np.cumsum(y * dx)
    res = np.insert(res, 0, 0)  # add x0 so that $\Int_{x0}^{x0}f(x)dx = 0$

    def integral_func(x: Union[float, Array]) -> np.ndarray:
        if isinstance(x, list):
            x = np.array(x)
        if not all(np.r_[x >= x0, x <= x1]):
            raise ValueError(f"x must be between {x0} and {x1}, but got {x}.")
        inds = np.ceil((x - x0) / dx).astype(int)
        clip_inds = np.clip(inds - 1, 0, len(res) - 1)
        return res[clip_inds]

    return integral_func


class SecondQuantileFunction:
    """Computes the second quantile function $F_X^{(-2)}(p)$, i.e. the integral of the quantile function $F_X^{(-1)}(p)$.
        For efficiency we compute its values between 0 and 1.0 with interval given by dp, and store them.
    """

    def __init__(self, samples: Array, dp: float = 0.01):
        self.samples = samples
        self.dp = dp

        quant_func = QuantileFunction(samples)
        self.second_quantile_func = num_integrate_func(quant_func, 0.0, 1.0,
                                                       dp)

    def __call__(self, p: Union[float, Array]) -> np.ndarray:
        return self.second_quantile_func(p)


class ECDF:
    """Empirical CDF corresponding to an empirical distribution of scores.

    Args
        samples (numpy array): Empirical distribution of data e.g. scores belonging to an algorithm.

    Returns
        Callable: the empirical CDF belonging to an empirical score distribution.
    """

    def __init__(self, samples: Array):
        if isinstance(samples, list):
            samples = np.array(samples)
        self.samples = np.sort(samples)

    def __len__(self):
        return len(self.samples)

    def __call__(self, x: Union[float, Array]) -> Union[float, np.ndarray]:
        if isinstance(x, list):
            x = np.array(x)
        ecdf_values = np.searchsorted(self.samples, x,
                                      side="right") / len(self)
        return ecdf_values


def bootstrap_multiprocessing(func: Callable,
                              num_workers: int,
                              n_bootstrap: int,
                              desc: str = "",
                              verbose: bool = True) -> List:
    """Bootstrap a function in parallel

    Args:
        func (Callable): function to bootstrap
        num_workers (int): number of workers
        n_bootstrap (int): number of bootstrap samples
        desc (str, optional): description for tqdm. Defaults to empty string.

    Returns:
        List: list of results

    Example:
        >>> from soe.utils import bootstrap_multiprocessing
        >>> def func(i):
        ...     return i
        >>> bootstrap_multiprocessing(func, 4, 10)
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

    """
    pool = ThreadPool(processes=num_workers)
    results = list(
        tqdm.tqdm(pool.imap(func, range(n_bootstrap)),
                  desc=desc,
                  total=n_bootstrap,
                  disable=not verbose))
    pool.close()
    return results
