import matplotlib.pyplot as plt
import numpy as np


def sample_pi_iama(alpha, n_samples=1, random_state=None):
    """
    Sample from
    pi*_IAMA(y) = alpha y^alpha (1-y)^{alpha-1} / (y^alpha + (1-y)^alpha)^2
    on [0, 1].

    Parameters
    ----------
    alpha : float
        alpha > 0
    n_samples : int
        number of samples
    random_state : int or None
        random seed

    Returns
    -------
    y : ndarray
        samples in [0, 1]
    """
    rng = np.random.default_rng(random_state)
    u = rng.uniform(0.0, 1.0, size=n_samples)

    y = (u ** (1.0 / alpha)) / (u ** (1.0 / alpha) + (1.0 - u) ** (1.0 / alpha))
    return y


with plt.style.context("config/paper.mplstyle"):
    fig = plt.figure(figsize=(3.3, 2))
    ax = fig.add_subplot(1, 1, 1)
    for N in [2, 4, 8]:
        samples = sample_pi_iama(alpha=1 / (N - 1), n_samples=100000, random_state=42)
        ax.hist(samples, bins=100, density=True, alpha=0.3, label=f"N={N}")
        ax.set_title(r"Samples from $\pi^*_{IAMA}$")
        ax.set_ylim(0, 5)
    ax.legend()
fig.savefig("figs/optimal_distribution_pi_iama.png")
