import statistics
from dataclasses import dataclass
from math import floor
from statistics import mean
from typing import Literal, NamedTuple, Optional, cast

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import scienceplots  # noqa: F401
from scipy.stats import distributions, expon, ks_2samp

from bbs.experiment import ExperimentMetrics, Trial
from bbs.expon.search import binary_search_expon, enhanced_binary_search_expon
from bbs.latex import table

plt.style.use(["science", "ieee"])
# https://github.com/garrettj403/SciencePlots/issues/60 for ieee
plt.rcParams.update({"figure.dpi": "100"})


@dataclass(frozen=True, kw_only=True)
class ExponentialDistributionParameters:
    scale: int


@dataclass(frozen=True, kw_only=True)
class ExponentialExperimentMetrics(ExperimentMetrics):
    parameters: ExponentialDistributionParameters


Kind = Literal["basic"] | Literal["enhanced"]
Batch = dict[Kind, list[ExponentialExperimentMetrics]]


def run_expon_experiment(
    rv: distributions.rv_frozen,
    targets: list[int],
    method: Kind,
    epsilon: int = 1,
) -> ExponentialExperimentMetrics:
    metrics = ExponentialExperimentMetrics(
        trials=[],
        epsilon=epsilon,
        parameters=ExponentialDistributionParameters(scale=rv.kwds["scale"]),
    )

    METHOD_TO_SEARCH = {
        "basic": binary_search_expon,
        "enhanced": enhanced_binary_search_expon,
    }
    search = METHOD_TO_SEARCH[method]

    for i, target in enumerate(targets, 1):
        trial = Trial(
            trial_number=i,
            target=target,
            metrics=search(rv, target, epsilon),
        )
        metrics.trials.append(trial)

    return metrics


class KsTestResult(NamedTuple):
    statistic: float
    pvalue: float
    statistic_location: float
    statistic_sign: int


KsTestDF = pd.DataFrame
ExperimentDF = pd.DataFrame
BoxDF = pd.DataFrame


class BatchUtil:
    @staticmethod
    def ks_test_df(batch: Batch) -> KsTestDF:
        data = []
        basic_batch = batch["basic"]
        enhanced_batch = batch["enhanced"]

        for basic_experiment, enhanced_experiment in zip(basic_batch, enhanced_batch):
            basic_steps, enhanced_steps = (
                basic_experiment.steps,
                enhanced_experiment.steps,
            )

            mean_b, mean_e = (
                statistics.mean(basic_steps),
                statistics.mean(enhanced_steps),
            )
            stdev_b, stdev_e = (
                statistics.stdev(basic_steps),
                statistics.stdev(enhanced_steps),
            )

            res = cast(
                KsTestResult, ks_2samp(basic_steps, enhanced_steps, alternative="less")
            )
            row = {
                "scale": basic_experiment.parameters.scale,
                "epsilon": basic_experiment.epsilon,
                "num_trials": basic_experiment.num_trials,
                "ks_test_pvalue": res.pvalue,
                "percent_decrease": f"{((mean_b - mean_e) / mean_b * 100):.2f}\\%",
                "mean_basic": f"{mean_b:.2f} $\\pm$ {stdev_b:.2f}",
                "mean_enhanced": f"{mean_e:.2f} $\\pm$ {stdev_e:.2f}",
            }
            data.append(row)
        return pd.DataFrame.from_records(data)

    @staticmethod
    def latex_aggregate_table(df: KsTestDF):
        header_to_title = {
            "epsilon": r"$\epsilon$",
            "percent_decrease": "Percent Decrease",
            "mean_basic": "Basic Mean Steps",
            "mean_enhanced": "Bayesian Mean Steps",
        }

        n = df["num_trials"].iloc[0]
        caption = f"Exponential params, n={n}"
        return table(df, header_to_title, caption=caption)

    @staticmethod
    def experiments_df(batch: Batch) -> ExperimentDF:
        data = []
        for kind, experiments in batch.items():
            for metrics in experiments:
                minimum = np.min(metrics.steps)
                maximum = np.max(metrics.steps)
                median = np.median(metrics.steps)
                q1 = np.percentile(metrics.steps, 25)
                q3 = np.percentile(metrics.steps, 75)
                row = {
                    "kind": kind,
                    "scale": metrics.parameters.scale,
                    "epsilon": metrics.epsilon,
                    "total_steps": sum(metrics.steps),
                    "mean_steps": mean(metrics.steps),
                    "min": minimum,
                    "q1": q1,
                    "median": median,
                    "q3": q3,
                    "max": maximum,
                    "count": len(metrics.steps),
                }

                data.append(row)
        return pd.DataFrame.from_records(data)


class Epsilon:
    @staticmethod
    def run(scale: int = 1000) -> Batch:
        NUM_TRIALS = 200

        rv = expon(scale=scale)
        rv.random_state = np.random.RandomState(seed=42)
        targets = [floor(rv.rvs()) for _ in range(NUM_TRIALS)]

        def basic(epsilon: int):
            return run_expon_experiment(rv, targets, "basic", epsilon=epsilon)

        def enhanced(epsilon: int):
            return run_expon_experiment(rv, targets, "enhanced", epsilon=epsilon)

        epsilons = list(range(1, min(33, scale)))
        print(epsilons)
        return {
            "basic": [basic(epsilon) for epsilon in epsilons],
            "enhanced": [enhanced(epsilon) for epsilon in epsilons],
        }

    @staticmethod
    def table(df: ExperimentDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def table_ks(df: KsTestDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def line_chart(df: ExperimentDF):
        fig = px.line(
            df,
            x="epsilon",
            y="mean_steps",
            color="kind",
            title="Exponential: Mean Steps vs. Epsilon",
            markers=True,
        )
        fig.update_layout(yaxis_title="Mean Steps", xaxis_title="Epsilon")
        return fig

    @staticmethod
    def mline_chart(df: ExperimentDF, pdf_name: Optional[str] = None):
        fig, ax = plt.subplots(figsize=(10, 6))

        for kind in df["kind"].unique():
            data = df[df["kind"] == kind]
            ax.plot(data["epsilon"], data["mean_steps"], marker="o", label=kind)

        ax.set_xlabel("Epsilon")
        ax.set_ylabel("Mean Steps")
        ax.set_title("Exponential: Mean Steps vs. Epsilon")
        ax.legend()
        ax.grid(True)
        if pdf_name:
            plt.savefig(pdf_name)
            plt.close()
        else:
            plt.show()

        return fig


class Box:
    @staticmethod
    def run() -> Batch:
        NUM_TRIALS = 1000
        SCALE = 10**4
        EPSILON = 8

        rv = expon(scale=SCALE)
        targets = [floor(rv.rvs()) for _ in range(NUM_TRIALS)]

        basic = [run_expon_experiment(rv, targets, "basic", epsilon=EPSILON)]
        enhanced = [run_expon_experiment(rv, targets, "enhanced", epsilon=EPSILON)]

        return {"basic": basic, "enhanced": enhanced}

    @staticmethod
    def df(batch: Batch):
        data = []
        for kind, experiments in batch.items():
            for metrics in experiments:
                for trial in metrics.trials:
                    for step in trial.metrics.steps:
                        row = {
                            "kind": kind,
                            "step": step.step,
                            "bracket_size": step.hi - step.lo + 1,
                            "scale": metrics.parameters.scale,
                            "epsilon": metrics.epsilon,
                        }

                        data.append(row)

        return pd.DataFrame.from_records(data)

    @staticmethod
    # TODO: kind, step, num_samples, num_samples_finished
    def table(df: BoxDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def fig(df: BoxDF):
        return px.box(df, x="step", y="bracket_size", color="kind")

    @staticmethod
    def show():
        Box.fig(Box.df(Box.run())).show()
