import numpy as np
import numpy.fft as fft
from numpy.random import default_rng
from stochastic.processes import BrownianExcursion
from scipy.ndimage import gaussian_filter1d


def _fractional_brownian_motion(rng=None, resolution_exponent=8, beta=4):
    """
    Fractional Brownian motion
    Used to generate distributions probability densities
    :param rng: distributions number generator
    :param resolution_exponent: exponent of 2 for resolution (bins)
    :param beta: negative exponent for spectral power law
    :return:
    """

    resolution = int(2**resolution_exponent)
    half_resolution = int(resolution / 2)

    if rng is None:
        # https://numpy.org/doc/stable/reference/random/index.html#quick-start
        rng = default_rng()

    # Box-Muller Transform
    r = 1 - rng.random(size=half_resolution)
    t = 2 * np.pi * rng.random(size=half_resolution)

    rel_amp = np.power(
        (
            np.arange(
                half_resolution,
            )
            + 1
        )
        / half_resolution,
        -beta,
    )

    p = np.sqrt(rel_amp * -2 * np.log(r))

    f_re = np.cos(t) * p
    f_im = np.sin(t) * p

    f_re = np.hstack([f_re, np.flip(f_re)])
    f_im = np.hstack([f_im, -np.flip(f_im)])

    # normalization difference between fft vs ifft is
    # immaterial, since we renormalize as a probability
    # distribution later.
    x = fft.fft(f_re + f_im * 1j)

    re = np.real(x)
    im = np.imag(x)

    return re, im


def random_distribution(n, rng=None, resolution_exponent=8, beta=4):
    """
    Generate n distributions probability distributions that
    go to zero on either end of the interval domain [0, 1]

    n       is number of distributions to generate
    rng     is distributions number generator
    res_exp is exponent of 2 for resolution (bins)
    beta    is negative exponent for spectral power law
    """

    out = []

    while n > 0:

        re, im = _fractional_brownian_motion(
            rng=rng, resolution_exponent=resolution_exponent, beta=beta
        )

        re_argmin = np.argmin(re)
        im_argmin = np.argmin(im)

        re = re - re[re_argmin]
        im = im - im[im_argmin]

        re = np.roll(re / np.sum(re), -re_argmin)
        im = np.roll(im / np.sum(im), -im_argmin)

        out.append(re * 0.99 + 0.01 / len(re))
        if n > 1:
            out.append(im * 0.99 + 0.01 / len(im))

        n -= 2

    return np.array(out)


def sample(n, rng=None, resolution_exponent=8, sigma=1):
    """
    TODO
    """
    out = []
    resolution = 2**resolution_exponent
    for _ in range(n):
        sample_ = BrownianExcursion(t=1, rng=rng).sample(resolution)
        sample_ = gaussian_filter1d(sample_, sigma=sigma)
        sample_ -= min(sample_)  # TODO doesn't quite work
        sample_ /= np.sum(sample_)
        out.append(sample_)
    return np.array(out)
