from dataclasses import asdict
from bbs.bimodal.search import (
    bimodal,
    binary_search_bimodal,
    enhanced_binary_search_bimodal,
)


def test_binary_search_bimodal():
    rv = bimodal(mu1=0, std_dev1=1000, mu2=4000, std_dev2=1000, weight1=0.5)

    metrics = binary_search_bimodal(rv, 4000)
    assert (asdict(metrics)) == dict(
        steps=[
            dict(step=0, lo=-2862, hi=6894, mid=2016),
            dict(step=1, lo=2016, hi=6894, mid=4455),
            dict(step=2, lo=2016, hi=4455, mid=3236),
            dict(step=3, lo=3236, hi=4455, mid=3846),
            dict(step=4, lo=3846, hi=4455, mid=4151),
            dict(step=5, lo=3846, hi=4151, mid=3999),
            dict(step=6, lo=3999, hi=4151, mid=4075),
            dict(step=7, lo=3999, hi=4075, mid=4037),
            dict(step=8, lo=3999, hi=4037, mid=4018),
            dict(step=9, lo=3999, hi=4018, mid=4009),
            dict(step=10, lo=3999, hi=4009, mid=4004),
            dict(step=11, lo=3999, hi=4004, mid=4002),
            dict(step=12, lo=3999, hi=4002, mid=4001),
            dict(step=13, lo=3999, hi=4001, mid=4000),
            dict(step=14, lo=4000, hi=4001, mid=4000),
        ]
    )


def test_enhanced_binary_search_bimodal():
    rv = bimodal(mu1=0, std_dev1=1000, mu2=4000, std_dev2=1000, weight1=0.5)

    metrics = enhanced_binary_search_bimodal(rv, 4000)
    assert (asdict(metrics)) == dict(
        steps=[
            dict(step=0, lo=-2862, hi=6894, mid=2001),
            dict(step=1, lo=2001, hi=6894, mid=3998),
            dict(step=2, lo=3998, hi=6894, mid=4671),
            dict(step=3, lo=3998, hi=4671, mid=4317),
            dict(step=4, lo=3998, hi=4317, mid=4156),
            dict(step=5, lo=3998, hi=4156, mid=4077),
            dict(step=6, lo=3998, hi=4077, mid=4038),
            dict(step=7, lo=3998, hi=4038, mid=4018),
            dict(step=8, lo=3998, hi=4018, mid=4008),
            dict(step=9, lo=3998, hi=4008, mid=4003),
            dict(step=10, lo=3998, hi=4003, mid=4001),
            dict(step=11, lo=3998, hi=4001, mid=4000),
            dict(step=12, lo=4000, hi=4001, mid=4000),
        ]
    )
