import math
from typing import Tuple

import torch
import torch.nn as nn
from models.rms_norm import RMSNorm
from torch.nn import functional as F


# Taken from facebookresearch/llama/model.py
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


# Taken from facebookresearch/llama/model.py
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    T_q = xq.shape[1]
    freqs_cis_q = reshape_for_broadcast(freqs_cis[:T_q], xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis_q).flatten(3)

    T_k = xk.shape[1]
    freqs_cis_k = reshape_for_broadcast(freqs_cis[:T_k], xk_)
    xk_out = torch.view_as_real(xk_ * freqs_cis_k).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)


# Taken from facebookresearch/llama/model.py
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64

    return freqs_cis


class LayerNorm(nn.Module):
    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""

    def __init__(self, ndim: int, bias: bool):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class MultiheadAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        n_head: int,
        bias: bool,
        dropout: float,
        block_size: int,
        causal: bool,
        disable_flash: bool = False,
    ):
        super().__init__()
        assert embed_dim % n_head == 0
        self.n_head = n_head
        self.n_embd = embed_dim
        self.dropout = dropout
        self.causal = causal

        # key, query, value projections for all heads, but in a batch
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        # output projection
        self.output_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        # regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = (
            hasattr(torch.nn.functional, "scaled_dot_product_attention")
            and not disable_flash
        )
        if not self.flash:
            print(
                "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0"
            )
            if self.causal:
                self.register_buffer(
                    "attn_mask",
                    torch.tril(torch.ones(block_size, block_size)).view(
                        1, 1, block_size, block_size
                    ),
                )
            else:
                self.register_buffer(
                    "attn_mask",
                    torch.ones(block_size, block_size).view(
                        1, 1, block_size, block_size
                    ),
                )

    def forward(
        self, x: torch.Tensor, mem: torch.Tensor, freqs_cis: torch.Tensor
    ) -> torch.Tensor:
        B, T_x, C = x.size()
        B, T_mem, C = mem.size()

        q = self.q_proj(x).view(B, T_x, self.n_head, C // self.n_head)
        k = self.k_proj(mem).view(B, T_mem, self.n_head, C // self.n_head)
        v = self.v_proj(mem).view(B, T_mem, self.n_head, C // self.n_head)

        q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(
                q,
                k,
                v,
                attn_mask=None,
                dropout_p=self.dropout if self.training else 0,
                is_causal=self.causal,
            )
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.attn_mask[:, :, :T_x, :T_mem] == 0, float("-inf"))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = (
            y.transpose(1, 2).contiguous().view(B, T_x, C)
        )  # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.output_proj(y))
        return y


class MLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        bias: bool,
        dropout: float,
    ):
        super().__init__()
        self.c_fc = nn.Linear(embed_dim, 4 * embed_dim, bias=bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * embed_dim, embed_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class DecoderBlock(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        n_head: int,
        bias: bool,
        dropout: float,
        block_size: int,
        causal: bool,
        norm_type: str = "layer_norm",
        disable_flash: bool = False,
    ):
        super().__init__()

        if norm_type == "layer_norm":
            self.norm_fn = LayerNorm
        elif norm_type == "rms_norm":
            self.norm_fn = RMSNorm
        else:
            raise ValueError(f"Unknown norm type: {norm_type}")

        self.ln_1 = self.norm_fn(embed_dim, bias=bias)
        self.attn_1 = MultiheadAttention(
            embed_dim=embed_dim,
            n_head=n_head,
            bias=bias,
            dropout=dropout,
            block_size=block_size,
            causal=causal,
            disable_flash=disable_flash,
        )
        self.ln_2 = self.norm_fn(embed_dim, bias=bias)
        self.mlp = MLP(
            embed_dim=embed_dim,
            bias=bias,
            dropout=dropout,
        )

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
        x_norm = self.ln_1(x)
        x = x + self.attn_1(x_norm, x_norm, freqs_cis)
        x = x + self.mlp(self.ln_2(x))
        return x


class Decoder(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        n_layers: int,
        embed_dim: int,
        n_head: int,
        bias: bool,
        dropout: float,
        block_size: int,
        causal: bool,
        deq: str = False,
        norm_type: str = "layer_norm",
        disable_flash: bool = False,
    ):
        super().__init__()

        # If True, operate as a DEQ and share parameters across layers
        self.deq = deq
        self.n_layers = n_layers

        self.layers = nn.ModuleList(
            [
                DecoderBlock(
                    embed_dim=embed_dim,
                    n_head=n_head,
                    bias=bias,
                    dropout=dropout,
                    block_size=block_size,
                    causal=causal,
                    disable_flash=disable_flash,
                )
                for _ in range(n_layers if not deq else 1)
            ]
        )
        self.input_projection = nn.Linear(input_dim, embed_dim)

        if norm_type == "layer_norm":
            self.norm_fn = LayerNorm
        elif norm_type == "rms_norm":
            self.norm_fn = RMSNorm
        else:
            raise ValueError(f"Unknown norm type: {norm_type}")

        self.final_layer_norm = self.norm_fn(embed_dim, bias=bias)
        self.output_projection = nn.Linear(embed_dim, output_dim)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
        for i in range(self.n_layers):
            x = (
                self.layers[i](x, freqs_cis)
                if not self.deq
                else self.layers[0](x, freqs_cis)
            )
        x = self.output_projection(self.final_layer_norm(x))
        return x


class Transformer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        n_decoder_layers: int,
        embed_dim: int,
        n_head: int,
        bias: bool,
        dropout: float,
        block_size: int,
        causal_decoder: bool,
        max_seq_len: int,
        deq: str = False,
        norm_type: str = "layer_norm",
        disable_flash: bool = False,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.n_head = n_head
        self.bias = bias
        self.dropout = dropout
        self.block_size = block_size
        self.max_seq_len = max_seq_len
        self.n_decoder_layers = n_decoder_layers

        self.input_projection = nn.Linear(input_dim, embed_dim)

        self.decoder = Decoder(
            input_dim=input_dim,
            output_dim=output_dim,
            n_layers=n_decoder_layers,
            embed_dim=embed_dim,
            n_head=n_head,
            bias=bias,
            dropout=dropout,
            block_size=block_size,
            causal=causal_decoder,
            deq=deq,
            norm_type=norm_type,
            disable_flash=disable_flash,
        )
        self.freqs_cis = precompute_freqs_cis(embed_dim // n_head, max_seq_len)

        # Init all weights
        self.apply(self._init_weights)
        # Apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith("output_proj.weight"):
                torch.nn.init.normal_(
                    p,
                    mean=0.0,
                    std=0.02 / math.sqrt(n_decoder_layers),
                )

    @property
    def num_params(self) -> int:
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, input: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
        _, t, _ = input.shape
        assert t <= self.block_size, (
            f"Cannot forward sequence of length {t}, "
            f"block size is only {self.block_size}"
        )

        freqs_cis = self.freqs_cis.to(input.device)
        freqs_cis = freqs_cis[start_pos : start_pos + t]

        input = self.input_projection(input)
        # Get the decoded features
        pred = self.decoder(input, freqs_cis)

        return pred

    def estimate_mfu(self, fwdbwd_per_iter: int, dt: float) -> float:
        """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS"""
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        L, H, Q, T = (
            self.n_decoder_layers,
            self.n_head,
            self.embed_dim // self.n_head,
            self.block_size,
        )
        flops_per_token = 6 * N + 12 * L * H * Q * T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0 / dt)  # per second
        flops_promised = 312e12  # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.no_grad()
    def generate(
        self,
        input: torch.Tensor,
        start_pos: int = 0,
    ) -> torch.Tensor:
        input = self.input_projection(input)
        T = input.shape[1]

        freqs_cis = self.freqs_cis.to(input.device)
        freqs_cis = freqs_cis[start_pos : start_pos + T]

        recons = self.decoder(input, freqs_cis)

        return recons
