import os
from collections import Counter
from contextlib import contextmanager
from math import floor
from typing import Generator, TypeVar

import numpy as np
from scipy.stats import norm

import bbs.normal.plot as plot
from bbs.draw_tree import (
    create_binary_tree_diagram,
    generate_enhanced_tree,
    generate_vanilla_tree,
)
from bbs.normal.experiment import BatchUtil, Box, Epsilon, KLDiv, StdDev
from bbs.stats import plot_distribution


def epsilon(scale=1000):
    batch = Epsilon.run(scale=scale)
    df_exp = BatchUtil.experiments_df(batch)
    df_ks = BatchUtil.ks_test_df(batch)

    # Epsilon.mline_chart(df_exp)
    plot.Epsilon.show(df_exp, df_ks, write_html_name=str(scale))


def std_dev():
    batch = StdDev.run()
    df_exp = BatchUtil.experiments_df(batch)
    df_ks = BatchUtil.ks_test_df(batch)
    plot.StdDev.show(df_exp, df_ks)


def kld():
    batch = KLDiv.run()
    df_exp = BatchUtil.experiments_df(batch)
    df_ks = BatchUtil.ks_test_df(batch)
    plot.KLDiv.show(df_exp, df_ks)


def box():
    Box.show()


def plot_dist():
    rv = norm(loc=0, scale=1000)
    plot_distribution(rv)  # type: ignore


def testing():
    single = False

    if single:
        # kld()
        epsilon()
        # box()
    else:
        for scale in [10**i for i in range(1, 9)]:
            epsilon(scale=scale)


T = TypeVar("T")


@contextmanager
def yield_self(value: T) -> Generator[T, None, None]:
    print(f"🏋️Generating:\t{value}")
    yield value
    print(f"✅Saved:\t{value}")


def generate_figures():
    """
    Save figures for paper.

    figures/normal/latex/
      line_chart.pdf
      table.txt
      box_plot.pdf
      tree_enhanced.pdf
      tree_vanilla.pdf
      kld/
        line_chart.pdf
        table.txt
    """
    path = "./figures/normal/latex"
    os.makedirs(path, exist_ok=True)
    os.makedirs(os.path.join(path, "kld"), exist_ok=True)

    scale = 10_000
    batch = Epsilon.run(scale=scale)
    df_exp = BatchUtil.experiments_df(batch)
    df_ks = BatchUtil.ks_test_df(batch)

    # line chart
    with yield_self(os.path.join(path, "line_chart.pdf")) as filename:
        Epsilon.mline_chart(df_exp, pdf_name=filename)

    # latex table
    with yield_self(os.path.join(path, "table.txt")) as filename:
        table = BatchUtil.latex_aggregate_table(df_ks)
        with open(filename, "w") as file:
            file.write(table)

    # box plot
    with yield_self(os.path.join(path, "box_plot.pdf")) as filename:
        df = Box.df(Box.run())
        Box.mfig(df, pdf_name=filename)

    # binary tree graphviz
    with yield_self(("tree_vanilla", "tree_enhanced")) as (
        vanilla_filename,
        enhanced_filename,
    ):
        rv = norm(loc=5, scale=1.15)
        rv.random_state = np.random.RandomState(seed=42)
        samples = Counter([floor(x) for x in rv.rvs(10000)])
        tree = generate_enhanced_tree(rv)
        create_binary_tree_diagram(
            tree, enhanced_filename, directory=path, samples=samples
        )

        vanilla = generate_vanilla_tree(0, 10)
        create_binary_tree_diagram(
            vanilla, filename=vanilla_filename, directory=path, samples=samples
        )

    # KLD
    batch = KLDiv.run()
    df_exp = BatchUtil.experiments_df(batch)
    df_ks = BatchUtil.ks_test_df(batch)

    # line chart
    with yield_self(os.path.join(path, "kld", "line_chart.pdf")) as filename:
        KLDiv.mline_chart(df_exp, pdf_name=filename)

    # latex table
    with yield_self(os.path.join(path, "kld", "table.txt")) as filename:
        table = KLDiv.latex_aggregate_table(df_ks)
        with open(filename, "w") as file:
            file.write(table)


def run():
    generate_figures()


if __name__ == "__main__":
    run()
