from copy import deepcopy
from dataclasses import asdict

from scipy.stats import norm

from bbs.normal.search import binary_search_normal, enhanced_binary_search_normal


def test_binary_search_normal_sign():
    rv = norm(loc=0, scale=900000)

    def sign(value):
        return -1 if value <= 58 else 1

    metrics = binary_search_normal(rv, epsilon=10, sign=sign)

    assert metrics.total_steps == 20
    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=-3780000, hi=3780000, mid=0),
            dict(step=1, lo=0, hi=3780000, mid=1890000),
            dict(step=2, lo=0, hi=1890000, mid=945000),
            dict(step=3, lo=0, hi=945000, mid=472500),
            dict(step=4, lo=0, hi=472500, mid=236250),
            dict(step=5, lo=0, hi=236250, mid=118125),
            dict(step=6, lo=0, hi=118125, mid=59063),
            dict(step=7, lo=0, hi=59063, mid=29532),
            dict(step=8, lo=0, hi=29532, mid=14766),
            dict(step=9, lo=0, hi=14766, mid=7383),
            dict(step=10, lo=0, hi=7383, mid=3692),
            dict(step=11, lo=0, hi=3692, mid=1846),
            dict(step=12, lo=0, hi=1846, mid=923),
            dict(step=13, lo=0, hi=923, mid=462),
            dict(step=14, lo=0, hi=462, mid=231),
            dict(step=15, lo=0, hi=231, mid=116),
            dict(step=16, lo=0, hi=116, mid=58),
            dict(step=17, lo=58, hi=116, mid=87),
            dict(step=18, lo=58, hi=87, mid=73),
            dict(step=19, lo=58, hi=73, mid=66),
            dict(step=20, lo=58, hi=66, mid=66),
        ]
    )


def test_binary_search_normal():
    rv = norm(loc=0, scale=900000)
    metrics = binary_search_normal(rv, 58, epsilon=10)

    assert metrics.total_steps == 20
    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=-3780000, hi=3780000, mid=0),
            dict(step=1, lo=0, hi=3780000, mid=1890000),
            dict(step=2, lo=0, hi=1890000, mid=945000),
            dict(step=3, lo=0, hi=945000, mid=472500),
            dict(step=4, lo=0, hi=472500, mid=236250),
            dict(step=5, lo=0, hi=236250, mid=118125),
            dict(step=6, lo=0, hi=118125, mid=59063),
            dict(step=7, lo=0, hi=59063, mid=29532),
            dict(step=8, lo=0, hi=29532, mid=14766),
            dict(step=9, lo=0, hi=14766, mid=7383),
            dict(step=10, lo=0, hi=7383, mid=3692),
            dict(step=11, lo=0, hi=3692, mid=1846),
            dict(step=12, lo=0, hi=1846, mid=923),
            dict(step=13, lo=0, hi=923, mid=462),
            dict(step=14, lo=0, hi=462, mid=231),
            dict(step=15, lo=0, hi=231, mid=116),
            dict(step=16, lo=0, hi=116, mid=58),
            dict(step=17, lo=58, hi=116, mid=87),
            dict(step=18, lo=58, hi=87, mid=73),
            dict(step=19, lo=58, hi=73, mid=66),
            dict(step=20, lo=58, hi=66, mid=66),
        ]
    )


def test_enhanced_binary_search_normal():
    rv = norm(loc=0, scale=900000)
    metrics = enhanced_binary_search_normal(rv, 58, epsilon=10)

    assert metrics.total_steps == 18
    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=-3780000, hi=3780000, mid=0),
            dict(step=1, lo=0, hi=3780000, mid=607022),
            dict(step=2, lo=0, hi=607022, mid=286768),
            dict(step=3, lo=0, hi=286768, mid=141577),
            dict(step=4, lo=0, hi=141577, mid=70570),
            dict(step=5, lo=0, hi=70570, mid=35258),
            dict(step=6, lo=0, hi=35258, mid=17626),
            dict(step=7, lo=0, hi=17626, mid=8813),
            dict(step=8, lo=0, hi=8813, mid=4407),
            dict(step=9, lo=0, hi=4407, mid=2204),
            dict(step=10, lo=0, hi=2204, mid=1102),
            dict(step=11, lo=0, hi=1102, mid=551),
            dict(step=12, lo=0, hi=551, mid=276),
            dict(step=13, lo=0, hi=276, mid=138),
            dict(step=14, lo=0, hi=138, mid=69),
            dict(step=15, lo=0, hi=69, mid=35),
            dict(step=16, lo=35, hi=69, mid=52),
            dict(step=17, lo=52, hi=69, mid=61),
            dict(step=18, lo=52, hi=61, mid=61),
        ]
    )


def test_enhanced_binary_search_normal_when_no_whole_number_gte_50_cdf():
    rv = norm(loc=0, scale=9000000000000)
    metrics = enhanced_binary_search_normal(rv, 58, epsilon=1)

    assert metrics.total_steps == 45
    expected_steps = dict(
        steps=[
            dict(step=0, lo=-37800000000000, hi=37800000000000, mid=0),
            dict(step=1, lo=0, hi=37800000000000, mid=6070218765382),
            dict(step=2, lo=0, hi=6070218765382, mid=2867675087205),
            dict(step=3, lo=0, hi=2867675087205, mid=1415758058514),
            dict(step=4, lo=0, hi=1415758058514, mid=705692839445),
            dict(step=5, lo=0, hi=705692839445, mid=352575353697),
            dict(step=6, lo=0, hi=352575353697, mid=176253861865),
            dict(step=7, lo=0, hi=176253861865, mid=88122706188),
            dict(step=8, lo=0, hi=88122706188, mid=44060825068),
            dict(step=9, lo=0, hi=44060825068, mid=22030346533),
            dict(step=10, lo=0, hi=22030346533, mid=11015165017),
            dict(step=11, lo=0, hi=11015165017, mid=5507581478),
            dict(step=12, lo=0, hi=5507581478, mid=2753790611),
            dict(step=13, lo=0, hi=2753790611, mid=1376895290),
            dict(step=14, lo=0, hi=1376895290, mid=688447643),
            dict(step=15, lo=0, hi=688447643, mid=344223822),
            dict(step=16, lo=0, hi=344223822, mid=172111911),
            dict(step=17, lo=0, hi=172111911, mid=86055956),
            dict(step=18, lo=0, hi=86055956, mid=43027978),
            dict(step=19, lo=0, hi=43027978, mid=21513990),
            dict(step=20, lo=0, hi=21513990, mid=10756996),
            dict(step=21, lo=0, hi=10756996, mid=5378498),
            dict(step=22, lo=0, hi=5378498, mid=2689250),
            dict(step=23, lo=0, hi=2689250, mid=1344626),
            dict(step=24, lo=0, hi=1344626, mid=672313),
            dict(step=25, lo=0, hi=672313, mid=336157),
            dict(step=26, lo=0, hi=336157, mid=168079),
            dict(step=27, lo=0, hi=168079, mid=84040),
            dict(step=28, lo=0, hi=84040, mid=42021),
            dict(step=29, lo=0, hi=42021, mid=21011),
            dict(step=30, lo=0, hi=21011, mid=10506),
            dict(step=31, lo=0, hi=10506, mid=5254),
            dict(step=32, lo=0, hi=5254, mid=2628),
            dict(step=33, lo=0, hi=2628, mid=1314),
            dict(step=34, lo=0, hi=1314, mid=657),
            dict(step=35, lo=0, hi=657, mid=329),
            dict(step=36, lo=0, hi=329, mid=165),
            dict(step=37, lo=0, hi=165, mid=83),
            dict(step=38, lo=0, hi=83, mid=42),
            dict(step=39, lo=42, hi=83, mid=63),
            dict(step=40, lo=42, hi=63, mid=53),
            dict(step=41, lo=53, hi=63, mid=58),
            dict(step=42, lo=58, hi=63, mid=61),
            dict(step=43, lo=58, hi=61, mid=60),
            dict(step=44, lo=58, hi=60, mid=59),
            dict(step=45, lo=58, hi=59, mid=59),
        ]
    )
    local_steps = expected_steps
    ci_steps = deepcopy(expected_steps)
    step22_ci = dict(step=22, lo=0, hi=10756996, mid=5378499)
    step23_ci = dict(step=23, lo=0, hi=5378499, mid=2689250)
    ci_steps["steps"][21] = step22_ci
    ci_steps["steps"][22] = step23_ci

    assert asdict(metrics) == local_steps  # or asdict(metrics) == ci_steps


def test_binary_search_normal_when_no_whole_number_gte_50_cdf():
    rv = norm(loc=0, scale=9000000000000)
    metrics = binary_search_normal(rv, 58, epsilon=1)

    assert metrics.total_steps == 46
    assert asdict(metrics) == dict(
        steps=[
            dict(step=0, lo=-37800000000000, hi=37800000000000, mid=0),
            dict(step=1, lo=0, hi=37800000000000, mid=18900000000000),
            dict(step=2, lo=0, hi=18900000000000, mid=9450000000000),
            dict(step=3, lo=0, hi=9450000000000, mid=4725000000000),
            dict(step=4, lo=0, hi=4725000000000, mid=2362500000000),
            dict(step=5, lo=0, hi=2362500000000, mid=1181250000000),
            dict(step=6, lo=0, hi=1181250000000, mid=590625000000),
            dict(step=7, lo=0, hi=590625000000, mid=295312500000),
            dict(step=8, lo=0, hi=295312500000, mid=147656250000),
            dict(step=9, lo=0, hi=147656250000, mid=73828125000),
            dict(step=10, lo=0, hi=73828125000, mid=36914062500),
            dict(step=11, lo=0, hi=36914062500, mid=18457031250),
            dict(step=12, lo=0, hi=18457031250, mid=9228515625),
            dict(step=13, lo=0, hi=9228515625, mid=4614257813),
            dict(step=14, lo=0, hi=4614257813, mid=2307128907),
            dict(step=15, lo=0, hi=2307128907, mid=1153564454),
            dict(step=16, lo=0, hi=1153564454, mid=576782227),
            dict(step=17, lo=0, hi=576782227, mid=288391114),
            dict(step=18, lo=0, hi=288391114, mid=144195557),
            dict(step=19, lo=0, hi=144195557, mid=72097779),
            dict(step=20, lo=0, hi=72097779, mid=36048890),
            dict(step=21, lo=0, hi=36048890, mid=18024445),
            dict(step=22, lo=0, hi=18024445, mid=9012223),
            dict(step=23, lo=0, hi=9012223, mid=4506112),
            dict(step=24, lo=0, hi=4506112, mid=2253056),
            dict(step=25, lo=0, hi=2253056, mid=1126528),
            dict(step=26, lo=0, hi=1126528, mid=563264),
            dict(step=27, lo=0, hi=563264, mid=281632),
            dict(step=28, lo=0, hi=281632, mid=140816),
            dict(step=29, lo=0, hi=140816, mid=70408),
            dict(step=30, lo=0, hi=70408, mid=35204),
            dict(step=31, lo=0, hi=35204, mid=17602),
            dict(step=32, lo=0, hi=17602, mid=8801),
            dict(step=33, lo=0, hi=8801, mid=4401),
            dict(step=34, lo=0, hi=4401, mid=2201),
            dict(step=35, lo=0, hi=2201, mid=1101),
            dict(step=36, lo=0, hi=1101, mid=551),
            dict(step=37, lo=0, hi=551, mid=276),
            dict(step=38, lo=0, hi=276, mid=138),
            dict(step=39, lo=0, hi=138, mid=69),
            dict(step=40, lo=0, hi=69, mid=35),
            dict(step=41, lo=35, hi=69, mid=52),
            dict(step=42, lo=52, hi=69, mid=61),
            dict(step=43, lo=52, hi=61, mid=57),
            dict(step=44, lo=57, hi=61, mid=59),
            dict(step=45, lo=57, hi=59, mid=58),
            dict(step=46, lo=58, hi=59, mid=58),
        ]
    )


def test_edges():
    rv = norm(loc=0, scale=0.1)

    metrics = binary_search_normal(rv, 900, epsilon=1, lo=-9000000000, hi=9000000000)
    assert metrics.total_steps == 34

    metrics = binary_search_normal(rv, -900, epsilon=1, lo=-9000000000, hi=9000000000)
    assert metrics.total_steps == 34
