import numpy as np
import statistics
from dataclasses import asdict, dataclass
from decimal import Decimal
from math import ceil, floor
from statistics import mean
from typing import Literal, NamedTuple, Optional, cast

import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import scienceplots  # noqa: F401
import seaborn as sns
from scipy.stats import distributions, ks_2samp

from bbs.bimodal.search import (
    bimodal,
    binary_search_bimodal,
    enhanced_binary_search_bimodal,
)
from bbs.experiment import ExperimentMetrics, Trial
from bbs.latex import table

plt.style.use(["science", "ieee"])
# https://github.com/garrettj403/SciencePlots/issues/60 for ieee
plt.rcParams.update({"figure.dpi": "100"})


@dataclass(frozen=True, kw_only=True)
class BimodalDistributionParameters:
    mu1: float
    std_dev1: float
    mu2: float
    std_dev2: float
    weight1: float
    weight2: float


@dataclass(frozen=True, kw_only=True)
class BimodalExperimentMetrics(ExperimentMetrics):
    parameters: BimodalDistributionParameters


Kind = Literal["basic"] | Literal["enhanced"]
Batch = dict[Kind, list[BimodalExperimentMetrics]]


def run_bimodal_experiment(
    rv: distributions.rv_frozen,
    targets: list[int],
    method: Kind,
    epsilon: int = 1,
) -> BimodalExperimentMetrics:
    metrics = BimodalExperimentMetrics(
        trials=[],
        epsilon=epsilon,
        parameters=BimodalDistributionParameters(
            mu1=rv.kwds["mu1"],
            std_dev1=rv.kwds["std_dev1"],
            mu2=rv.kwds["mu2"],
            std_dev2=rv.kwds["std_dev2"],
            weight1=rv.kwds["weight1"],
            weight2=1 - rv.kwds["weight1"],
        ),
    )

    METHOD_TO_SEARCH = {
        "basic": binary_search_bimodal,
        "enhanced": enhanced_binary_search_bimodal,
    }
    search = METHOD_TO_SEARCH[method]

    for i, target in enumerate(targets, 1):
        trial = Trial(
            trial_number=i,
            target=target,
            metrics=search(rv, target, epsilon),
        )
        metrics.trials.append(trial)

    return metrics


class KsTestResult(NamedTuple):
    statistic: float
    pvalue: float
    statistic_location: float
    statistic_sign: int


KsTestDF = pd.DataFrame
ExperimentDF = pd.DataFrame
BoxDF = pd.DataFrame


class BatchUtil:
    @staticmethod
    def ks_test_df(batch: Batch) -> KsTestDF:
        data = []
        basic_batch = batch["basic"]
        enhanced_batch = batch["enhanced"]

        for basic_experiment, enhanced_experiment in zip(basic_batch, enhanced_batch):
            basic_steps, enhanced_steps = (
                basic_experiment.steps,
                enhanced_experiment.steps,
            )

            mean_b, mean_e = (
                statistics.mean(basic_steps),
                statistics.mean(enhanced_steps),
            )
            stdev_b, stdev_e = (
                statistics.stdev(basic_steps),
                statistics.stdev(enhanced_steps),
            )

            res = cast(
                KsTestResult, ks_2samp(basic_steps, enhanced_steps, alternative="less")
            )
            row = {
                # "std_dev": basic_experiment.parameters.scale, TODO: the bimodal parameters
                "epsilon": basic_experiment.epsilon,
                "num_trials": basic_experiment.num_trials,
                "ks_test_pvalue": f"{Decimal(res.pvalue):.2E}",
                "percent_decrease": f"{((mean_b - mean_e) / mean_b * 100):.2f}\\%",
                "mean_basic": f"{mean_b:.2f} $\\pm$ {stdev_b:.2f}",
                "mean_enhanced": f"{mean_e:.2f} $\\pm$ {stdev_e:.2f}",
            }
            data.append(row)
        return pd.DataFrame.from_records(data)

    @staticmethod
    def latex_aggregate_table(df: KsTestDF):
        header_to_title = {
            "epsilon": r"$\epsilon$",
            "percent_decrease": "Percent Decrease",
            "mean_basic": "Basic Mean Steps",
            "mean_enhanced": "Bayesian Mean Steps",
        }

        n = df["num_trials"].iloc[0]
        caption = f"Bimodal params, n={n}"
        return table(df, header_to_title, caption=caption)

    @staticmethod
    def experiments_df(batch: Batch) -> ExperimentDF:
        data = []
        for kind, experiments in batch.items():
            for metrics in experiments:
                minimum = np.min(metrics.steps)
                maximum = np.max(metrics.steps)
                median = np.median(metrics.steps)
                q1 = np.percentile(metrics.steps, 25)
                q3 = np.percentile(metrics.steps, 75)
                row = {
                    "kind": kind,
                    # "std_dev": metrics.parameters.scale, # TODO: the bimodal parameters
                    "epsilon": metrics.epsilon,
                    "total_steps": sum(metrics.steps),
                    "mean_steps": mean(metrics.steps),
                    "min": minimum,
                    "q1": q1,
                    "median": median,
                    "q3": q3,
                    "max": maximum,
                    "count": len(metrics.steps),
                }

                data.append(row)
        return pd.DataFrame.from_records(data)


class Epsilon:
    @staticmethod
    def run() -> Batch:
        NUM_TRIALS = 50

        LOC1 = 0
        SCALE1 = 1000

        LOC2 = 4000
        SCALE2 = 1000

        WEIGHT1 = 0.5

        rv = bimodal(LOC1, SCALE1, LOC2, SCALE2, WEIGHT1)
        rv.random_state = np.random.RandomState(seed=42)
        targets = [ceil(rv.rvs()) for _ in range(NUM_TRIALS)]

        def basic(epsilon: int):
            return run_bimodal_experiment(rv, targets, "basic", epsilon=epsilon)

        def enhanced(epsilon: int):
            return run_bimodal_experiment(rv, targets, "enhanced", epsilon=epsilon)

        epsilons = list(range(1, 33))
        print(epsilons)
        return {
            "basic": [basic(epsilon) for epsilon in epsilons],
            "enhanced": [enhanced(epsilon) for epsilon in epsilons],
        }

    @staticmethod
    def table(df: ExperimentDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def table_ks(df: KsTestDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def line_chart(df: ExperimentDF):
        fig = px.line(
            df,
            x="epsilon",
            y="mean_steps",
            color="kind",
            title="Bimodal: Mean Steps vs. Epsilon",
            markers=True,
        )
        fig.update_layout(yaxis_title="Mean Steps", xaxis_title="Epsilon")
        return fig

    @staticmethod
    def mline_chart(df: ExperimentDF, pdf_name: Optional[str] = None):
        fig, ax = plt.subplots(figsize=(10, 6))

        for kind in df["kind"].unique():
            data = df[df["kind"] == kind]
            ax.plot(data["epsilon"], data["mean_steps"], marker="o", label=kind)

        ax.set_xlabel("Epsilon")
        ax.set_ylabel("Mean Steps")
        ax.set_title("Bimodal: Mean Steps vs. Epsilon")
        ax.legend()
        ax.grid(True)
        if pdf_name:
            plt.savefig(pdf_name)
            plt.close()
        else:
            plt.show()

        return fig


class Box:
    @staticmethod
    def run() -> Batch:
        NUM_TRIALS = 1000
        EPSILON = 8  # TODO: what value?

        NUM_TRIALS = 50

        LOC1 = 0
        SCALE1 = 100

        LOC2 = 400
        SCALE2 = 100

        WEIGHT1 = 0.5

        rv = bimodal(LOC1, SCALE1, LOC2, SCALE2, WEIGHT1)

        targets = [floor(rv.rvs()) for _ in range(NUM_TRIALS)]

        basic = [run_bimodal_experiment(rv, targets, "basic", epsilon=EPSILON)]
        enhanced = [run_bimodal_experiment(rv, targets, "enhanced", epsilon=EPSILON)]

        return {"basic": basic, "enhanced": enhanced}

    @staticmethod
    def df(batch: Batch):
        data = []
        for kind, experiments in batch.items():
            for metrics in experiments:
                for trial in metrics.trials:
                    for step in trial.metrics.steps:
                        row = {
                            "kind": kind,
                            "step": step.step,
                            "bracket_size": step.hi - step.lo + 1,
                            "dist_params": asdict(metrics.parameters),
                            "epsilon": metrics.epsilon,
                        }

                        data.append(row)

        return pd.json_normalize(data)

    @staticmethod
    # TODO: kind, step, num_samples, num_samples_finished
    def table(df: BoxDF):
        return go.Table(
            header=dict(values=list(df.columns), align="left"),
            cells=dict(
                values=[df[col] for col in df.columns],
                align="left",
            ),
        )

    @staticmethod
    def fig(df: BoxDF):
        return px.box(df, x="step", y="bracket_size", color="kind")

    @staticmethod
    def show():
        # Box.fig(Box.df(Box.run())).show()
        df = Box.df(Box.run())
        go.Figure(data=Box.table(df)).show()

    @staticmethod
    def mfig(df: BoxDF):
        plt.figure(figsize=(12, 6))
        sns.boxplot(data=df, x="step", y="bracket_size", hue="kind")
        plt.title("Bracket Size Distribution by Step and Kind")
        plt.xlabel("Step")
        plt.ylabel("Bracket Size")
        plt.legend(title="Kind")
        plt.show()

    @staticmethod
    def mshow():
        Box.mfig(Box.df(Box.run()))
