import numpy as np
import graphviz
from dataclasses import dataclass
from math import ceil, floor
from typing import Optional
from collections import Counter

from graphviz.graphs import Digraph
from scipy.stats import distributions, norm

from bbs.search import default_mid
from bbs.stats import truncnorm_median


@dataclass(frozen=True, kw_only=True)
class Value:
    lo: int
    hi: int
    search_val: int


class Node:
    def __init__(self, value: Value):
        self.left = None
        self.right = None
        self.value = value


def generate_tree(lo: int, hi: int, mid_func=default_mid):
    m = mid_func(lo, hi)
    root = Node(Value(lo=lo, hi=hi, search_val=m))

    def helper(node):
        lo, hi, search_val = node.value.lo, node.value.hi, node.value.search_val

        if hi - lo <= 1:  # epsilon = 1
            return

        # NOTE: this is possible e.g.
        # rv = norm(loc=5, scale=1.15)
        # truncnorm_median(rv, 0, 5) == 5
        # i.e. CDF([4, 5]) > CDF([0, 4])
        left = mid_func(lo, search_val)
        if left == search_val:
            left -= 1

        # sign(search_val) == 1, search_val > target
        node.left = Node(Value(lo=lo, hi=search_val, search_val=left))
        helper(node.left)

        # sign(search_val) == -1, search_val <= target
        node.right = Node(
            Value(lo=search_val, hi=hi, search_val=mid_func(search_val, hi))
        )
        helper(node.right)

    helper(root)
    return root


def add_nodes_edges(
    graph: Digraph, node: Node, pos="", depth=0, samples: Optional[dict] = None
):
    if node:
        f = f"[{node.value.lo}, {node.value.hi})"
        if node.left or node.right:
            f += f"\n{node.value.search_val}"
        elif samples:
            f += f"\n{samples[node.value.lo]}x"

        is_leaf = node.left is None and node.right is None
        color = "lightgreen" if is_leaf else "lightblue"

        graph.node(str(id(node)), f, fillcolor=color)

        if node.left:
            graph.edge(
                str(id(node)), str(id(node.left)), f"x < {node.value.search_val}"
            )
            add_nodes_edges(graph, node.left, f"{pos}L", depth + 1, samples=samples)

        if node.right:
            graph.edge(
                str(id(node)), str(id(node.right)), f"x >= {node.value.search_val}"
            )
            add_nodes_edges(graph, node.right, f"{pos}R", depth + 1, samples=samples)


def create_binary_tree_diagram(
    root,
    filename="binary_tree",
    directory="./figures/binary_tree",
    samples: Optional[dict] = None,
):
    dot = graphviz.Digraph(comment="Binary Tree")
    dot.attr(rankdir="TB", size="80,110!")
    dot.attr(
        "node",
        shape="circle",
        style="filled",
        fillcolor="lightblue",
        fontname="SF Mono",
        fontsize="12",
        width="0.5",
        height="0.5",
    )

    add_nodes_edges(dot, root, samples=samples)

    dot.render(filename, directory=directory, format="pdf", cleanup=True)
    print(f"The binary tree diagram has been saved as {filename}.pdf")


def generate_vanilla_tree(lo: int, hi: int):
    return generate_tree(lo, hi)


def generate_enhanced_tree(
    rv: distributions.rv_frozen,
    lo: Optional[int] = None,
    hi: Optional[int] = None,
):
    lo = lo if isinstance(lo, int) else floor(rv.std() * -4.2 + rv.mean())
    hi = hi if isinstance(hi, int) else ceil(rv.std() * 4.2 + rv.mean())
    return generate_tree(lo, hi, mid_func=truncnorm_median(rv))


def run():
    rv = norm(loc=5, scale=1.15)
    rv.random_state = np.random.RandomState(seed=42)
    samples = Counter([floor(x) for x in rv.rvs(10000)])

    vanilla = generate_vanilla_tree(0, 10)
    create_binary_tree_diagram(vanilla, "vanilla", samples=samples)

    normal = generate_enhanced_tree(rv)
    create_binary_tree_diagram(normal, "enhanced", samples=samples)


if __name__ == "__main__":
    run()
