import math

import torch
import torch.nn as nn
from torch.nn import functional as F
from torchtune.modules import RotaryPositionalEmbeddings


def plot_attention(attn, path, labels, title):
    import matplotlib.pyplot as plt

    plt.rcParams["svg.fonttype"] = "none"

    font_size = 12 if len(labels) <= 28 else 10  # D <= 3 ? 12 : 10
    title_font_size = 20

    plt.figure(figsize=(6, 6), dpi=96)
    plt.imshow(attn, cmap="viridis", interpolation="nearest")

    # Center the labels text
    plt.yticks(range(len(labels)), labels, ha="right", fontsize=font_size)
    plt.xticks(range(len(labels)), labels, ha="center", fontsize=font_size)

    # Add a title
    plt.title(title, fontsize=title_font_size)

    plt.tight_layout()

    plt.savefig(path)
    plt.close()


class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, n_embd, n_head, block_size, dropout, head_dim=None):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.head_dim = head_dim if head_dim is not None else n_embd // n_head
        self.c_attn = nn.Linear(n_embd, 3 * n_head * self.head_dim)
        # output projection
        self.c_proj = nn.Linear(n_head * self.head_dim, n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
        self.n_head = n_head
        self.n_embd = n_embd

        self.drop_attn = nn.Dropout(dropout)
        self.drop_resid = nn.Dropout(dropout)

    def forward(self, x, attn_path=None, layer_id=None, labels=None, attention_maps=None, obj=None):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_head * self.head_dim, dim=2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)  # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.drop_attn(att)

        if attn_path is not None:
            attn = att.detach().cpu().numpy()
            for head_ix in range(self.n_head):
                attn_head = attn[0, head_ix, :, :]
                plot_path = f"{attn_path}/layer_{layer_id}_head_{head_ix}.pdf"
                title = f"Layer {layer_id + 1}, head {head_ix + 1}"
                plot_attention(attn_head, plot_path, labels, title)

        for head_ix in range(self.n_head):
            if attention_maps is not None and attention_maps[layer_id][head_ix] is not None:
                att[:, head_ix] = attention_maps[layer_id][head_ix]

        if obj is not None and obj["layer"] == layer_id:
            obj["sum"] += (att * obj["attention_map"]).sum().item()
            obj["count"] += obj["attention_map"].sum().item()

        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)  # re-assemble all head outputs side by side

        # output projection
        y = self.drop_resid(self.c_proj(y))
        return y


class Block(nn.Module):
    """an unassuming Transformer block"""

    def __init__(self, n_embd, n_head, block_size, dropout, mlp, head_dim):
        super().__init__()
        self.ln_1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout, head_dim)
        self.ln_2 = nn.LayerNorm(n_embd)
        self.mlp = nn.ModuleDict(
            dict(
                c_fc=nn.Linear(n_embd, 4 * n_embd),
                c_proj=nn.Linear(4 * n_embd, n_embd),
                act=nn.ReLU(),
                drop=nn.Dropout(dropout),
            )
        )
        m = self.mlp
        self.mlpf = lambda x: (m.drop(m.c_proj(m.act(m.c_fc(x)))) if mlp else torch.zeros_like(x))  # MLP forward

    def forward(self, x, attn_path=None, layer_id=None, labels=None, attention_maps=None, obj=None):
        x = x + self.attn(self.ln_1(x), attn_path, layer_id, labels, attention_maps, obj=obj)
        x = x + self.mlpf(self.ln_2(x))
        return x


class Tranformer(torch.nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        n_embd,
        n_head,
        block_size,
        num_Layers,
        dropout=0,
        mlp=False,
        head_dim=None,
    ):
        super(Tranformer, self).__init__()
        self.l_in = nn.Linear(input_dim * 2, n_embd)
        self.pos_embed = RotaryPositionalEmbeddings(input_dim)
        self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size, dropout, mlp, head_dim) for _ in range(num_Layers)])
        self.l_out = nn.Linear(n_embd, output_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, attn_path=None, labels=None, attention_maps=None, obj=None):
        pos_embed = self.pos_embed(torch.zeros((x.size(0), x.size(1), 1, x.size(2))).to(x.device)).squeeze(2)
        x = torch.cat((x, pos_embed), dim=-1)

        x = self.l_in(x)
        x = self.drop(x)

        x = x.reshape(x.size(0), x.size(1), 1, -1)
        x = x.reshape(x.size(0), x.size(1), -1)

        for i, block in enumerate(self.blocks):
            x = block(x, attn_path, i, labels, attention_maps, obj=obj)

        x = self.l_out(x)

        return x
