from typing import cast

import numpy as np
from scipy import stats
from scipy.stats import distributions
from scipy.stats._distn_infrastructure import rv_continuous_frozen

from bbs.search import binary_search
from bbs.stats import truncated_median


class bimodal_gen(stats.rv_continuous):
    def _pdf(
        self,
        x,
        mu1,
        std_dev1,
        mu2,
        std_dev2,
        weight1,
    ):
        pdf1 = stats.norm.pdf(x, mu1, std_dev1)
        pdf2 = stats.norm.pdf(x, mu2, std_dev2)

        return weight1 * pdf1 + (1 - weight1) * pdf2

    def _get_support(self, mu1, std_dev1, mu2, std_dev2, *_):
        a = min(
            cast(np.float64, stats.norm.ppf(0.0001, loc=mu1, scale=std_dev1)),
            cast(np.float64, stats.norm.ppf(0.0001, loc=mu2, scale=std_dev2)),
        )
        b = max(
            cast(np.float64, stats.norm.ppf(0.9999, loc=mu1, scale=std_dev1)),
            cast(np.float64, stats.norm.ppf(0.9999, loc=mu2, scale=std_dev2)),
        )
        return a, b

    def _argcheck(self, *_):
        return True


def bimodal(mu1=0, std_dev1=0, mu2=0, std_dev2=1, weight1=0.5):
    dist = bimodal_gen(name="bimodal")(
        mu1=mu1,
        std_dev1=std_dev1,
        mu2=mu2,
        std_dev2=std_dev2,
        weight1=weight1,
    )
    return cast(rv_continuous_frozen, dist)


def binary_search_bimodal(rv: distributions.rv_frozen, target: int, epsilon: int = 1):
    lo, hi = int(rv.ppf(0.001)), int(rv.ppf(0.999))
    return binary_search(lo, hi, target, epsilon=epsilon)


def enhanced_binary_search_bimodal(
    rv: distributions.rv_frozen, target: int, epsilon: int = 1
):
    lo, hi = int(rv.ppf(0.001)), int(rv.ppf(0.999))
    return binary_search(
        lo,
        hi,
        target,
        epsilon=epsilon,
        mid_func=truncated_median(rv),
    )
