from math import ceil, floor
from typing import Callable, Literal, cast
from manim import *  # noqa: F403, # pyright: ignore
from scipy import stats

from bbs.search import SearchMetrics, Step, default_mid
from bbs.stats import truncnorm_median
import os


def main():
    os.system(
        r"manim -pql bbs/animations/normal.py NormalSearch && open media/videos/normal/480p15/NormalSearch.mp4"
    )


def main_high():
    os.system(
        r"manim -pqh bbs/animations/normal.py NormalSearch && open media/videos/normal/1080p60/NormalSearch.mp4"
    )


# python3 -m manim -pql bbs/animations/normal.py NormalSearch && open media/videos/normal/480p15/NormalSearch.mp4
# python3 -m manim -pqh bbs/animations/normal.py NormalSearch && open media/videos/normal/1080p60/NormalSearch.mp4
class NormalSearch(Scene):
    def construct(self):
        VANILLA = 0
        enhanced = "" if VANILLA else "Enhanced "
        title = Text(f"{enhanced}Binary Search").scale(0.8).to_edge(UP)
        target = 1005
        tolerance = 10

        LOC = 1000
        SCALE = 2000
        rv = stats.norm(loc=LOC, scale=SCALE)

        def normal_dist(x):
            return rv.pdf(x)  # type: ignore

        lo = floor(rv.std() * -4.2 + rv.mean())
        hi = ceil(rv.std() * 4.2 + rv.mean())
        max_pdf = rv.pdf(rv.mean())  # type: ignore

        axes = Axes(
            x_range=[lo, hi + 0.1, rv.std()],
            y_range=[0, max_pdf + max_pdf / 4, (max_pdf + max_pdf / 4) / 4],
            axis_config={"color": BLACK},
            x_axis_config={
                "color": WHITE,
                "numbers_to_include": [lo, hi, target],
                "include_ticks": False,
                "decimal_number_config": {"num_decimal_places": 0},
            },
            y_axis_config={"stroke_width": 0},
        )
        x = axes.get_x_axis()
        x.numbers[2].set_color(YELLOW)  # type: ignore

        axes_labels = axes.get_axis_labels(x_label="", y_label="")
        normal_graph = axes.plot(normal_dist, color=WHITE)

        self.play(
            Create(axes),
            Create(axes_labels),
            Create(normal_graph),
            Write(title, rate_func=lambda _: 2),
        )
        self.wait(0.25)

        ####
        target_text = (
            MathTex(
                f"\\mu={LOC}, \\sigma={SCALE}, target={target}, tolerance={tolerance}",
                color=YELLOW,
            )
            .scale(0.6)
            .to_edge(UP)
            .shift(DOWN * 0.5)
        )
        self.play(Write(target_text, rate_func=lambda _: 2))
        ####

        steps_text = Text("", color=BLUE).scale(0.5).to_edge(LEFT).shift(UP * 3)
        bracket_text = Text("", color=BLUE).scale(0.5).to_edge(LEFT).shift(UP * 2.5)
        search_val_text = Text("", color=BLUE).scale(0.5).to_edge(LEFT).shift(UP * 2)
        top_text = MathTex("")
        bbox = Rectangle(
            width=top_text.get_width() + 0.4,
            height=top_text.get_height() + 0.2,
            fill_color=BLACK,
            fill_opacity=1,
            stroke_width=0,
        )
        bbox.move_to(top_text)

        def bisection_search(
            lo: int,
            hi: int,
            target: int,
            epsilon: int = 1,
            mid_func: Callable[[int, int], int] = default_mid,
        ):
            def sign(val) -> Literal[-1, 1]:
                """
                val - target
                sign(0) = -1
                """
                return -1 if val <= target else 1

            if epsilon < 1:
                raise Exception("ε must be >= 1")

            steps = -1  # First step is 0th step because we haven't probed yet.
            metrics = SearchMetrics(steps=[])

            last_mid = None
            while lo <= hi:
                steps += 1

                new_bracket_text = (
                    Text(f"[{lo}, {hi})", color=BLUE)
                    .scale(0.5)
                    .to_edge(LEFT)
                    .shift(UP * 2.5)
                )

                new_steps_text = (
                    Text(f"Steps: {steps}", color=BLUE)
                    .scale(0.5)
                    .to_edge(LEFT)
                    .shift(UP * 3)
                )

                if hi - lo <= epsilon:
                    last_mid = cast(int, last_mid)
                    last_step = Step(step=steps, lo=lo, hi=hi, mid=last_mid)
                    empty = Text("", color=RED).scale(0.6).to_edge(LEFT).shift(UP * 1.5)
                    self.play(
                        Transform(bracket_text, new_bracket_text),
                        Transform(search_val_text, empty),
                        Transform(steps_text, new_steps_text),
                    )
                    metrics.steps.append(last_step)
                    return metrics

                m = mid_func(lo, hi)

                # NOTE: This can happen e.g.
                # norm(loc=5, scale=1.15), loc=0,hi=5
                # m == 5
                # => hi - lo + 1 == 6 (> 2)
                if m == hi:
                    m -= 1

                # for vanilla: add
                if VANILLA:
                    new_search_val_text = (
                        MathTex(
                            r"m = \left\lceil\frac{lo + hi}{2}\right\rceil = {search}".replace(
                                "{search}", str(m)
                            ),
                            color=BLUE,
                        )
                        .scale(0.5)
                        .to_edge(LEFT)
                        .shift(UP * 2)
                    )
                else:
                    new_search_val_text = (
                        MathTex(
                            r"m = \left\lceil F_{X[lo,hi]}^{-1}(0.5)\right\rceil = {search}".replace(
                                "{search}", str(m)
                            ),
                            color=BLUE,
                        )
                        .scale(0.5)
                        .to_edge(LEFT)
                        .shift(UP * 2)
                    )

                y = normal_dist(m)
                # Create vertical line at current search value
                vert_line = axes.get_vertical_line(axes.c2p(m, y), color=YELLOW)  # type: ignore

                s = sign(m)
                metrics.steps.append(Step(step=steps, lo=lo, hi=hi, mid=m))

                new_top_text = (
                    MathTex(
                        f"{target} < {m}" if s == 1 else f"{m} \\leq {target}",
                        color=YELLOW,
                    )
                    .scale(0.5)
                    .next_to(vert_line, UP)
                )
                new_bbox = Rectangle(
                    width=new_top_text.get_width() + 0.4,
                    height=new_top_text.get_height() + 0.2,
                    fill_color=BLACK,
                    fill_opacity=1,
                    stroke_width=0,
                )
                new_bbox.move_to(new_top_text)

                self.play(
                    Transform(bracket_text, new_bracket_text),
                    Transform(search_val_text, new_search_val_text),
                    Transform(steps_text, new_steps_text),
                    Create(vert_line),
                    Transform(bbox, new_bbox),
                    Transform(top_text, new_top_text),
                )

                if s == 1:  # sign: m - x (m > x)
                    # Shade right side
                    right_area = axes.get_area(
                        normal_graph,
                        x_range=(m, hi),
                        color=RED,
                        stroke_width=0,
                        opacity=0.2,
                    )
                    self.play(Create(right_area))
                    hi = m
                else:
                    # Shade left side
                    left_area = axes.get_area(
                        normal_graph,
                        x_range=(lo, m),
                        color=RED,
                        stroke_width=0,
                        opacity=0.2,
                    )
                    self.play(Create(left_area))
                    lo = m
                last_mid = m
                self.wait(0.5)
                self.play(FadeOut(vert_line))

        if VANILLA:
            bisection_search(lo, hi, target, epsilon=tolerance)
        else:
            bisection_search(
                lo, hi, target, epsilon=tolerance, mid_func=truncnorm_median(rv)
            )

        self.wait(2)
