from pathlib import Path

import dash_bootstrap_components as dbc
import pandas as pd
import plotly.graph_objects as go
from dash import Input, Output, State, callback, dcc, html
from dash.exceptions import PreventUpdate
from konductor.webserver.utils import Experiment, fill_experiments

EXPERIMENTS: list[Experiment] = []

layout = html.Div(
    children=[
        dbc.Row(
            [
                dbc.Col(html.H3("Accuracy Over Time")),
                dbc.Col(
                    dbc.Button(
                        "Download CSV", id="btn-render", style={"float": "right"}
                    )
                ),
                dcc.Download(id="download-csv-render"),
            ]
        ),
        dbc.Row(
            [
                dbc.Col(dcc.Dropdown(id="dd-split", options=["train", "val"])),
                dbc.Col(dcc.Dropdown(id="dd-dataset", options=["sc2", "gym"])),
                dbc.Col(dcc.Dropdown(id="dd-metric", options=[])),
            ]
        ),
        dcc.Graph(id="timeseries-graph", selectedData={}),
    ]
)


def data_by_time(exp: Experiment, split: str, key: str) -> pd.Series:
    """Iterate over dataset and transform to be indexed by time"""
    if key.startswith("top") or key == "l2" or key.startswith("pos"):
        group = "sc2-accuracy"
    elif key.startswith("Top"):
        group = "goal"
    else:
        group = "occupancy"

    # Get latest iteration performance data
    data = exp.get_group_latest(split, group)

    if key.endswith("null"):
        filt = lambda c: c.endswith("null") and c.startswith(key.split("-")[0])
    elif key.startswith("top"):
        filt = lambda c: not c.endswith("null") and c.startswith(key)
    else:
        filt = lambda c: key in c

    # Filter to f"key_{timestep}"
    data = data.filter(items=[s for s in data.columns if filt(s)])

    if key.startswith("top"):
        rename_fn = lambda s: int(s.split("-")[1])
    elif key.startswith("pos") or key == "l2":
        rename_fn = lambda s: int(s.split("-")[-1])
    else:
        rename_fn = lambda s: int(s.split("_")[-1])

    # Transform to be indexed by timestep with "key" as label name
    data.rename({s: rename_fn(s) for s in data.columns}, axis="columns", inplace=True)
    data = data.transpose().sort_index()

    series: pd.Series = data[data.columns.values[0]]
    series.name = exp.name

    return series


def gather_experiment_time_performance(
    exps: list[Experiment], split: str, metric: str
) -> list[pd.Series]:
    """For each experiment with the timeseries metric, gather the data from the last
    epoch and transform into a time series to plot prediction performance over time"""

    if metric.endswith("null"):
        filt = lambda e: any(s.endswith("null") for s in e.stats)
    else:
        filt = lambda e: any(metric in s for s in e.stats)

    data: list[pd.Series] = [data_by_time(e, split, metric) for e in exps if filt(e)]

    return data


@callback(
    Output("dd-metric", "options"),
    Output("dd-metric", "value"),
    Input("dd-dataset", "value"),
)
def update_metric_options(dataset: str):
    """Update the available timeseries metrics based on the type of dataset.
    Also resets the selected metric"""
    if dataset is None:
        raise PreventUpdate
    if dataset == "sc2":
        return [
            "top1",
            "top1-null",
            "top5",
            "top5-null",
            "l2",
            "pos-score-acc",
            "pos-score-f1",
            "pos-score-precision",
            "pos-score-recall",
        ], None
    elif dataset == "gym":
        return ["Top1", "targets_IoU", "targets_AUC"], None
    else:
        raise KeyError(f"Invalid dataset type: {dataset}")


@callback(
    Output("timeseries-graph", "figure"),
    Input("dd-split", "value"),
    Input("dd-metric", "value"),
    Input("root-dir", "data"),
)
def update_graph(split: str, metric: str, root_dir: str):
    if len(EXPERIMENTS) == 0:
        fill_experiments(Path(root_dir), EXPERIMENTS)

    if split is None or metric is None:
        raise PreventUpdate

    exps: list[pd.Series] = gather_experiment_time_performance(
        EXPERIMENTS, split, metric
    )
    if len(exps) == 0:
        raise PreventUpdate

    fig = go.Figure()
    for exp in exps:
        fig.add_trace(
            go.Scatter(x=exp.index, y=exp.values, mode="lines", name=exp.name)
        )

    return fig


@callback(
    Output("download-csv-render", "data"),
    Input("btn-render", "n_clicks"),
    State("timeseries-graph", "figure"),
    State("dd-split", "value"),
    State("dd-dataset", "value"),
    State("dd-metric", "value"),
    prevent_initial_call=True,
)
def render_csv(n_clicks, figure, split, dataset, metric):
    """"""
    if n_clicks is None:
        raise PreventUpdate

    if figure is None:
        return ""

    series: list[pd.Series] = []
    for trace in figure["data"]:
        if not trace["visible"] is True:
            continue
        series.append(pd.Series(trace["y"], index=trace["x"], name=trace["name"]))

    df = pd.concat(series, axis=1)

    # Convert dataframe to CSV format
    csv_data = df.to_csv(index=True)

    return {
        "content": csv_data,
        "filename": f"{split}-{dataset}-{metric}.csv",
        "mime_type": "text/csv",
    }
