from contextlib import contextmanager
import os
from typing import TypeVar, Generator
import bbs.expon.plot as plot
from bbs.expon.experiment import BatchUtil, Epsilon
from bbs.stats import plot_distribution
from scipy.stats import expon


def epsilon(scale: int = 1000):
    batch = Epsilon.run(scale=scale)
    df_exp = BatchUtil.experiments_df(batch)
    df_ks = BatchUtil.ks_test_df(batch)
    plot.Epsilon.show(df_exp, df_ks, write_html_name=str(scale))


def plot_dist():
    rv = expon(scale=10**4)
    plot_distribution(rv)  # type: ignore


def testing():
    single = False

    if single:
        epsilon()
    else:
        for scale in [10**i for i in range(1, 9)]:
            epsilon(scale=scale)


T = TypeVar("T")


@contextmanager
def yield_self(value: T) -> Generator[T, None, None]:
    print(f"🏋️Generating:\t{value}")
    yield value
    print(f"✅Saved:\t{value}")


def generate_figures():
    """
    Save figures for paper.

    figures/expon/latex/
      line_chart.pdf
      table.txt
    """
    path = "./figures/expon/latex"
    os.makedirs(path, exist_ok=True)

    scale = 10_000
    batch = Epsilon.run(scale=scale)
    df_exp = BatchUtil.experiments_df(batch)
    df_ks = BatchUtil.ks_test_df(batch)

    # line chart
    with yield_self(os.path.join(path, "line_chart.pdf")) as filename:
        Epsilon.mline_chart(df_exp, pdf_name=filename)

    # latex table
    with yield_self(os.path.join(path, "table.txt")) as filename:
        table = BatchUtil.latex_aggregate_table(df_ks)
        with open(filename, "w") as file:
            file.write(table)


def run():
    generate_figures()
