import torch
import torch.nn as nn
from torch import Tensor
import torch.distributed as dist
import flashinfer
from flashinfer.activation import silu_and_mul
import torch.nn.functional as F

from .config import ModelArgs

class FeedForward(nn.Module):
    def __init__(self, config: ModelArgs) -> None:
        super().__init__()
        self.w13 = nn.Linear(config.dim, 2 * config.intermediate_size, bias=False)
        self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
        self.process_group = None
        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(self, state_dict, prefix, *args):
        if prefix + "w1.weight" in state_dict:
            w1 = state_dict.pop(prefix + "w1.weight")
            w3 = state_dict.pop(prefix + "w3.weight")
            state_dict[prefix + "w13.weight"] = torch.cat([w1, w3])
        
    def forward(self, x: Tensor) -> Tensor:
        y = self.w2(silu_and_mul(self.w13(x)))
        if self.process_group != None:
            dist.all_reduce(y, group=self.process_group)
        return y

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

    def forward(self, x: Tensor) -> Tensor:
        output = self._norm(x.float()).type_as(x)
        return output * self.weight
    
class BaseAttention(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        assert config.dim % config.n_head == 0
        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
        self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.qkv_bias)

        # Typically, head_dim * n_head = dim. However, in Qwen3-0.6B, head_dim * n_head = dim * 2.
        self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
        if config.qk_norm:
            self.q_norm = RMSNorm(config.head_dim, config.norm_eps)
            self.k_norm = RMSNorm(config.head_dim, config.norm_eps)
        else:
            self.q_norm = nn.Identity()
            self.k_norm = nn.Identity()
        
        self.kv_cache: nn.Module = None
        self.rope = None

        self.n_head = config.n_head
        self.head_dim = config.head_dim
        self.n_local_heads = config.n_local_heads
        self.dim = self.n_head * self.head_dim

        self.process_group = None
        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(self, state_dict, prefix, *args):
        if prefix + "wq.weight" in state_dict:
            wq = state_dict.pop(prefix + "wq.weight")
            wk = state_dict.pop(prefix + "wk.weight")
            wv = state_dict.pop(prefix + "wv.weight")
            state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
        if prefix + "wq.bias" in state_dict:
            bq = state_dict.pop(prefix + "wq.bias")
            bk = state_dict.pop(prefix + "wk.bias")
            bv = state_dict.pop(prefix + "wv.bias")
            state_dict[prefix + "wqkv.bias"] = torch.cat([bq, bk, bv])
    
    def forward(self, *args, **kwargs):
        raise NotImplementedError("Each Attention subclass must implement its own forward pass.")

class BaseTransformerBlock(nn.Module):
    def __init__(self, config: ModelArgs, attention_module: BaseAttention):
        super().__init__()
        self.attention = attention_module
        self.feed_forward = FeedForward(config)
        self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
        self.attention_norm = RMSNorm(config.dim, config.norm_eps)
    
    def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
        h = x + self.attention(self.attention_norm(x), *args, **kwargs)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out
    
    def prefill(self, x: Tensor, *args, **kwargs) -> Tensor:
        h = x + self.attention.prefill(self.attention_norm(x), *args, **kwargs)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

class BaseTransformer(nn.Module):
    def __init__(self, config: ModelArgs) -> None:
        super().__init__()
        self.config = config
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.layers = self._build_layers()
        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
        
        self.world_size = None
        self.rank = None
        self.process_group = None

    def _build_layers(self) -> nn.ModuleList:
        attention_class = self._get_attention_class()
        return nn.ModuleList(
            BaseTransformerBlock(self.config, attention_class(self.config)) 
            for _ in range(self.config.n_layer)
        )

    def _get_attention_class(self) -> type:
        raise NotImplementedError("Each Transformer subclass must implement its own _get_attention_class method.")

    def _setup_rope_kernels(self, use_position_ids: bool = False):
        schema = "(Tensor q, Tensor k, Tensor position_ids) -> (Tensor ropeq, Tensor ropek)" if use_position_ids else "(Tensor q, Tensor k, Tensor indptr, Tensor offsets) -> (Tensor ropeq, Tensor ropek)"
        rope_kwargs = dict(interleave=True, rope_scale=self.config.scaling_factor, rope_theta=self.config.rope_base)

        if (self.config.high_freq_factor is not None) and (self.config.low_freq_factor is not None):
            rope_kwargs.update(low_freq_factor=self.config.low_freq_factor, high_freq_factor=self.config.high_freq_factor, old_context_len=self.config.original_max_position_embeddings)
            if use_position_ids:
                backend = lambda q, k, position_ids: flashinfer.rope.apply_llama31_rope_pos_ids(q, k, position_ids, **rope_kwargs)
            else:
                backend = lambda q, k, indptr, offsets: flashinfer.rope.apply_llama31_rope(q, k, indptr, offsets, **rope_kwargs)
        else:
            if use_position_ids:
                backend = lambda q, k, position_ids: flashinfer.rope.apply_rope_pos_ids(q, k, position_ids, **rope_kwargs)
            else:
                backend = lambda q, k, indptr, offsets: flashinfer.rope.apply_rope(q, k, indptr, offsets, **rope_kwargs)
        
        torch.library.define("mylib::rope", schema)
        @torch.library.impl("mylib::rope", "cuda")
        def rope_impl(*args):
            return backend(*args)
        @torch.library.register_fake("mylib::rope")
        def rope_fake(*args):
            q, k = args[0], args[1]
            return torch.empty_like(q), torch.empty_like(k)

    def setup_caches(self, num_pages, page_size, **kwargs):
        raise NotImplementedError("Each Transformer subclass must implement its own setup_caches method.")

    def forward(self, *args, **kwargs):
        raise NotImplementedError("Each Transformer subclass must implement its own forward pass.")

    @classmethod
    def from_name(cls, name: str):
        config = ModelArgs.from_name(name)
        return cls(config)