from dataclasses import asdict
from bbs.search import binary_search


# remove x prefix, then `poetry run python3 -m pytest test/test_search.py::test_me`
def xtest_me():
    ms = []
    for i in range(0, 11):
        metrics = binary_search(lo=0, hi=10, target=i, epsilon=1)
        ms.append({i: metrics})
    __import__("pdb").set_trace()


def test_default():
    metrics = binary_search(lo=0, hi=1000, target=2, epsilon=9)

    assert metrics.total_steps == 7
    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=0, hi=1000, mid=500),
            dict(step=1, lo=0, hi=500, mid=250),
            dict(step=2, lo=0, hi=250, mid=125),
            dict(step=3, lo=0, hi=125, mid=63),
            dict(step=4, lo=0, hi=63, mid=32),
            dict(step=5, lo=0, hi=32, mid=16),
            dict(step=6, lo=0, hi=16, mid=8),
            dict(step=7, lo=0, hi=8, mid=8),
        ]
    )

    metrics = binary_search(lo=0, hi=1000, target=998, epsilon=2)
    assert metrics.total_steps == 9
    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=0, hi=1000, mid=500),
            dict(step=1, lo=500, hi=1000, mid=750),
            dict(step=2, lo=750, hi=1000, mid=875),
            dict(step=3, lo=875, hi=1000, mid=938),
            dict(step=4, lo=938, hi=1000, mid=969),
            dict(step=5, lo=969, hi=1000, mid=985),
            dict(step=6, lo=985, hi=1000, mid=993),
            dict(step=7, lo=993, hi=1000, mid=997),
            dict(step=8, lo=997, hi=1000, mid=999),
            dict(step=9, lo=997, hi=999, mid=999),
        ]
    )


def test_custom_mid_function():
    curr_mid = [0]

    def inc_linear_mid_func(start: int):
        curr_mid[0] = start

        def helper(*_):
            curr_mid[0] += 1
            return curr_mid[0] - 1

        return helper

    def dec_linear_mid_func(start: int):
        curr_mid[0] = start

        def helper(*_):
            curr_mid[0] -= 1
            return curr_mid[0] + 1

        return helper

    metrics = binary_search(lo=0, hi=4, target=3, mid_func=inc_linear_mid_func(0))
    assert metrics.total_steps == 4
    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=0, hi=4, mid=0),
            dict(step=1, lo=0, hi=4, mid=1),
            dict(step=2, lo=1, hi=4, mid=2),
            dict(step=3, lo=2, hi=4, mid=3),
            dict(step=4, lo=3, hi=4, mid=3),
        ]
    )

    metrics = binary_search(
        lo=0, hi=90000, target=3, mid_func=inc_linear_mid_func(0), epsilon=2
    )
    assert metrics.total_steps == 5
    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=0, hi=90000, mid=0),
            dict(step=1, lo=0, hi=90000, mid=1),
            dict(step=2, lo=1, hi=90000, mid=2),
            dict(step=3, lo=2, hi=90000, mid=3),
            dict(step=4, lo=3, hi=90000, mid=4),
            dict(step=5, lo=3, hi=4, mid=4),
        ]
    )

    metrics = binary_search(
        lo=0, hi=90000, target=89995, mid_func=dec_linear_mid_func(90000), epsilon=2
    )
    assert metrics.total_steps == 5
    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=0, hi=90000, mid=89999),
            dict(step=1, lo=0, hi=89999, mid=89998),
            dict(step=2, lo=0, hi=89998, mid=89997),
            dict(step=3, lo=0, hi=89997, mid=89996),
            dict(step=4, lo=0, hi=89996, mid=89995),
            dict(step=5, lo=89995, hi=89996, mid=89995),
        ]
    )
