import torch
import torch.nn as nn
import torch.distributed as dist
import flashinfer

from torch import Tensor
from flashinfer.activation import silu_and_mul

from .utils import ModelArgs, StandardKVCache
from profiler import attention_compute_timer, rope_compute_timer

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

    def forward(self, x: Tensor) -> Tensor:
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

    def __repr__(self):
        return f"RMSNorm(dim={self.weight.shape[0]}, eps={self.eps})"


class FeedForward(nn.Module):
    def __init__(self, config: ModelArgs) -> None:
        super().__init__()
        self.w13 = nn.Linear(config.dim, 2 * config.intermediate_size, bias=False)
        self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
        self.process_group = None
        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(self, state_dict, prefix, *args):
        if prefix + "w1.weight" in state_dict:
            w1 = state_dict.pop(prefix + "w1.weight")
            w3 = state_dict.pop(prefix + "w3.weight")
            state_dict[prefix + "w13.weight"] = torch.cat([w1, w3])
        
    def forward(self, x: Tensor) -> Tensor:
        y = self.w2(silu_and_mul(self.w13(x)))
        if self.process_group != None:
            dist.all_reduce(y, group=self.process_group)
        return y


class Attention(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()

        assert config.dim % config.n_head == 0
        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
        self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.qkv_bias)

        # head_dim * n_head = dim except for Qwen3-0.6B which head_dim * n_head = dim * 2
        self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
        if config.qk_norm:
            # For Qwen3 family
            self.q_norm = RMSNorm(config.head_dim, config.norm_eps)
            self.k_norm = RMSNorm(config.head_dim, config.norm_eps)
        else:
            self.q_norm = nn.Identity()
            self.k_norm = nn.Identity()
        
        self.kv_cache: nn.Module = None
        self.rope = None

        self.n_head = config.n_head
        self.head_dim = config.head_dim
        self.n_local_heads = config.n_local_heads
        self.dim = self.n_head * self.head_dim

        self.process_group = None
        self._register_load_state_dict_pre_hook(self.load_hook)

        # Attention kernels
        self.attn_prefill = None
        self.attn_verify = None
        self.attn_decode = None

    def load_hook(self, state_dict, prefix, *args):
        if prefix + "wq.weight" in state_dict:
            wq = state_dict.pop(prefix + "wq.weight")
            wk = state_dict.pop(prefix + "wk.weight")
            wv = state_dict.pop(prefix + "wv.weight")
            state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
        if prefix + "wq.bias" in state_dict:
            bq = state_dict.pop(prefix + "wq.bias")
            bk = state_dict.pop(prefix + "wk.bias")
            bv = state_dict.pop(prefix + "wv.bias")
            state_dict[prefix + "wqkv.bias"] = torch.cat([bq, bk, bv])
    
    def forward(self,
        x: Tensor,
        offsets: Tensor,
        kv_append_indptr: Tensor,
        kv_page_indices: Tensor,
        kv_page_indptr: Tensor,
        kv_page_lastlen: Tensor,
        attn_type: str = "decode"
    ) -> Tensor:
        bsz, seqlen, _ = x.shape
        kv_size = self.n_local_heads * self.head_dim

        # QKV projection
        q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
        q = self.q_norm(q.view(bsz, seqlen, self.n_head, self.head_dim))
        k = self.k_norm(k.view(bsz, seqlen, self.n_local_heads, self.head_dim))
        
        q = q.view(bsz * seqlen, self.n_head, self.head_dim)
        k = k.view(bsz * seqlen, self.n_local_heads, self.head_dim)
        v = v.contiguous().view(bsz * seqlen, self.n_local_heads, self.head_dim)
        
        # Apply RoPE and update KV cache
        with rope_compute_timer():
            q, k = self.rope(q, k, kv_append_indptr, offsets)
        kv_cache = self.kv_cache.update(k, v, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen)
        
        with attention_compute_timer():
            # Apply attention kernel
            if attn_type == "prefill":
                y = self.attn_prefill.run(q, kv_cache)
            elif attn_type == "verify":
                y = self.attn_verify.run(q, kv_cache)
            elif attn_type == "decode":
                y = self.attn_decode.run(q, kv_cache)
        
        # Output projection
        y = y.contiguous().view(bsz, seqlen, self.dim)
        y = self.wo(y)

        # All reduce
        if self.process_group != None:
            dist.all_reduce(y, group = self.process_group)
        return y


class TransformerBlock(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        self.attention = Attention(config)
        self.feed_forward = FeedForward(config)
        self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
        self.attention_norm = RMSNorm(config.dim, config.norm_eps)

    def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
        h = x + self.attention(self.attention_norm(x), *args, **kwargs)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out
        

class Transformer(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        self.config = config
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
        self.norm = RMSNorm(config.dim, config.norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

        self.world_size = None
        self.rank = None
        self.process_group = None

    def _setup_rope_kernels(self):
        try:
            torch.library.define(
                "mylib::rope",
                "(Tensor q, Tensor k, Tensor indptr, Tensor offsets) -> (Tensor ropeq, Tensor ropek)",
            )
        except:
            # RoPE kernel is already defined
            return
        
        if (self.config.high_freq_factor is not None) and (self.config.low_freq_factor is not None):
            @torch.library.impl("mylib::rope", "cuda")
            def rope(q, k, indptr, offsets):
                return flashinfer.rope.apply_llama31_rope(q, k, indptr, offsets, interleave=True, rope_scale=self.config.scaling_factor, rope_theta=self.config.rope_base, low_freq_factor=self.config.low_freq_factor, high_freq_factor=self.config.high_freq_factor, old_context_len=self.config.original_max_position_embeddings)
        else:
            @torch.library.impl("mylib::rope", "cuda")
            def rope(q, k, indptr, offsets):
                return flashinfer.rope.apply_rope(q, k, indptr, offsets, interleave=True, rope_scale=self.config.scaling_factor, rope_theta=self.config.rope_base)

        @torch.library.register_fake("mylib::rope")
        def rope_abstract(q, k, indptr, offsets):
            return torch.empty_like(q), torch.empty_like(k)

    def setup_caches(self, num_pages, page_size, attn_wrappers):
        self._setup_rope_kernels()
        dtype = self.output.weight.dtype

        for b in self.layers:
            b.attention.kv_cache = StandardKVCache(num_pages, page_size, self.config.n_local_heads, self.config.head_dim, dtype)
            b.attention.attn_prefill = attn_wrappers["prefill"] if "prefill" in attn_wrappers else None
            b.attention.attn_verify = attn_wrappers["verify"] if "verify" in attn_wrappers else None # valid only the model acts as a verifier
            b.attention.attn_decode = attn_wrappers["decode"] if "decode" in attn_wrappers else None # valid only the model acts as a standalone model
            b.attention.rope = torch.ops.mylib.rope

    @classmethod
    def from_name(cls, name: str):
        config = ModelArgs.from_name(name)
        return cls(config)
    
    def forward(self,
        idx: Tensor,
        input_pos: Tensor,
        kv_append_indptr: Tensor,
        kv_page_indices: Tensor,
        kv_page_indptr: Tensor,
        kv_page_lastlen: Tensor,
        attn_type: str = "decode"
    ) -> Tensor:
        x = self.tok_embeddings(idx)
        for layer in self.layers:
            x = layer(x, input_pos, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen, attn_type)
        x = self.norm(x)
        logits = self.output(x)
        
        if self.process_group != None:
            gathered_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
            dist.all_gather(gathered_logits, logits, group=self.process_group)
            logits = torch.cat(gathered_logits, dim=-1)

        return logits