from typing import Optional
import pandas as pd
from plotly.subplots import make_subplots
import bbs.normal.experiment as exp
from bbs.normal.experiment import ExperimentDF, KsTestDF


class Epsilon:
    @staticmethod
    def show(
        df_exp: pd.DataFrame, df_ks: pd.DataFrame, write_html_name: Optional[str] = None
    ):
        df_exp = df_exp.sort_values(by=["epsilon", "kind"])
        df_ks = df_ks.sort_values(by=["epsilon"])
        plot = make_subplots(
            rows=3,
            cols=1,
            specs=[[{"type": "scatter"}], [{"type": "table"}], [{"type": "table"}]],
        )

        lc = exp.Epsilon.line_chart(df_exp)
        t_exp = exp.Epsilon.table(df_exp)
        t_ks = exp.Epsilon.table_ks(df_ks)

        lc.layout.xaxis.domain = None  # pyright:ignore
        lc.layout.yaxis.domain = None  # pyright:ignore

        plot.add_traces(lc.data, rows=1, cols=1)
        plot.add_trace(t_exp, row=2, col=1)
        plot.add_trace(t_ks, row=3, col=1)
        plot.update_layout(lc.layout)

        if write_html_name:
            plot.write_html(f"./figures/normal/epsilon/{write_html_name}.html")
        else:
            plot.show()


class StdDev:
    @staticmethod
    def show(df_exp: ExperimentDF, df_ks: KsTestDF):
        df_exp = df_exp.sort_values(by=["std_dev", "kind"])
        df_ks = df_ks.sort_values(by=["std_dev"])
        plot = make_subplots(
            rows=3,
            cols=1,
            specs=[[{"type": "scatter"}], [{"type": "table"}], [{"type": "table"}]],
        )
        lc = exp.StdDev.line_chart(df_exp)
        t_exp = exp.StdDev.table(df_exp)
        t_ks = exp.StdDev.table_ks(df_ks)

        lc.layout.xaxis.domain = None  # pyright:ignore
        lc.layout.yaxis.domain = None  # pyright:ignore

        plot.add_traces(lc.data, rows=1, cols=1)
        plot.add_trace(t_exp, row=2, col=1)
        plot.add_trace(t_ks, row=3, col=1)
        plot.update_layout(lc.layout)
        plot.show()


class KLDiv:
    @staticmethod
    def show(df_exp: ExperimentDF, df_ks: KsTestDF):
        df_exp = df_exp.sort_values(by=["kld", "kind"])
        df_ks = df_ks.sort_values(by=["kld"])
        plot = make_subplots(
            rows=3,
            cols=1,
            specs=[[{"type": "scatter"}], [{"type": "table"}], [{"type": "table"}]],
        )
        lc = exp.KLDiv.line_chart(df_exp)
        t_exp = exp.KLDiv.table(df_exp)
        t_ks = exp.KLDiv.table_ks(df_ks)

        lc.layout.xaxis.domain = None  # pyright:ignore
        lc.layout.yaxis.domain = None  # pyright:ignore

        plot.add_traces(lc.data, rows=1, cols=1)
        plot.add_trace(t_exp, row=2, col=1)
        plot.add_trace(t_ks, row=3, col=1)
        plot.update_layout(lc.layout)
        plot.show()
