from typing import Tuple, Optional, Generator
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.config import ModelConfig, Transformer

class Rotary(nn.Module):
    def __init__(self, dim: int, max_seq_len: int = 65536) -> None:
        super().__init__()

        angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim // 4, dtype=torch.float32)
        angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim // 4)])
        t = torch.arange(max_seq_len, dtype=torch.float32)
        theta = torch.einsum("i,j -> ij", t, angular_freq)
        self.register_buffer("cos", theta.cos(), persistent=False)
        self.register_buffer("sin", theta.sin(), persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        cos = self.cos[None, : x.size(-3), None, :]
        sin = self.sin[None, : x.size(-3), None, :]
        x1, x2 = x.float().chunk(2, dim=-1)
        y1 = x1 * cos + x2 * sin
        y2 = x1 * (-sin) + x2 * cos
        return torch.cat((y1, y2), 3).type_as(x)


class Attention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        embed_dim: int,
        num_key_value: int,
        bias: bool,
    ) -> None:
    
        super().__init__()

        assert embed_dim % num_heads == 0
        assert num_heads % num_key_value == 0

        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.num_key_value = num_key_value
        self.head_dim = embed_dim // num_heads

        self.rotary = Rotary(embed_dim // num_heads)

        self.c_attn_q = nn.Linear(
            in_features=embed_dim, 
            out_features=embed_dim,
            bias=bias)

        self.c_attn_kv = nn.Linear(
            in_features=embed_dim, 
            out_features=2 * num_key_value * self.head_dim,
            bias=bias)

        self.c_proj = nn.Linear(
            in_features=embed_dim, 
            out_features=embed_dim, 
            bias=bias)

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

        B, T, E = x.size()

        q = self.c_attn_q(x)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        k = self.c_attn_kv(x)
        k, v = k.split(self.num_key_value * self.head_dim, dim=2)

        k = k.view(B, T, self.num_key_value, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_key_value, self.head_dim).transpose(1, 2)
        q, k = self.rotary(q), self.rotary(k)

        y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=True)
        y = y.transpose(1, 2).contiguous().view(B, T, E)
        y = self.c_proj(y)

        return y


class MLP(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        mlp_dim: int,
    ) -> None:

        super().__init__()

        self.c_fc = nn.Linear(
            in_features=embed_dim,
            out_features=mlp_dim,
            bias=True)
            
        self.c_proj = nn.Linear(
            in_features=mlp_dim,
            out_features=embed_dim,
            bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.c_fc(x)
        h = F.gelu(h, approximate="tanh")
        h = self.c_proj(h)
        return h

class Block(nn.Module):

    def __init__(
        self, 
        config: ModelConfig
    ) -> None:

        super().__init__()

        self.attn = Attention(
            num_heads=config.num_heads,
            embed_dim=config.embed_dim,
            num_key_value=config.num_key_value,
            bias=config.attn_bias,
        )

        self.mlp = MLP(
            embed_dim=config.embed_dim,
            mlp_dim=config.mlp_dim,
        )

        self.norm_1 = nn.RMSNorm(config.embed_dim)
        self.norm_2 = nn.RMSNorm(config.embed_dim)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: torch.Tensor,
    ) -> torch.Tensor:

        x = x + self.attn(self.norm_1(x), attn_mask=attn_mask)
        x = x + self.mlp(self.norm_2(x))
        return x


class BaseTransformer(Transformer):
    def __init__(self, config: ModelConfig):

        super().__init__(config)
        self.model_type = "base"

        self.embed = nn.Embedding(config.vocab_size, config.embed_dim)
        self.norm = nn.RMSNorm(config.embed_dim)
        self.unembed = nn.Linear(config.embed_dim, config.vocab_size, bias=True)

        self.blocks = nn.ModuleList(
            [Block(config) for _ in range(config.num_layers)])

        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if hasattr(module, "bias") and getattr(module, "bias") is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.RMSNorm):
            module.weight.data.fill_(1.0)

    def forward(
        self,
        tokens: torch.Tensor,
        targets: torch.Tensor | None = None,
        stop_at_layer: int | None = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

        loss = None

        def run_stack(
            tok: torch.Tensor, # (B, T)
            stop_at_layer: int | None = None,
        ) -> torch.Tensor:

            attn_mask = make_attention_mask(tok, self.config.eos_token_id)
            h = self.embed(tok) # (B, T, E)
            for i, block in enumerate(self.blocks):
                h = block(h, attn_mask=attn_mask)
                if stop_at_layer is not None and i == stop_at_layer:
                    return h
            h = self.norm(h)
            return self.unembed(h)
        
        logits = run_stack(
            tok=tokens,
            stop_at_layer=stop_at_layer,
        ) # (B, T, V)

        if stop_at_layer is not None:
            return logits, loss

        if targets is not None:
            #calculate CE loss
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.reshape(-1),
                ignore_index=-1000,
                reduction="mean",
            )
       
        return logits, loss


def make_attention_mask(tokens: torch.Tensor, eos_token_id: int) -> torch.Tensor:
    """Causal segmentation-aware mask. Shape (B, 1, T, T), False = masked."""
    seq_len = tokens.shape[1]
    device = tokens.device
    eos_mask = tokens == eos_token_id
    seg_ids = torch.cumsum(eos_mask.int(), dim=1) - eos_mask.int()
    same_segment = seg_ids.unsqueeze(2).eq(seg_ids.unsqueeze(1))
    causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device))
    allowed = same_segment & causal
    return allowed.unsqueeze(1)