import statistics
from dataclasses import dataclass
from decimal import Decimal
from math import ceil, floor
from statistics import mean
from typing import Literal, NamedTuple, Optional, cast

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import scienceplots  # noqa: F401
import seaborn as sns
from scipy.stats import distributions, ks_2samp, norm

from bbs.experiment import ExperimentMetrics, Trial
from bbs.kl.kl import kl_divergent_norm
from bbs.latex import table
from bbs.normal.search import binary_search_normal, enhanced_binary_search_normal

plt.style.use(["science", "ieee"])
# https://github.com/garrettj403/SciencePlots/issues/60 for ieee
plt.rcParams.update({"figure.dpi": "100"})


@dataclass(frozen=True, kw_only=True)
class NormalDistributionParameters:
    loc: int
    scale: int
    kld: float


@dataclass(frozen=True, kw_only=True)
class NormalExperimentMetrics(ExperimentMetrics):
    parameters: NormalDistributionParameters


Kind = Literal["basic"] | Literal["enhanced"]
Batch = dict[Kind, list[NormalExperimentMetrics]]


def run_normal_experiment(
    rv: distributions.rv_frozen,
    targets: list[int],
    method: Kind,
    epsilon: int = 1,
    kl: float = 0,
) -> NormalExperimentMetrics:
    metrics = NormalExperimentMetrics(
        trials=[],
        epsilon=epsilon,
        parameters=NormalDistributionParameters(loc=rv.mean(), scale=rv.std(), kld=kl),
    )

    METHOD_TO_SEARCH = {
        "basic": binary_search_normal,
        "enhanced": enhanced_binary_search_normal,
    }
    search = METHOD_TO_SEARCH[method]

    lo, hi = None, None
    if kl:
        lo = floor(rv.std() * -4.2 + rv.mean())
        hi = ceil(rv.std() * 4.2 + rv.mean())
        rv = kl_divergent_norm(rv, kl)

    for i, target in enumerate(targets, 1):
        trial = Trial(
            trial_number=i,
            target=target,
            metrics=search(rv, target, epsilon, lo=lo, hi=hi),
        )
        metrics.trials.append(trial)

    return metrics


class KsTestResult(NamedTuple):
    statistic: float
    pvalue: float
    statistic_location: float
    statistic_sign: int


KsTestDF = pd.DataFrame
ExperimentDF = pd.DataFrame
BoxDF = pd.DataFrame


class BatchUtil:
    @staticmethod
    def ks_test_df(batch: Batch) -> KsTestDF:
        data = []
        basic_batch = batch["basic"]
        enhanced_batch = batch["enhanced"]

        for basic_experiment, enhanced_experiment in zip(basic_batch, enhanced_batch):
            basic_steps, enhanced_steps = (
                basic_experiment.steps,
                enhanced_experiment.steps,
            )

            mean_b, mean_e = (
                statistics.mean(basic_steps),
                statistics.mean(enhanced_steps),
            )
            stdev_b, stdev_e = (
                statistics.stdev(basic_steps),
                statistics.stdev(enhanced_steps),
            )

            res = cast(
                KsTestResult, ks_2samp(basic_steps, enhanced_steps, alternative="less")
            )
            row = {
                "std_dev": basic_experiment.parameters.scale,
                "kld": basic_experiment.parameters.kld,
                "real_mean": basic_experiment.parameters.loc,
                "assumed_mean": kl_divergent_norm(
                    norm(
                        loc=basic_experiment.parameters.loc,
                        scale=basic_experiment.parameters.scale,
                    ),
                    basic_experiment.parameters.kld,
                ).mean(),
                "epsilon": basic_experiment.epsilon,
                "num_trials": basic_experiment.num_trials,
                "ks_test_pvalue": f"{Decimal(res.pvalue):.2E}",
                "percent_decrease": f"{((mean_b - mean_e) / mean_b * 100):.2f}\\%",
                "mean_basic": f"{mean_b:.2f} $\\pm$ {stdev_b:.2f}",
                "mean_enhanced": f"{mean_e:.2f} $\\pm$ {stdev_e:.2f}",
            }
            data.append(row)
        return pd.DataFrame.from_records(data)

    @staticmethod
    def latex_aggregate_table(df: KsTestDF):
        header_to_title = {
            "epsilon": r"$\epsilon$",
            "percent_decrease": "Percent Decrease",
            "mean_basic": "Basic Mean Steps",
            "mean_enhanced": "Bayesian Mean Steps",
        }

        mu = df["real_mean"].iloc[0]
        sigma = df["std_dev"].iloc[0]
        n = df["num_trials"].iloc[0]
        caption = f"\\mu={mu}, \\sigma={sigma}, N={n}"
        return table(df, header_to_title, caption=caption)

    @staticmethod
    def experiments_df(batch: Batch) -> ExperimentDF:
        data = []
        for kind, experiments in batch.items():
            for metrics in experiments:
                minimum = np.min(metrics.steps)
                maximum = np.max(metrics.steps)
                median = np.median(metrics.steps)
                q1 = np.percentile(metrics.steps, 25)
                q3 = np.percentile(metrics.steps, 75)
                row = {
                    "kind": kind,
                    # "std_dev": metrics.parameters.scale, # TODO: the bimodal parameters
                    "epsilon": metrics.epsilon,
                    "total_steps": sum(metrics.steps),
                    "mean_steps": mean(metrics.steps),
                }

                row = {
                    "kind": kind,
                    "std_dev": metrics.parameters.scale,
                    "kld": metrics.parameters.kld,
                    "real_mean": metrics.parameters.loc,
                    "assumed_mean": kl_divergent_norm(
                        norm(
                            loc=metrics.parameters.loc,
                            scale=metrics.parameters.scale,
                        ),
                        metrics.parameters.kld,
                    ).mean(),
                    "epsilon": metrics.epsilon,
                    "total_steps": sum(metrics.steps),
                    "mean_steps": mean(metrics.steps),
                    "min": minimum,
                    "q1": q1,
                    "median": median,
                    "q3": q3,
                    "max": maximum,
                    "count": len(metrics.steps),
                }

                data.append(row)
        return pd.DataFrame.from_records(data)


class KLDiv:
    @staticmethod
    def run() -> Batch:
        basic_experiments: list[NormalExperimentMetrics] = []
        enhanced_experiments: list[NormalExperimentMetrics] = []

        LOC = 0
        NUM_TRIALS = 200
        SCALE = 1000
        EPS = 10

        rv = norm(loc=LOC, scale=SCALE)
        rv.random_state = np.random.RandomState(seed=42)
        targets = [floor(rv.rvs()) for _ in range(NUM_TRIALS)]

        for kld in range(0, 20, 1):
            basic_experiments.append(
                run_normal_experiment(rv, targets, "basic", kl=kld / 20, epsilon=EPS)
            )
            enhanced_experiments.append(
                run_normal_experiment(rv, targets, "enhanced", kl=kld / 20, epsilon=EPS)
            )

        return {"basic": basic_experiments, "enhanced": enhanced_experiments}

    @staticmethod
    def line_chart(df: ExperimentDF):
        fig = px.line(
            df,
            x="kld",
            y="mean_steps",
            color="kind",
            title="Normal: Mean Steps vs. KLD",
            markers=True,
        )
        fig.update_layout(yaxis_title="Mean Steps", xaxis_title="KLD")
        return fig

    @staticmethod
    def mline_chart(df: ExperimentDF, pdf_name: Optional[str] = None):
        fig, ax = plt.subplots(figsize=(10, 6))

        for kind in df["kind"].unique():
            data = df[df["kind"] == kind]
            ax.plot(data["kld"], data["mean_steps"], marker="o", label=kind)

        ax.set_xlabel("KLD")
        ax.set_ylabel("Mean Steps")
        ax.set_title("Normal: Mean Steps vs. KLD")
        ax.legend()
        ax.grid(True)
        if pdf_name:
            plt.savefig(pdf_name)
            plt.close()
        else:
            plt.show()

        return fig

    @staticmethod
    def latex_aggregate_table(df: KsTestDF):
        header_to_title = {
            "kld": r"KLD",
            "percent_decrease": "Percent Decrease",
            "mean_basic": "Basic Mean Steps",
            "mean_enhanced": "Bayesian Mean Steps",
        }

        mu = df["real_mean"].iloc[0]
        sigma = df["std_dev"].iloc[0]
        epsilon = df["epsilon"].iloc[0]
        n = df["num_trials"].iloc[0]
        caption = f"\\mu={mu}, \\sigma={sigma}, $\\epsilon$={epsilon}, N={n}"
        return table(df, header_to_title, caption=caption)

    @staticmethod
    def table(df: ExperimentDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def table_ks(df: KsTestDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )


class StdDev:
    @staticmethod
    def run() -> Batch:
        basic_experiments: list[NormalExperimentMetrics] = []
        enhanced_experiments: list[NormalExperimentMetrics] = []

        LOC = 0
        NUM_TRIALS = 100

        for scale in range(1, 1001, 250):
            rv = norm(loc=LOC, scale=scale)
            targets = [floor(rv.rvs()) for _ in range(NUM_TRIALS)]

            basic_experiments.append(run_normal_experiment(rv, targets, "basic"))
            enhanced_experiments.append(run_normal_experiment(rv, targets, "enhanced"))

        return {"basic": basic_experiments, "enhanced": enhanced_experiments}

    @staticmethod
    def line_chart(df: ExperimentDF):
        fig = px.line(
            df,
            x="std_dev",
            y="mean_steps",
            color="kind",
            title="Normal: Mean Steps vs. Standard Deviation",
            markers=True,
        )
        fig.update_layout(yaxis_title="Mean Steps", xaxis_title="Standard Deviation")
        return fig

    @staticmethod
    def table(df: ExperimentDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def table_ks(df: KsTestDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )


class Epsilon:
    @staticmethod
    def run(scale: int = 1000) -> Batch:
        NUM_TRIALS = 500
        LOC = 0

        rv = norm(loc=LOC, scale=scale)
        rv.random_state = np.random.RandomState(seed=42)
        targets = [floor(rv.rvs()) for _ in range(NUM_TRIALS)]

        def basic(epsilon: int):
            return run_normal_experiment(rv, targets, "basic", epsilon=epsilon)

        def enhanced(epsilon: int):
            return run_normal_experiment(rv, targets, "enhanced", epsilon=epsilon)

        epsilons = list(range(1, min(33, scale)))
        # more = list(range(10, ceil(scale / 3), max(1, scale // 30)))
        es = epsilons  # + more
        print(es)

        return {
            "basic": [basic(epsilon) for epsilon in es],
            "enhanced": [enhanced(epsilon) for epsilon in es],
        }

    @staticmethod
    def table(df: ExperimentDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def table_ks(df: KsTestDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def line_chart(df: ExperimentDF):
        fig = px.line(
            df,
            x="epsilon",
            y="mean_steps",
            color="kind",
            title="Normal: Mean Steps vs. Epsilon",
            markers=True,
        )
        fig.update_layout(yaxis_title="Mean Steps", xaxis_title="Epsilon")
        return fig

    @staticmethod
    def mline_chart(df: ExperimentDF, pdf_name: Optional[str] = None):
        fig, ax = plt.subplots(figsize=(10, 6))

        for kind in df["kind"].unique():
            data = df[df["kind"] == kind]
            ax.plot(data["epsilon"], data["mean_steps"], marker="o", label=kind)

        ax.set_xlabel("Epsilon")
        ax.set_ylabel("Mean Steps")
        ax.set_title("Normal: Mean Steps vs. Epsilon")
        ax.legend()
        ax.grid(True)
        if pdf_name:
            plt.savefig(pdf_name)
            plt.close()
        else:
            plt.show()

        return fig


class Box:
    @staticmethod
    def run() -> Batch:
        NUM_TRIALS = 1000
        LOC = 0
        SCALE = 100  # TODO: what value?
        EPSILON = 8  # TODO: what value?

        rv = norm(loc=LOC, scale=SCALE)
        targets = [floor(rv.rvs()) for _ in range(NUM_TRIALS)]

        basic = [run_normal_experiment(rv, targets, "basic", epsilon=EPSILON)]
        enhanced = [run_normal_experiment(rv, targets, "enhanced", epsilon=EPSILON)]

        return {"basic": basic, "enhanced": enhanced}

    @staticmethod
    def df(batch: Batch):
        data = []
        for kind, experiments in batch.items():
            for metrics in experiments:
                for trial in metrics.trials:
                    for step in trial.metrics.steps:
                        row = {
                            "kind": kind,
                            "step": step.step,
                            "bracket_size": step.hi - step.lo + 1,
                            "std_dev": metrics.parameters.scale,
                            "epsilon": metrics.epsilon,
                        }

                        data.append(row)

        return pd.DataFrame.from_records(data)

    @staticmethod
    # TODO: kind, step, num_samples, num_samples_finished
    def table(df: BoxDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def fig(df: BoxDF):
        return px.box(df, x="step", y="bracket_size", color="kind")

    @staticmethod
    def show():
        Box.fig(Box.df(Box.run())).show()

    @staticmethod
    def mfig(df: BoxDF, pdf_name: Optional[str] = None):
        fig = plt.figure(figsize=(12, 6))
        sns.boxplot(data=df, x="step", y="bracket_size", hue="kind")
        plt.title("Bracket Size Distribution by Step and Kind")
        plt.xlabel("Step")
        plt.ylabel("Bracket Size")
        plt.legend(title="Kind")

        if pdf_name:
            plt.savefig(pdf_name)
            plt.close()
        else:
            plt.show()

        return fig

    @staticmethod
    def mshow(pdf_name: Optional[str] = None):
        df = Box.df(Box.run())
        Box.mfig(df, pdf_name=pdf_name)
