import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
import os
import matplotlib.colors as mcolors

colors = ['blue', 'orange', 'green', 'purple', 'orange']
shapes = ['solid', 'solid', 'dashdot', 'longdash', 'longdashdot']
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 32
def name_to_rgb(name):
    rgb = mcolors.to_rgb(name)  # This will get the RGB values in a range of 0-1
    r, g, b = [int(x*255) for x in rgb]  # Scaling RGB values to 0-255
    return r, g, b

colors =  [
    'rgb(230, 159, 0)',   # Orange
    #'rgb(86, 180, 233)',  # Sky Blue
    'rgb(0, 158, 115)',   # Bluish Green
    #'rgb(240, 228, 66)',  # Yellow
    'rgb(0, 114, 178)',   # Blue
    #'rgb(213, 94, 0)',    # Vermillion
    'rgb(204, 121, 167)'  # Reddish Purple
]
rgb_colors =  [
    'rgb(230, 159, 0)',   # Orange
    #'rgb(86, 180, 233)',  # Sky Blue
    'rgb(0, 158, 115)',   # Bluish Green
    #'rgb(240, 228, 66)',  # Yellow
    'rgb(0, 114, 178)',   # Blue
    #'rgb(213, 94, 0)',    # Vermillion
    'rgb(204, 121, 167)'  # Reddish Purple
]

# Convert to matplotlib format
mpl_colors = [tuple(int(c) / 255 for c in color[4:-1].split(',')) for color in rgb_colors]
def replace_keys_with_names(result_dict, lookup_dict):
    return {lookup_dict[key]: value for key, value in result_dict.items() if key in lookup_dict}


def load_bounds(res, res_seed, step_size):
    lower_bounds_seeds = {}
    upper_bounds_seeds = {}
    for strat, res_strat in res_seed.items():
        try: 
            lower_bounds_seeds.update({strat: np.array(res_strat["Grad_Min_Exact"]).squeeze()})
            upper_bounds_seeds.update({strat: np.array(res_strat["Grad_Max_Exact"]).squeeze()})

            if strat == "Adaptive":
                lower_bounds_seeds["Adaptive"] = lower_bounds_seeds["Adaptive"][:, step_size, :].squeeze()
                upper_bounds_seeds["Adaptive"] = upper_bounds_seeds["Adaptive"][:, step_size, :].squeeze()
        except TypeError:
            print(f"{strat} not available")
    
    res_time = {}
    try:
        for strat, res_strat in res_seed.items():
            try: 
                res_time.update({strat: np.array(res_strat["Time"]).squeeze()})
            
                if strat == "Adaptive":
                    res_time["Adaptive"] = res_time["Adaptive"][:, step_size, :].squeeze()
            except TypeError:
                print(f"{strat} not available")
    except KeyError:
        print(f"Time not available")

    lower_bounds = {}
    upper_bounds = {}
    for strat, res_strat in res.items():
        try:
            lower_bounds.update({strat: np.array(res_strat["Grad_Min_Exact"]).squeeze()})
            upper_bounds.update({strat: np.array(res_strat["Grad_Max_Exact"]).squeeze()})
            if strat == "Adaptive":
                lower_bounds["Adaptive"] = lower_bounds["Adaptive"][step_size]#.squeeze()
                upper_bounds["Adaptive"] = upper_bounds["Adaptive"][step_size]#.squeeze()
        except TypeError:
            print(f"{strat} not available")
    return lower_bounds, upper_bounds, lower_bounds_seeds, upper_bounds_seeds, res_time


def load_results(output_path, experiment_name, strategy_type, seed, lam_c, T, T_exploration, sample_sigma, iteration):
    res = {}
    res_seed = {}
    for i in range(len(strategy_type)):
        folder_name = f"strategy_type-{strategy_type[i]}_lam_c-{lam_c}_seed-{seed}_T-{T}_T_exploration-{T_exploration}_sample_sigma-{sample_sigma}"
        file_name = os.path.join(output_path, experiment_name, folder_name, f"results_{iteration}.npy")
        res_name = os.path.join(output_path, experiment_name, folder_name, f"results_seed.npy")
        try:
            res_single = np.load(file_name, allow_pickle=True).item()
            res_single_seed = np.load(res_name, allow_pickle=True).item()
        except FileNotFoundError:
            print(f"File--- {strategy_type[i]} ---not found")
            res_single = None
            res_single_seed = None
        res.update({strategy_type[i]: res_single})
        res_seed.update({strategy_type[i]: res_single_seed})


    return res, res_seed

def update_layout(fig, do_markers=True):
    """update layout for the paper

    Parameters
    ----------
    fig : plotly figure

    Returns
    -------
    fig : plotly figure
        input figure with updated layout

    """

    layout = go.Layout(
        plot_bgcolor='rgba(0,0,0,0)',
        font=dict(family="serif", size=32),
        margin=dict(l=0, r=0, t=0, b=0),
        xaxis=dict(showline=True, linewidth=2, linecolor="black"),  # gridcolor="grey"),
        yaxis=dict(showline=True, linewidth=2, linecolor="black")  # , gridcolor="grey")
    )
    fig.update_xaxes(showline=True, linewidth=2, linecolor="black")
    fig.update_yaxes(showline=True, linewidth=2, linecolor="black")
    if do_markers:
        fig.update_traces(marker_line_width=2, marker_size=5, line_width=4)
    fig.update_layout(layout)

    return fig




def visualize_iteration(res, iteration, n_exploitation, grad_true, comp = 0, figsize = (40, 40), strat = None, step_size = 1, title = None):

    X1 = res["X1"]
    X2 = res["X2"]
    f_grid = res["f_grid"]
    xbar = res["xbar"].squeeze()
    x_data = res["x_data"]
    #x_exp = res["X_Samples"][iteration * n_exploitation : (iteration + 1) * n_exploitation, :]
    #y_exp = res["Y_Samples"][iteration * n_exploitation : (iteration + 1) * n_exploitation]
    
    x_exp = res["X_Samples"][ : (iteration + 1) * n_exploitation, :]
    y_exp = res["Y_Samples"][ : (iteration + 1) * n_exploitation]

    grad = np.zeros((2, ))
    grad_min_exact = np.zeros((2, ))
    grad_max_exact = np.zeros((2, ))
    grad[comp] = grad_true.squeeze()
    #grad[comp] = res["Grad"].squeeze()[iteration]
    if strat == "Adaptive":
        grad_min_exact[comp] = res["Grad_Min_Exact"].squeeze()[step_size, iteration]
        grad_max_exact[comp] = res["Grad_Max_Exact"].squeeze()[step_size, iteration]
    else:
        grad_min_exact[comp] = res["Grad_Min_Exact"].squeeze()[iteration]
        grad_max_exact[comp] = res["Grad_Max_Exact"].squeeze()[iteration]
        

    grad_results = {"Minimum ": grad_min_exact.squeeze(), "Maximum ": grad_max_exact.squeeze(), "Actual ": grad.squeeze()}
    
    fig, ax = plt.subplots(figsize=figsize)

    # Plot isolines using the function values
    contour = ax.contour(X1, X2, f_grid, levels=10, cmap='coolwarm')
    ax.clabel(contour, inline=True, fontsize=8)
    #fig.colorbar(contour, ax=ax, label='Function value')
    
    # Plot gradient as vector at the average x position, if provided
    i = 0
    if xbar is not None and grad_results is not None:
        for grad_name, grad_value in grad_results.items():
            ax.quiver(*xbar, *grad_value, color=mpl_colors[i], scale=1, scale_units='xy', angles='xy', width=0.005, label=f"{grad_name}: {np.round(grad_value[comp], 1)}")

            i += 1
    
    ax.scatter(*xbar, marker='x', color='black', label='xbar', s=100)
    # Add grid lines
    ax.grid(color='lightgray', linestyle='-', linewidth=0.5)
    
    #ax.set_title("Contour plot of the function and its gradient")
    ax.set_xlabel('X0 - Component')
    ax.set_ylabel('X1 - Component')
    ax.legend()


    scatter = ax.scatter(x_exp[:, 0], x_exp[:, 1], c=y_exp, cmap='coolwarm', alpha=0.5, s=1)
    
    # Add a colorbar for the scatter plot
    fig.colorbar(scatter, ax=ax, label='Y values')

    return fig
    



def visualize_grad_confidence(lower_bounds, upper_bounds, grad_true, len_=None, range_=None):
    fig = go.Figure()

    length_ = 0
    for i, key in enumerate(lower_bounds.keys()):
        if len_ == None:
            length = len(lower_bounds[key][0])
            if length > length_:
                length_ = length
        else:
            length_ = len_

    conf_u = 90.0
    conf_l = 10.0
    for i, key in enumerate(lower_bounds.keys()):
        print(key)
        def pad_array(array, length_):
            if np.isscalar(array):
                array = np.array([array])
            if len(array) < length_:
                array = np.concatenate((np.full(length_ - len(array), np.nan), array))
            return array
        if key == "Alternating":
            mean_lower = np.repeat(np.mean(lower_bounds[key], axis=0), 2)
            per_lower_l = np.repeat(np.percentile(lower_bounds[key], conf_l, axis=0), 2)
            per_lower_u = np.repeat(np.percentile(lower_bounds[key], conf_u, axis=0), 2)

            #upper_bounds[key] = pad_array(upper_bounds[key], length_)
            mean_upper = np.repeat(np.mean(upper_bounds[key], axis=0), 2)
            per_upper_l = np.repeat(np.percentile(upper_bounds[key], conf_l, axis=0), 2)
            per_upper_u = np.repeat(np.percentile(upper_bounds[key], conf_u, axis=0), 2)
        else:
            # Compute the column-wise mean, lower and upper confidence bounds
            mean_lower = pad_array(np.mean(lower_bounds[key], axis=0), length_)
            per_lower_l = pad_array(np.percentile(lower_bounds[key], conf_l, axis=0), length_)
            per_lower_u = pad_array(np.percentile(lower_bounds[key], conf_u, axis=0), length_)

            #upper_bounds[key] = pad_array(upper_bounds[key], length_)
            mean_upper = pad_array(np.mean(upper_bounds[key], axis=0), length_)
            per_upper_l = pad_array(np.percentile(upper_bounds[key], conf_l, axis=0), length_)
            per_upper_u = pad_array(np.percentile(upper_bounds[key], conf_u, axis=0), length_)
        

        try:
            if len_ == None:
                len_ = len(lower_bounds[key])
            # Change scatter to lines and use different line shapes for lower and upper bounds
            fig.add_trace(go.Scatter(x=list(range(len_)), y=mean_lower, mode='lines', line=dict(color=colors[i % len(colors)], dash=shapes[0]), name=f"{key}", showlegend=True))
            fig.add_trace(go.Scatter(x=list(range(len_)), y=per_lower_l, mode='lines', line=dict(color=colors[i % len(colors)], width=0), name=f"{key}", showlegend=False))
            fig.add_trace(go.Scatter(x=list(range(len_)), y=per_lower_u, mode='lines', line=dict(color=colors[i % len(colors)], width=0), 
                                    name=f"{key}", showlegend=False, fill="tonexty", 
                                    fillcolor=f'rgba({int(rgb_colors[i % len(rgb_colors)].split("(")[1].split(",")[0])},{int(rgb_colors[i % len(rgb_colors)].split(",")[1])},{int(rgb_colors[i % len(rgb_colors)].split(",")[2].split(")")[0])},0.1)'))
            
            fig.add_trace(go.Scatter(x=list(range(len_)), y=mean_upper, mode='lines', line=dict(color=colors[i % len(colors)], dash=shapes[1]), name=f"{key}", showlegend=False))
            fig.add_trace(go.Scatter(x=list(range(len_)), y=per_upper_l, mode='lines', line=dict(color=colors[i % len(colors)], width=0), name=f"{key}", showlegend=False))
            fig.add_trace(go.Scatter(x=list(range(len_)), y=per_upper_u, mode='lines', line=dict(color=colors[i % len(colors)], width=0), 
                                    name=f"{key}", showlegend=False, fill="tonexty", 
                                    fillcolor=f'rgba({int(rgb_colors[i % len(rgb_colors)].split("(")[1].split(",")[0])},{int(rgb_colors[i % len(rgb_colors)].split(",")[1])},{int(rgb_colors[i % len(rgb_colors)].split(",")[2].split(")")[0])},0.1)'))
            fig.add_trace(go.Scatter(x=list(range(len_)), y=grad_true[:len_], mode='lines', line=dict(color='black'), name='True Gradient', showlegend=False))
        except ValueError:
            print(f"ValueError: {key}")
        if range_ is not None:
            # Calculate the mean of grad_true
            mean_grad_true = np.mean(grad_true)
            fig.update_yaxes(range=[mean_grad_true - range_, mean_grad_true + range_])
    
        fig.update_layout(
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            )
)
    return fig


def visualize_grad_confidence_final(lower_bounds, upper_bounds, grad_true, len_=None, range_=None):
    fig = go.Figure()

    for i, key in enumerate(lower_bounds.keys()):
       
        # Compute the column-wise mean, lower and upper confidence bounds
        #if key == "Adaptive":
        #    lower = lower_bounds[key].squeeze()
        #    upper = upper_bounds[key].squeeze()
        #else:
        lower = lower_bounds[key][:, -1]
        upper = upper_bounds[key][:, -1]

        if len_ == None:
            len_ = len(lower_bounds[key])
        # Change scatter to lines and use different line shapes for lower and upper bounds
        
        fig.add_trace(go.Box(y=lower, name=f"{key}", marker=dict(color=colors[i % len(colors)]), showlegend=False))
        fig.add_trace(go.Box(y=upper, name=f"{key}", marker=dict(color=colors[i % len(colors)]), showlegend=False))
        fig.update_layout(
            #yaxis=dict(domain=[0, 0.5]),  # the range of y values for the first boxplot
            #yaxis2=dict(domain=[0.5, 1]),  # the range of y values for the second boxplot
            boxmode='overlay'  # group together boxes of the different traces for each value of x
        )

        if range_ is not None:
            # Calculate the mean of grad_true
            mean_grad_true = np.mean(grad_true)
            fig.update_yaxes(range=[mean_grad_true - range_, mean_grad_true + range_])

        mean_grad_true = np.mean(grad_true)

        fig.add_shape(
            type="line",
            x0=-0.5,
            y0=mean_grad_true,
            x1=i+0.5,
            y1=mean_grad_true,
            line=dict(
                color="Black",
                width=3,
            ),
        )
        fig.add_shape(
            type="line",
            x0=-0.5,
            y0=0,
            x1=i+0.5,
            y1=0,
            line=dict(
                color="Black",
                width=3,
                dash="dot",
            ),
        )
    return fig




def plot_data_with_gradient_3d_surface(x_grid, y_grid, f_values, avg_x_position=None, grad_f_at_avg_x=None, figsize=(40,20)):
    """
    Visualizes data using a 3D surface plot with specified transparency and optionally plots the gradient of a function
    at a specific position, projecting the gradient vector onto the plot's base plane.

    :param x_grid: np.array, the grid of x values generated by np.meshgrid
    :param y_grid: np.array, the grid of y values generated by np.meshgrid
    :param f_values: np.array, the function values at each point in the grid, shape (n_samples, n_samples)
    :param avg_x_position: np.array, optional, the average x position for plotting the gradient vector
    :param grad_f_at_avg_x: np.array, optional, the gradient of the function at the average x position
    :param title: str, the title for the plot
    :return: matplotlib.pyplot, the plot object
    """
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')

    # Plot the surface with specified alpha for transparency
    surf = ax.plot_surface(x_grid, y_grid, f_values, cmap='coolwarm', edgecolor='none', alpha=0.5)
    fig.colorbar(surf, ax=ax, shrink=0.5, aspect=5, label='Function value')

    # Plot gradient vector on the floor if provided
    #if avg_x_position is not None and grad_f_at_avg_x is not None:
    #    z_base = f_values.min()
    #    ax.quiver(avg_x_position[0], avg_x_position[1], z_base,
    #              grad_f_at_avg_x[0], grad_f_at_avg_x[1], 0,
    #              color='red', length=0.5, normalize=True, arrow_length_ratio=0.1)

    ax.set_xlabel('x_0')
    ax.set_ylabel('x_1')
    #ax.set_zlabel('Function value')

    return fig


def plot_data_with_gradient_2d_contour(x_grid, y_grid, f_values, avg_x_position=None, grad_f_at_avg_x=None, figsize=(15, 10)):
    """
    Visualizes data using a 2D contour plot and optionally plots the gradient of a function at a specific position.
    
    :param x_grid: np.array, the grid of x values generated by np.meshgrid
    :param y_grid: np.array, the grid of y values generated by np.meshgrid
    :param f_values: np.array, the function values at each point in the grid, shape (n_samples, n_samples)
    :param avg_x_position: np.array, optional, the average x position for plotting the gradient vector
    :param grad_f_at_avg_x: np.array, optional, the gradient of the function at the average x position
    :param title: str, the title for the plot
    :return: matplotlib.pyplot, the plot object
    """
    fig, ax = plt.subplots(figsize=figsize)

    # Plot isolines using the function values
    contour = ax.contour(x_grid, y_grid, f_values, levels=10, cmap='coolwarm')
    ax.clabel(contour, inline=True, fontsize=8)
    fig.colorbar(contour, ax=ax, label='Function value')
    
    # Plot gradient as vector at the average x position, if provided
    if avg_x_position is not None and grad_f_at_avg_x is not None:
        ax.quiver(*avg_x_position, *grad_f_at_avg_x, color='black', scale=1, scale_units='xy', angles='xy', width=0.005, label='Gradient Vector')
    ax.grid(color='lightgray', linestyle='-', linewidth=0.5)
    
    ax.set_xlabel('x_0')
    ax.set_ylabel('x_1')
    ax.legend()

    return fig


def visualize_data(X1, X2, f_grid, xbar=None, grad_true=None):

    fig_2d = plot_data_with_gradient_2d_contour(X1, X2, f_grid, avg_x_position=xbar, grad_f_at_avg_x=grad_true)
    fig_surface = plot_data_with_gradient_3d_surface(X1, X2, f_grid, avg_x_position=xbar, grad_f_at_avg_x=grad_true)
    return fig_2d, fig_surface


def visualize_time(lower_bounds, len_=None, range_=None):
    fig = go.Figure()

    length_ = 0
    for i, key in enumerate(lower_bounds.keys()):
        if len_ == None:
            length = len(lower_bounds[key][0])
            if length > length_:
                length_ = length
        else:
            length_ = len_

    
    for i, key in enumerate(lower_bounds.keys()):
        print(key)
        def pad_array(array, length_):
            if np.isscalar(array):
                array = np.array([array])
            if len(array) < length_:
                array = np.concatenate((np.full(length_ - len(array), np.nan), array))
            return array
        if key == "Alternating":
            mean_lower = np.repeat(np.mean(lower_bounds[key], axis=0), 2)
            per_lower_l = np.repeat(np.percentile(lower_bounds[key], 25.0, axis=0), 2)
            per_lower_u = np.repeat(np.percentile(lower_bounds[key], 75.0, axis=0), 2)

        else:
            # Compute the column-wise mean, lower and upper confidence bounds
            mean_lower = pad_array(np.mean(lower_bounds[key], axis=0), length_)
            per_lower_l = pad_array(np.percentile(lower_bounds[key], 25.0, axis=0), length_)
            per_lower_u = pad_array(np.percentile(lower_bounds[key], 75.0, axis=0), length_)

        try:
            #if len_ == None:
            #    len_ = len(lower_bounds[key])
            # Change scatter to lines and use different line shapes for lower and upper bounds
            fig.add_trace(go.Scatter(x=list(range(length_)), y=mean_lower, mode='lines', line=dict(color=colors[i % len(colors)], dash=shapes[0]), name=f"{key}", showlegend=True))
            fig.add_trace(go.Scatter(x=list(range(length_)), y=per_lower_l, mode='lines', line=dict(color=colors[i % len(colors)], width=0), name=f"{key}", showlegend=False))
            fig.add_trace(go.Scatter(x=list(range(length_)), y=per_lower_u, mode='lines', line=dict(color=colors[i % len(colors)], width=0), 
                                    name=f"{key}", showlegend=False, fill="tonexty", 
                                    fillcolor=f'rgba({int(rgb_colors[i % len(rgb_colors)].split("(")[1].split(",")[0])},{int(rgb_colors[i % len(rgb_colors)].split(",")[1])},{int(rgb_colors[i % len(rgb_colors)].split(",")[2].split(")")[0])},0.1)'))
            
        except ValueError:
            print(f"ValueError: {key}")
        
        fig.update_layout(
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            )
)
    return fig