import math

import jax
import jax.numpy as jnp
import numpy as np
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.llama.modeling_flax_llama import FlaxLlamaRotaryEmbedding

from nets.world_model import get_default_position_ids, FlaxLlamaRotaryEmbeddingReal


def visualize_attention():
    use_spatio_temporal = True

    T, H, W = 5, 9, 9
    tokens_per_block: int = H * W + 1
    max_blocks: int = T
    vocab_size: int = 4096
    n_positions: int = tokens_per_block * max_blocks
    n_embd: int = 512
    n_layer: int = 3
    n_head: int = 8
    n_inner = None  # defaults to 4 * n_embd
    resid_pdrop: float = 0.1
    embd_pdrop: float = 0.1
    attn_pdrop: float = 0.1

    batch_size = 3

    rng = jax.random.PRNGKey(0)

    rng, rng_q, rng_k = jax.random.split(rng, 3)
    q = jnp.ones((batch_size, tokens_per_block * max_blocks, n_head, n_embd // n_head))
    # q = jax.random.normal(
    #     rng_q, (batch_size, tokens_per_block * max_blocks, n_head, n_embd // n_head)
    # )
    k = q

    config = GPT2Config(
        vocab_size=4096,
        n_positions=tokens_per_block * max_blocks,
        n_embd=512,
        n_layer=3,
        n_head=8,
        n_inner=None,  # defaults to 4 * n_embd
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
    )
    dtype = jnp.float32
    rotary_emb = FlaxLlamaRotaryEmbeddingReal(config, dtype=dtype)

    rngs, rng_params, rng_others = jax.random.split(rng, 3)
    rngs = {"params": rng_params, "other_rng": rng_others}

    if use_spatio_temporal:
        position_ids = get_default_position_ids(
            batch_size, tokens_per_block * max_blocks, tokens_per_block, True
        )
        position_ids = position_ids.at[..., 1].set(
            position_ids[..., 1] - position_ids[..., 0]
        )
        position_ids = position_ids.at[..., 2].set(
            position_ids[..., 2] - position_ids[..., 0]
        )
        position_ids = position_ids / 10.0
    else:
        position_ids = (
            jnp.broadcast_to(
                jnp.arange(tokens_per_block * max_blocks)[None, :, None],
                (batch_size, tokens_per_block * max_blocks, 3),
            ).astype("i4")
            / 10.0
        )

    params = rotary_emb.init(rngs, q, k, position_ids[..., 0])

    head_dim = config.hidden_size // config.num_attention_heads
    unit_time_size = 2
    unit_space_size = 3
    scale = head_dim // 2 // (unit_time_size + unit_space_size * 2)

    rope_patterns = [
        jnp.array(
            [False, False] * unit_space_size * scale + [True] * unit_time_size * scale
        ),
        jnp.array(
            [True, False] * unit_space_size * scale + [False] * unit_time_size * scale
        ),
        jnp.array(
            [False, True] * unit_space_size * scale + [False] * unit_time_size * scale
        ),
    ]
    rope_patterns = [
        jnp.concatenate((pattern, pattern), axis=-1) for pattern in rope_patterns
    ]

    query, key = q, k
    for idx, pattern in enumerate(rope_patterns):
        rotated_query, rotated_key = rotary_emb.apply(
            params, query, key, position_ids[..., idx]
        )
        query = jnp.where(pattern, rotated_query, query)
        key = jnp.where(pattern, rotated_key, key)

    # query, key = q, k
    # query, key = rotary_emb.apply(params, query, key, jnp.arange(tokens_per_block * max_blocks))

    # query = query[..., [2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15]]
    # key = key[..., [2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15]]

    attn_weights = jnp.einsum("b q h e, b k h e -> b h q k", query, key) / jnp.sqrt(
        query.shape[-1]
    )
    attn = jax.nn.softmax(attn_weights)

    attn_weights = attn_weights[0, 0].reshape(T, (H * W + 1), T, (H * W + 1))
    attn_weights_state = attn_weights[:, : (H * W), :, : (H * W)]
    attn_weights_state = attn_weights_state.reshape(T, H, W, T, H, W)
    attn_weights_action = attn_weights[:, :1, :, -1:]

    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

    # fig0, ((ax00, ax01), (ax10, ax11)) = plt.subplots(2, 2,  gridspec_kw={'height_ratios': [9, 1]})
    fig0, (ax00, ax01) = plt.subplots(1, 2)

    t0_attn = attn_weights_state[0, 0, 0, 0, :, :]
    t1_attn = attn_weights_state[0, 0, 0, 1, :, :]

    t10_attn = attn_weights_action[0, :1, 0, :]
    t11_attn = attn_weights_action[0, :1, 1, :]

    score_min = min(t0_attn.min(), t1_attn.min())
    score_min = min(score_min, t10_attn)
    score_min = min(score_min, t11_attn)

    score_max = max(t0_attn.max(), t1_attn.max())
    score_max = max(score_max, t10_attn)
    score_max = max(score_max, t11_attn)

    t0_attn = (t0_attn - score_min) / (score_max - score_min)
    t1_attn = (t1_attn - score_min) / (score_max - score_min)
    t10_attn = (t10_attn - score_min) / (score_max - score_min)
    t11_attn = (t11_attn - score_min) / (score_max - score_min)

    im = ax00.imshow(t0_attn, vmin=0, vmax=1)

    # Create a Rectangle patch
    rect = patches.Rectangle(
        (-0.5, -0.5),
        1,
        1,
        linewidth=2,
        edgecolor="r",
        facecolor="none",
        clip_on=False,
        zorder=100,
    )

    # Add the patch to the Axes
    ax00.add_patch(rect)
    im2 = ax01.imshow(t1_attn, vmin=0, vmax=1)
    # im3 = ax10.imshow(t10_attn, vmin=0, vmax=1)
    # im4 = ax11.imshow(t11_attn, vmin=0, vmax=1)
    ax00.yaxis.set_inverted(True)
    ax00.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
    ax00.set_xticks(range(W))
    ax00.set_yticks(range(H))
    ax00.set_title("t=0")

    ax01.yaxis.set_inverted(True)
    ax01.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
    ax01.set_xticks(range(W))
    ax01.set_yticks(range(H))
    ax01.set_title("t=1")

    # ax10.set_axis_off()
    # ax11.set_axis_off()

    ax00_divider = make_axes_locatable(ax00)
    ax01_divider = make_axes_locatable(ax01)

    cax00 = ax00_divider.append_axes("left", size="7%", pad="2%")
    cax00.set_axis_off()
    cax00_bottom = ax00_divider.append_axes("bottom", size="10%", pad="2%")
    cax00_bottom.imshow(t10_attn, vmin=0, vmax=1)
    cax00_bottom.get_xaxis().set_visible(False)
    # cax00_bottom.get_yaxis().set_label_text("action")
    cax00_bottom.get_yaxis().set_ticks([0], labels=["action"])
    cax01 = ax01_divider.append_axes("right", size="7%", pad="2%")
    cax01_bottom = ax01_divider.append_axes("bottom", size="10%", pad="2%")
    cax01_bottom.imshow(t11_attn, vmin=0, vmax=1)
    cax01_bottom.get_xaxis().set_visible(False)
    cax01_bottom.get_yaxis().set_ticks([0], labels=["action"])
    # cax01_bottom.set_axis_off()

    fig0.colorbar(im2, cax=cax01)
    fig0.tight_layout()
    plt.savefig("heatmap.pdf")


if __name__ == "__main__":
    visualize_attention()
