from decimal import Decimal
from math import ceil
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import distributions, truncnorm, uniform
from scipy.stats._distn_infrastructure import rv_continuous_frozen


def truncnorm_median(rv: distributions.rv_frozen):
    def median(lo: int, hi: int):
        if hi < lo:
            raise Exception("hi must be greater than lo")

        a, b = (lo - rv.mean()) / rv.std(), (hi - rv.mean()) / rv.std()
        rv_trunc = truncnorm(a, b, loc=rv.mean(), scale=rv.std())

        return ceil(truncate_float(rv_trunc.median(), 12))

    return median


def truncuniform_median(lo: int, hi: int):
    if hi < lo:
        raise Exception("hi must be greater than lo")

    rv = uniform(loc=lo, scale=hi - lo)
    return ceil(truncate_float(rv.median(), 12))


def truncated_median(rv: distributions.rv_frozen):
    """
    Finds the median of a given distribution within the range of [lo, hi].
    Returns the first integer `x` between [lo, hi] such that
    `lo < x <= hi` and `abs((CDF(x) - CDF(lo))/(CDF(hi) - CDF(lo)) - 1/2) <= ε`
    """

    def median(lo: int, hi: int):
        if hi < lo:
            raise Exception("hi must be greater than lo")

        total_cdf = Decimal(rv.cdf(hi)) - Decimal(rv.cdf(lo))

        def cdf(val):
            return (Decimal(rv.cdf(val)) - Decimal(rv.cdf(lo))) / total_cdf

        return bisect_right(lo, hi, 0.5, key=cdf)

    return median


# Changed default epsilon from 1e-16 to 1e-14 because
# rv = uniform(loc=0, scale=9000)
# rv.ppf(rv.cdf(204)) is not 204
# and
#  (rv.cdf(207) - rv.cdf(204)) / (rv.cdf(210) - rv.cdf(204)) is < 0.5
#  even though uniform(loc=204, scale=6).median() is 207
def bisect_right(
    lo: int, hi: int, target: float, key: Callable[[int], Decimal], epsilon=1e-14
):
    """
    Returns the first `value` such that `abs(f(value) - target) <= epsilon`.
    `value` and `f(value)` must have a Pearson correlation coefficient of +1.
      - i.e. `f(x1) > f(x2) ∀ x1 > x2`
    """

    while lo <= hi:
        m = (hi + lo) // 2
        val = key(m)

        if abs(val - Decimal(target)) <= epsilon:
            return m
        elif val > Decimal(target):
            hi = m - 1
        else:
            lo = m + 1

    return lo


def truncate_float(x: float, scale: int) -> float:
    return float(Decimal(x).quantize(Decimal(f"1e-{scale}")))


def plot_distribution(rv: rv_continuous_frozen, num_samples=1000, num_bins=200):
    samples = rv.rvs(size=num_samples)
    x_range = np.linspace(rv.ppf(0.001), rv.ppf(0.999), 1000)
    plt.figure(figsize=(10, 6))

    # Samples Histogram
    plt.hist(
        samples,
        bins=num_bins,
        density=True,
        alpha=0.7,
        color="skyblue",
        label="Samples",
    )

    # PDF curve
    plt.plot(x_range, rv.pdf(x_range), "r-", lw=2, label="PDF")

    plt.xlabel("Value")
    plt.ylabel("Density")
    plt.title(rv.dist.name)
    plt.legend()

    plt.grid(True, alpha=0.3)
    plt.show()
