import numpy as np
import plotly.figure_factory as ff
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from utils.data import CLASS_ID_TO_NAMES

import plotly.express as px
COLORS = px.colors.qualitative.Plotly
# COLORS = ['darkblue', 'red', 'brown', 'pink', 'orange', 'darkgreen',
        #   'gray', 'lightblue', 'purple', 'lightgreen']

def plot_pie_chart(labels, values, ids_colors=None, save_path=None,
                   textinfo='label+percent', textfont_size_pie=16,
                   yaxis=dict(), xaxis=dict(), legend=dict(), font=dict(), title=dict()):
    selected_colors = COLORS
    if ids_colors is not None:
        selected_colors = np.take(COLORS, ids_colors)
    fig = go.Figure(data=[go.Pie(labels=labels,
                                 values=values,
                                 textinfo=textinfo,
                                 insidetextorientation='radial',
                                 marker=dict(colors=list(selected_colors)),
                                 textfont_size=textfont_size_pie,)])
    fig.update_traces(
        marker=dict(colors=list(selected_colors)),
        textfont_size=16,
    )
    if bool(xaxis) is False:
        xaxis = dict(
            linecolor = "black",
            spikedash = 'solid',
            showgrid=True,
            gridcolor='lightgrey')
    if bool(yaxis) is False:
        yaxis = dict(
            linecolor = "black",
            spikedash = 'solid',
            showgrid=True,
            gridcolor='lightgrey')
    fig.update_layout(
        title=title,
        xaxis=xaxis,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        yaxis=yaxis,
        font=font,
        legend=legend)
    if save_path is not None:
        fig.write_image(save_path)
        print(f'Plotted and saved scatter plot under {save_path}')
    return fig

def plot_scatter_plot(all_x, all_y, all_names, xaxis_title_text, yaxis_title_text, showlegend=True, log=False,
    title_text='', marker=None, xaxis=dict(), font=dict(), legend=dict(), text_scatter=None, error_y=dict(),
    mode='markers', yaxis=dict(), save_path=None, x_range=None, y_range=None, colors=None):
    colors = colors
    if colors is None:
        colors=COLORS
    if bool(xaxis) is False:
        xaxis = dict(
            linecolor = "black",
            spikedash = 'solid',
            showgrid=True,
            gridcolor='lightgrey')
    if bool(yaxis) is False:
        yaxis = dict(
            linecolor = "black",
            spikedash = 'solid',
            showgrid=True,
            gridcolor='lightgrey')
    fig = go.Figure()
    fig.add_hline(y=0, line_width=2, line_dash="dash", line_color="grey")
    if text_scatter is None:
        text_scatter=['']*len(all_x)
    for idx, (x, y, names) in enumerate(zip(all_x, all_y, all_names)):
        fig.add_trace(go.Scatter(
            x=x, y=y, name=names, mode=mode, text=text_scatter[idx], error_y=error_y,
            marker=marker, marker_color=colors[idx]))
    fig.update_layout(
        title_text=title_text,
        xaxis_title_text=xaxis_title_text,
        yaxis_title_text=yaxis_title_text,
        xaxis=xaxis,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        yaxis=yaxis,
        font=font,
        showlegend=showlegend,
        legend=legend,
    )
    if x_range is not None:
        fig.update_xaxes(range=x_range)
    if y_range is not None:
        fig.update_yaxes(range=y_range)
    if log:
        fig.update_yaxes(type="log", dtick=1, range=y_range)
    if save_path is not None:
        fig.write_image(save_path)
        print(f'Plotted and saved scatter plot under {save_path}')
    return fig

def plot_instances_highlighting_special(images, labels, indexes, x, y,
        special_indexes=None, save_path=None, dataset='cifar'):
    fig, axarr = plt.subplots(x,y)
    plt.rcParams.update({'font.size': 50})
    for ax, im, label, idx in zip(axarr.ravel(), images, labels, indexes):
        ax.set_title(f'label: {CLASS_ID_TO_NAMES[dataset][label]}')
        if special_indexes is not None and idx in special_indexes:
            ax.patch.set_edgecolor('red')  
            ax.patch.set_linewidth('55')
        else:
            ax.set_axis_off()
        ax.imshow(im)
    fig.set_size_inches(100, 100)
    if save_path is not None:
        fig.savefig(save_path)
        print(f'Plotted and saved images under {save_path}')
    return fig


def plot_instances(images, labels, x, y, save_path=None, dataset='cifar', similarities=None):
    fig, axarr = plt.subplots(x,y)
    plt.rcParams.update({'font.size': 50})
    if similarities is not None:
        plt.rcParams.update({'font.size': 40})
    for idx, (ax, im, label) in enumerate(zip(axarr.ravel(), images, labels)):
        ax.imshow(im)
        ax.set_axis_off()
        if similarities is None:
            ax.set_title(f'label: {CLASS_ID_TO_NAMES[dataset][label]}')
        else:
            ax.set_title(f'{CLASS_ID_TO_NAMES[dataset][label]} ({round(float(similarities[idx]),4)})')
    fig.set_size_inches(100, 100)
    if save_path is not None:
        fig.savefig(save_path)
        print(f'Plotted and saved images under {save_path}')
    return fig

def plot_multiple_histograms(all_x, all_names, save_paths, all_y=None,
        xaxis=dict(), xaxis_title_text='Similarity Score',
        min_value=None, max_value=None):
    for idx, (x, save_path) in enumerate(zip(all_x, save_paths)):
        fig = go.Figure()
        if all_y is None:
            y = None
        else:
            y = all_y[idx]
        fig.add_trace(go.Histogram(
            histfunc="count", x=x, y=y, name=all_names[idx],
            marker_color=COLORS[idx]))
        fig.update_layout(
            title_text=f'Histogram',
            xaxis_title_text=xaxis_title_text,
            yaxis_title_text='Count',
            xaxis=xaxis)
        if min_value is not None and max_value is not None:
            fig.update_xaxes(range=[min_value, max_value])
        fig.write_image(save_path)
        print(f'Plotted and saved histogram under {save_path}')

def plot_histogram(all_x, all_names, save_path=None, all_y=None, title_text='', colors=None, showlegend=True,
        xaxis=dict(), yaxis=dict(), xaxis_title_text='Similarity Score', nbinsx=None, log=False,
        xrange=None, yrange=None, histfunc="count", histnorm=None, yaxis_title_text='', error_y=dict(),
        bingroup=None, barmode='group', xbins=None, font=dict(), y_hline=None, legend=dict(),
        x_vline=None, linecolor='red', vline_annotation='', marker_patterns=None):
    if bool(xaxis) is False:
        xaxis = dict(
            linecolor = "black",
            spikedash = 'solid',
            showgrid=True,
            gridcolor='lightgrey')
    if bool(yaxis) is False:
        yaxis = dict(
            linecolor = "black",
            spikedash = 'solid',
            showgrid=True,
            gridcolor='lightgrey')
    if colors is None:
        colors = COLORS
    if marker_patterns is None:
        marker_patterns = [None]*len(all_x)
    fig = go.Figure()
    for idx, x in enumerate(all_x):
        if all_y is None:
            y = None
        else:
            y = all_y[idx]
        fig.add_trace(go.Histogram(
            histfunc=histfunc, histnorm=histnorm, nbinsx=nbinsx, autobinx=False,
            x=x, y=y, name=all_names[idx], bingroup=bingroup, error_y=error_y,
            marker_color=colors[idx], xbins=xbins,  marker_pattern_shape=marker_patterns[idx]))
    if y_hline is not None:
        fig.add_hline(y=y_hline)
    if x_vline is not None:
        if isinstance(x_vline, list):
            for x_vline_idx, x_vline_item in enumerate(x_vline):
                fig.add_vline(x=x_vline_item, line_width=2, line_dash='dash',
                              line_color=linecolor[x_vline_idx], annotation_text=vline_annotation[x_vline_idx],
                              annotation_position='top')
        else:
            fig.add_vline(x=x_vline, line_width=2, line_dash='dash', line_color=linecolor,
                      annotation_text=vline_annotation, annotation_position='top')
    fig.update_layout(
        title_text=title_text,
        xaxis_title_text=xaxis_title_text,
        yaxis_title_text=yaxis_title_text,
        barmode=barmode,
        xaxis=xaxis,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        yaxis=yaxis,
        showlegend=showlegend,
        legend=legend,
        font=font)
    if xrange is not None:
        fig.update_xaxes(range=xrange)
    if yrange is not None:
        fig.update_yaxes(range=yrange)
    if log:
        fig.update_yaxes(type="log", dtick=1, range=yrange)
    if save_path is not None:
        fig.write_image(save_path)
        print(f'Plotted and saved histogram under {save_path}')
    return fig

def plot_box_plot_with_bins(df, title_text='', save_path=None, font=dict(), yrange=None, xtick_angle=0,
                            xrange=None, ranges=None, column_name_y='ratio_intersection', showlegend=True,
                            column_name_x='sim', yaxis_title_text=None, xaxis_title_text=None):
    if ranges is None:
        ranges = np.arange(0, 1.1, 0.1)
    fig = go.Figure()
    for idx, (split, r0, r1) in enumerate(zip(df, ranges, ranges[1:])):
        info_plot_x = split[column_name_x]
        info_plot_y = split[column_name_y]
        if len(list(info_plot_y)) == 0:
            info_plot_x = [0.]
            info_plot_y = [0.]
        fig.add_trace(go.Box(
            y=list(info_plot_y),
            marker_color=COLORS[0],
            # name=f"[{round(min(info_plot_x), 4)}:{round(max(info_plot_x), 4)}]",
            name=f"[{round(r0,1)}:{round(r1,1)}]",
            boxmean='sd'))
    fig.update_layout(
        title_text=title_text,
        yaxis_title_text=yaxis_title_text,
        xaxis_title_text=xaxis_title_text,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        # yaxis_range=yrange,
        # xaxis_range=xrange,
        font=font,
        width=1000,
        height=500,
        showlegend=showlegend,
        xaxis = dict(
            linecolor = "black", spikedash = 'solid',
            # tickmode = 'array',
            # tickvals = np.arange(0, 1.1, 0.1),
            gridcolor='lightgrey'),
        yaxis = dict(
            linecolor = "black", spikedash = 'solid',
            # tickmode = 'array',
            # tickvals = np.arange(0, 1.1, 0.1),
            showgrid=True, gridcolor='lightgrey'))
    if yrange is not None:
        fig.update_yaxes(range=yrange)
    if xrange is not None:
        fig.update_xaxes(range=xrange)
    if xtick_angle != 0:
        fig.update_xaxes(tickangle=xtick_angle)
    if save_path is not None:
        fig.write_image(save_path)
        print(f'Plotted and saved box plot under {save_path}')
    return fig

def plot_confusion_matrix(conf_mx, save_path=None, x=None, y=None,
    dataset='cifar', x_text="Predicted Label", y_text="True Label", title_text='Confusion matrix'):
    z = conf_mx
    if x is None:
        x = list(CLASS_ID_TO_NAMES[dataset].values())
    if y is None:
        y =  list(CLASS_ID_TO_NAMES[dataset].values())

    z_text = [[str(y) for y in x] for x in z]

    fig = ff.create_annotated_heatmap(
        z, x=x, y=y, annotation_text=z_text, colorscale='Viridis')

    fig.update_layout(title_text=title_text, width=1000)

    fig.add_annotation(dict(font=dict(color="black",size=14),
                            x=0.5,
                            y=-0.15,
                            showarrow=False,
                            text=x_text,
                            xref="paper",
                            yref="paper"
                        ))
    fig.add_annotation(dict(font=dict(color="black",size=14),
                            x=-0.25,
                            y=0.5,
                            showarrow=False,
                            text=y_text,
                            textangle=-90,
                            xref="paper",
                            yref="paper"
                        ))

    fig.update_layout(margin=dict(t=100, l=200))
    fig.update_yaxes(autorange="reversed")
    fig['data'][0]['showscale'] = True
    if save_path is not None:
        fig.write_image(save_path)
        print(f'Plotted and saved confusion matrix under {save_path}')
    return fig

def plot_histogram_before_after(all_x, all_names, mean_before=None, mean_after=None,
                                xrange=None, yrange=None, colors=None, all_y=None,
                                save_path=None, histnorm='probability density'):
    fig = go.Figure()
    if colors is None:
        colors = COLORS
    for idx, x in enumerate(all_x):
        if all_y is None:
            y = None
        else:
            y = all_y[idx]
        fig.add_trace(go.Histogram(
            x=x, y=y, name=all_names[idx], histnorm=histnorm,
            marker_color=colors[idx]))

    shapes, annotations = [], []
    if mean_after is not None and mean_before is not None:
        shapes= [
            {'line': {'color': 'white', 'dash': 'dash', 'width': 1},
                'type': 'line', 'x0': mean_before, 'x1': mean_before, 'xref': 'x',
                'y0': -0.1, 'y1': 1, 'yref': 'paper' },
            {'line': {'color': 'black', 'dash': 'dash', 'width': 1},
                'type': 'line', 'x0': mean_after, 'x1': mean_after, 'xref': 'x',
                'y0': -0.1, 'y1': 1, 'yref': 'paper'}]
        annotations=[
            dict(x=mean_before, y=1, xref='x',
                yref='paper', text="Mean Before",
            ),
            dict(x=mean_after, y=0.95, xref='x',
                yref='paper', text="Mean After",
            )]
    if xrange is not None:
        fig.update_xaxes(range=xrange)
    if yrange is not None:
        fig.update_yaxes(range=yrange)
    fig.update_layout(
        title_text="Histogram (prob. density)",
        barmode='overlay',
        annotations=annotations,
        shapes=shapes
    )
    fig.update_traces(opacity=0.6)
    if save_path is not None:
        fig.write_image(save_path)
        print(f'Plotted and saved histogram under {save_path}')
    return fig


