from typing import Union

import pandas as pd
import plotly.graph_objects as go
from plotly import colors


HEX_COLOR_SEQUENCE = [
    "#1f77b4",  # muted blue
    "#ff7f0e",  # safety orange
    "#2ca02c",  # cooked asparagus green
    "#d62728",  # brick red
    "#9467bd",  # muted purple
    "#8c564b",  # chestnut brown
    "#e377c2",  # raspberry yogurt pink
    "#7f7f7f",  # middle gray
    "#bcbd22",  # curry yellow-green
    "#17becf",  # blue-teal
]
# HEX_COLOR_SEQUENCE = [
#     "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2"
# ]
COLOR_SEQUENCE = colors.convert_colors_to_same_type(
    HEX_COLOR_SEQUENCE, "tuple"
)[0]


def add_std_dev_trace(
    fig: go.Figure,
    trace_idx: int,
    mean_values: Union[pd.Series, pd.DataFrame],
    std_dev_values: Union[pd.Series, pd.DataFrame],
    name: str,
    x_values: Union[pd.Series, pd.DataFrame, None] = None,
) -> go.Figure:
    if x_values is None:
        x_vals = mean_values.index
    else:
        x_vals = x_values
    upper_bound = mean_values + std_dev_values
    lower_bound = mean_values - std_dev_values

    color_vals = COLOR_SEQUENCE[trace_idx % len(COLOR_SEQUENCE)]
    color = f"rgb({color_vals[0]}, {color_vals[1]}, {color_vals[2]})"
    fillcolor = f"rgba({color_vals[0]}, {color_vals[1]}, {color_vals[2]}, 0.2)"

    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=mean_values,
            name=name,
            line=dict(color=color, width=3),
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=upper_bound,
            mode="lines",
            fillcolor=fillcolor,
            fill="tonexty",
            line=dict(width=0),
            showlegend=False,
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x_vals,
            y=lower_bound,
            mode="lines",
            fillcolor=fillcolor,
            fill="tonexty",
            line=dict(width=0),
            showlegend=False,
        )
    )
    # fig.add_trace(
    #     go.Scatter(
    #         x=mean_res[progress_unit],
    #         y=mean_res["loss"],
    #         error_y=dict(
    #             type="data",
    #             array=std_res["loss"],
    #             visible=True,
    #         ),
    #         name=res_name,
    #     )
    # )
    return fig
