from dataclasses import asdict
from bbs.expon.search import binary_search_expon, enhanced_binary_search_expon
from scipy.stats import expon


SCALE = 10**4


def test_binary_search_expon():
    rv = expon(scale=SCALE)
    metrics = binary_search_expon(rv, 1 * SCALE)

    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=0, hi=115130, mid=57565),
            dict(step=1, lo=0, hi=57565, mid=28783),
            dict(step=2, lo=0, hi=28783, mid=14392),
            dict(step=3, lo=0, hi=14392, mid=7196),
            dict(step=4, lo=7196, hi=14392, mid=10794),
            dict(step=5, lo=7196, hi=10794, mid=8995),
            dict(step=6, lo=8995, hi=10794, mid=9895),
            dict(step=7, lo=9895, hi=10794, mid=10345),
            dict(step=8, lo=9895, hi=10345, mid=10120),
            dict(step=9, lo=9895, hi=10120, mid=10008),
            dict(step=10, lo=9895, hi=10008, mid=9952),
            dict(step=11, lo=9952, hi=10008, mid=9980),
            dict(step=12, lo=9980, hi=10008, mid=9994),
            dict(step=13, lo=9994, hi=10008, mid=10001),
            dict(step=14, lo=9994, hi=10001, mid=9998),
            dict(step=15, lo=9998, hi=10001, mid=10000),
            dict(step=16, lo=10000, hi=10001, mid=10000),
        ]
    )


def test_enhanced_binary_search_expon():
    rv = expon(scale=SCALE)
    metrics = enhanced_binary_search_expon(rv, 1 * SCALE)

    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=0, hi=115130, mid=6932),
            dict(step=1, lo=6932, hi=115130, mid=13864),
            dict(step=2, lo=6932, hi=13864, mid=9809),
            dict(step=3, lo=9809, hi=13864, mid=11633),
            dict(step=4, lo=9809, hi=11633, mid=10680),
            dict(step=5, lo=9809, hi=10680, mid=10236),
            dict(step=6, lo=9809, hi=10236, mid=10021),
            dict(step=7, lo=9809, hi=10021, mid=9915),
            dict(step=8, lo=9915, hi=10021, mid=9968),
            dict(step=9, lo=9968, hi=10021, mid=9995),
            dict(step=10, lo=9995, hi=10021, mid=10008),
            dict(step=11, lo=9995, hi=10008, mid=10002),
            dict(step=12, lo=9995, hi=10002, mid=9999),
            dict(step=13, lo=9999, hi=10002, mid=10001),
            dict(step=14, lo=9999, hi=10001, mid=10000),
            dict(step=15, lo=10000, hi=10001, mid=10000),
        ]
    )
