from typing import Union, Dict
import numpy as np
import plotly.graph_objects as go


def plot_eigs_hist(eigs: Union[np.array, Dict[str, np.array]], show=True, logscale=True):
    # Create figure
    fig = go.Figure()
    # Add traces, one for each time step and each series
    if isinstance(eigs, np.ndarray):
        eigs = {'eigs': eigs}
    # Compute ranges for x-axis:
    eigs_flat = np.concatenate([series.reshape(-1) for series in eigs.values()])
    y_max = 0
    if logscale:
        eigs_flat = np.abs(eigs_flat)
        eigs_flat = eigs_flat[eigs_flat > 0]
        eigs_min = np.log10(eigs_flat).min()
        eigs_max = np.log10(eigs_flat).max()
    else:
        eigs_min = eigs_flat.min()
        eigs_max = eigs_flat.max()
    for key in eigs:
        series = eigs[key]
        n = len(series)
        digits = int(np.ceil(np.log10(n)))
        formatter = '{:0' + str(digits) + 'd}'
        for step in range(len(series)):
            data = series[step, :].reshape(-1)
            if logscale:
                data = data[np.abs(data) > 0]
                data = np.log10(data)
                
            hist, bins = np.histogram(data, bins=np.linspace(eigs_min, eigs_max, 100), density=False)
            y_max = max(y_max, np.max(hist))
            fig.add_trace(
                go.Bar(
                    x = bins,
                    y = hist,
                    opacity = 0.5,
                    visible=False,
                    name=f'{key}: {formatter.format(step)}'
                )
            )
    
    # Create and add slider
    steps = []
    series_cnt = len(eigs)
    series_lens = [len(series) for series in eigs.values()]
    series_len = series_lens[0]
    assert np.all([series_len == N for N in series_lens]) 
    for i in range(series_len):
        step = dict(
            method="update",
            args=[{"visible": [False] * series_cnt * series_len},
                {"title": "Eigenvalue distribution at training logging point: " + str(i)}],  # layout attribute
        )
        for series_idx in range(series_cnt):
            step["args"][0]["visible"][series_idx * series_len + i ] = True
        steps.append(step)

    for i in range(series_cnt):
        fig.data[i * series_len].visible = True
    sliders = [dict(
        active=0,
        currentvalue={"prefix": "Logging point:"},
        pad={"t": 50},
        steps=steps
    )]
    
    fig.update_layout(
        sliders=sliders,
        xaxis={'range': [eigs_min, eigs_max], 'autorange': False},
        yaxis={'range': [0, y_max], 'autorange': False},
        barmode='overlay'
    )
    if show:
        fig.show()
    else:
        return fig