import numpy as np
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display


def show_multilayer_attention_maps(maps: np.ndarray) -> None:
    # Make sure that there's only one sequence
    assert maps.shape[1] == 1
    num_layers = maps.shape[0]
    for layer_idx in [0, round(num_layers / 2), num_layers - 1]:
        display(md(f"**Layer {layer_idx} attention**"))
        fig = plot_attention_maps(maps[layer_idx, 0])
        fig.show()


def plot_attention_maps(
    attention: np.ndarray,
) -> go.Figure:
    mean_attention = attention.mean(axis=0)
    mean_attention_map = _create_attention_map(mean_attention)
    buttons = [
        {
            "label": "Mean attention",
            "method": "update",
            "args": [
                {"z": [mean_attention]},
                {"title": "Mean attention"},
            ],
        }
    ] + [
        {
            "label": f"Attention head {i}",
            "method": "update",
            "args": [
                {"z": [head_map]},
                {"title": f"Attention head {i}"},
            ],
        }
        for i, head_map in enumerate(attention)
    ]
    layout = go.Layout(
        updatemenus=[
            {
                "buttons": buttons,
                "direction": "down",
                "showactive": True,
                "x": 0.8,
                "xanchor": "left",
                "y": 1.04,
                "yanchor": "top",
            }
        ],
        title="Mean attention",
        height=900,
        width=1000,
        xaxis_title="Sequence",
        yaxis_title="Sequence",
    )
    fig = go.Figure(data=[mean_attention_map], layout=layout)
    return fig


def _create_attention_map(
    head_attention: np.ndarray,
) -> go.Heatmap:
    map = go.Heatmap(
        z=head_attention,
    )
    return map
