import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from LLMProxy.option import ModelArg
from typing import Optional


def reshape_for_boardcast(freqs_cis, x):
    # x.shape[1] is seq_len, x.shape[-1] is dim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 
    shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.
    """
    # freqs_cis.shape[0] == xq.shape[1] (seq_len)
    # freqs_cis.shape = (seq_len, dim // 2)
    _xq = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    _xk = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    freqs_cis = reshape_for_boardcast(freqs_cis, _xq)

    xq_out = torch.view_as_real(_xq * freqs_cis).flatten(_xq.ndim - 1)
    xk_out = torch.view_as_real(_xk * freqs_cis).flatten(_xk.ndim - 1)
    return xq_out.type_as(xq), xk_out.type_as(xk)


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    # Build freqs matrix across dimension 
    # freqs = [theta_i] where theta_i = 10000 ^ (- 2 * i / dim), i in [0, dim / 2]
    # shape = [dim // 2]
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # build indices from 0 to end
    t = torch.arange(end, device=freqs.device)
    # build freqs for each position [end, dim // 2]
    freqs = torch.outer(t, freqs).float()
    # torch.polar, input is theta_i, output is [cos(theta) + sin(theta) * i], complex tensor
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


class RMSNorm(nn.Module):
    """
    RMS Normalization layer
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()

        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps
    
    def forward(self, x: torch.Tensor):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

    def _norm(self, x: torch.Tensor):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)


class Attention(nn.Module):
    '''
    Initialize Multi-Head Attention module.

    Args:
        args (ModelArg): Model configuration.

    Parameters:
        hidden_size: Dimension of the hidden representation in Attention.
        num_attention_heads: The number of attention heads in Attention.
    '''
    def __init__(self, args: ModelArg):
        super().__init__()

        self.dim = args.hidden_size
        self.head_dim = args.hidden_size // args.num_attention_heads
        self.num_heads = args.num_attention_heads

        self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
        self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        bsz, seqlen, _ = x.shape
        # convert qkv as [bsz, seqlen, num_heads, head_dim]
        xq = xq.view(bsz, seqlen, self.num_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.num_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.num_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        # change to [bsz, num_heads, seqlen, head_dim]
        xq = xq.transpose(1, 2)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        # Q * K^T (dot product)
        score = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.dim)
        if mask is not None:
            score = score + mask
        # score = [bsz, num_heads, seqlen, seqlen]
        score = F.softmax(score.float(), dim=-1).type_as(xq)
        # output = [bsz, num_heads, seqlen, head_dim]
        output = torch.matmul(score, xv)
        # output = [bsz, seqlen, num_heads * head_dim]
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        o = self.wo(output)
        return o


class FeedForward(nn.Module):

    def __init__(self, args: ModelArg):
        super().__init__()


class TransformerLayer(nn.Module):

    def __init__(self, args: ModelArg):
        super().__init__()

        self.attn = Attention(args=args)
        self.ffn = FeedForward(args=args)

        self.attn_norm = RMSNorm(dim=args.hidden_size, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(dim=args.hidden_size, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ):
        x = x + self.attention(
            self.attn_norm(x), freqs_cis=freqs_cis, mask=mask
        )
        x = x + self.ffn(self.ffn_norm(x))
        return x


class Transformer(nn.Module):

    def __init__(self, args: ModelArg):
        super().__init__()
        self.token_embeddings = nn.Embedding(
            args.vocab_size, args.hidden_size
        )
        self.n_layers = args.num_hidden_layers

        self.layers = nn.ModuleList()

        for i in range(self.n_layers):
            self.layers.append(TransformerLayer(args=args))

        self.norm = RMSNorm(args.hidden_size, eps=args.norm_eps)
        self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

    def forward(self, input_ids: torch.Tensor):
        sz = input_ids.size()

        x = self.token_embeddings(input_ids)
        
        for layer_id in range(self.n_layers):
            x = self.layers[layer_id](x)

        x = self.norm(x)
        x = self.output(x)
        return x
        
        