from typing import Tuple, Optional, List

import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
import flashinfer

from torch import Tensor
from flashinfer.activation import silu_and_mul

from spec_benchmark.Engine.models.base import ModelArgs, LoRAConfig, GatedLoRALinear, Sampler, StandardKVCache, RMSNorm
from spec_benchmark.profiler import attention_compute_timer, rope_compute_timer, bucket_timer


class FeedForward(nn.Module):
    def __init__(self, config: ModelArgs, lora_config: LoRAConfig):
        super().__init__()
        # Update LoRA config for w13 for fused gate+up
        w13_lora_config = LoRAConfig(rank=2*lora_config.rank, alpha=2*lora_config.alpha, lora_bias=lora_config.lora_bias, use_rslora=lora_config.use_rslora)
        self.w13 = GatedLoRALinear(config.dim, 2*config.intermediate_size, bias=False, lora_config=w13_lora_config)
        self.w2 = GatedLoRALinear(config.intermediate_size, config.dim, bias=False, lora_config=lora_config)
        self.process_group = None
        self._register_load_state_dict_pre_hook(self.load_hook)

    def load_hook(self, state_dict, prefix, *args):
        if prefix + "w1.weight" in state_dict:
            w1 = state_dict.pop(prefix + "w1.weight")
            w3 = state_dict.pop(prefix + "w3.weight")
            state_dict[prefix + "w13.weight"] = torch.cat([w1, w3])
    
    def forward(self, x: Tensor, gate_mask: Optional[Tensor]) -> Tensor:
        y = self.w13(x, gate_mask)
        with bucket_timer("mlp.silu_and_mul"):
            y = silu_and_mul(y)
        y = self.w2(y, gate_mask)
        if self.process_group != None:
            dist.all_reduce(y, group=self.process_group)

        return y


class Attention(nn.Module):
    def __init__(self, config: ModelArgs, lora_config: LoRAConfig):
        super().__init__()
        assert config.dim % config.n_head == 0, "config.dim must be divisible by config.n_head"
        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
        
        # Update LoRA config for wqkv for fused q+k+v
        wqkv_lora_config = LoRAConfig(rank=3*lora_config.rank, alpha=3*lora_config.alpha, lora_bias=lora_config.lora_bias, use_rslora=lora_config.use_rslora)
        self.wqkv = GatedLoRALinear(config.dim, total_head_dim, bias=config.qkv_bias, lora_config=wqkv_lora_config)
        self.wo = GatedLoRALinear(config.head_dim * config.n_head, config.dim, bias=False, lora_config=lora_config)
        
        if config.qk_norm:
            self.q_norm = RMSNorm(config.head_dim, config.norm_eps)
            self.k_norm = RMSNorm(config.head_dim, config.norm_eps)
        else:
            self.q_norm = nn.Identity()
            self.k_norm = nn.Identity()
        
        self.n_head = config.n_head
        self.head_dim = config.head_dim
        self.n_local_heads = config.n_local_heads
        self.dim = self.n_head * self.head_dim
        
        self.kv_cache = None
        self.rope = None
        self.process_group = None
        
        self.attn_prefill = None
        self.attn_draft = None
        self.attn_draft_and_verify = None

    def forward(self, x: Tensor, gate_mask: Optional[Tensor], position_ids: Tensor, kv_append_indptr: Tensor, kv_page_indices: Tensor, kv_page_indptr: Tensor, kv_page_lastlen: Tensor, attn_type: str) -> Tensor:
        bsz, seqlen, _ = x.shape
        kv_size = self.n_local_heads * self.head_dim
        q, k, v = self.wqkv(x, gate_mask).split([self.dim, kv_size, kv_size], dim=-1)
        q = self.q_norm(q.view(bsz, seqlen, self.n_head, self.head_dim))
        k = self.k_norm(k.view(bsz, seqlen, self.n_local_heads, self.head_dim))
        q = q.view(bsz * seqlen, self.n_head, self.head_dim)
        k = k.view(bsz * seqlen, self.n_local_heads, self.head_dim)
        v = v.contiguous().view(bsz * seqlen, self.n_local_heads, self.head_dim)
        with rope_compute_timer():
            q, k = self.rope(q, k, position_ids)
        kv_cache = self.kv_cache.update(k, v, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen)
        if attn_type == "prefill":
            y = self.attn_prefill(q, kv_cache)
        elif attn_type == "draft":
            with attention_compute_timer():
                y = self.attn_draft(q, kv_cache)
        elif attn_type == "draft_and_verify":
            with attention_compute_timer():
                y = self.attn_draft_and_verify(q, kv_cache)
        y = y.contiguous().view(bsz, seqlen, self.dim)
        y = self.wo(y, gate_mask)
        if self.process_group != None:
            dist.all_reduce(y, group = self.process_group)
        return y


class TransformerBlock(nn.Module):
    def __init__(self, config: ModelArgs, lora_config: LoRAConfig):
        super().__init__()
        self.attention = Attention(config, lora_config)
        self.feed_forward = FeedForward(config, lora_config)
        self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
        self.attention_norm = RMSNorm(config.dim, config.norm_eps)
    
    def forward(self, x: Tensor, gate_mask: Optional[Tensor], *args, **kwargs) -> Tensor:
        h = x + self.attention(self.attention_norm(x), gate_mask, *args, **kwargs)
        out = h + self.feed_forward(self.ffn_norm(h), gate_mask)
        return out


class Transformer(nn.Module):
    def __init__(self, config: ModelArgs, lora_config: LoRAConfig):
        super().__init__()
        self.config = config
        self.lora_config = lora_config
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.layers = nn.ModuleList([TransformerBlock(config, lora_config) for _ in range(config.n_layer)])
        self.norm = RMSNorm(config.dim, config.norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

        self.world_size = None
        self.rank = None
        self.process_group = None

    def _setup_rope_kernels(self):
        schema = "(Tensor q, Tensor k, Tensor position_ids) -> (Tensor ropeq, Tensor ropek)"
        rope_kwargs = dict(interleave=True, rope_scale=self.config.scaling_factor, rope_theta=self.config.rope_base)

        if (self.config.high_freq_factor is not None and self.config.low_freq_factor is not None):
            # For Llama-3.1
            rope_kwargs.update(low_freq_factor=self.config.low_freq_factor, high_freq_factor=self.config.high_freq_factor, old_context_len=self.config.original_max_position_embeddings)
            backend = lambda q, k, position_ids: flashinfer.rope.apply_llama31_rope_pos_ids(q, k, position_ids, **rope_kwargs)
        else:
            backend = lambda q, k, position_ids: flashinfer.rope.apply_rope_pos_ids(q, k, position_ids, **rope_kwargs)

        torch.library.define("mylib::rope", schema)
        @torch.library.impl("mylib::rope", "cuda")
        def rope_impl(*args):
            return backend(*args)
        @torch.library.register_fake("mylib::rope")
        def rope_fake(*args):
            q, k = args[0], args[1]
            return torch.empty_like(q), torch.empty_like(k)

    def setup_caches(self, num_pages, page_size, **kwargs):
        self._setup_rope_kernels()
        dtype = self.output.weight.dtype if self.output.weight.dtype == torch.float16 else torch.bfloat16
        for b in self.layers:
            b.attention.kv_cache = StandardKVCache(num_pages, page_size, self.config.n_local_heads, self.config.head_dim, dtype)
            b.attention.attn_prefill = torch.ops.mylib.attn_prefill
            b.attention.attn_draft = torch.ops.mylib.attn_draft
            b.attention.attn_draft_and_verify = torch.ops.mylib.attn_draft_and_verify
            b.attention.rope = torch.ops.mylib.rope

    @classmethod
    def from_name(cls, name: str, lora_config: LoRAConfig):
        config = ModelArgs.from_name(name)
        return cls(config, lora_config)

    def get_tok_embeddings(self):
        return self.tok_embeddings
    
    def get_lm_head(self):
        return self.output

    def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
        """
        Resize ONLY the input token embeddings.
        - Always assumes: pad_to_multiple_of=None, mean_resizing=True
        - Does NOT touch lm_head (self.output) nor self.config.vocab_size.
        """
        old_embeddings: nn.Embedding = self.tok_embeddings

        # No-op if unspecified
        if new_num_tokens is None:
            return old_embeddings
        if int(new_num_tokens) <= 0:
            raise ValueError(f"new_num_tokens must be positive, got {new_num_tokens}")

        old_vocab = int(old_embeddings.num_embeddings)
        emb_dim = int(old_embeddings.embedding_dim)
        if new_num_tokens == old_vocab:
            return old_embeddings  # nothing to do

        device = old_embeddings.weight.device
        dtype = old_embeddings.weight.dtype
        new_embeddings = nn.Embedding(int(new_num_tokens), emb_dim, device=device, dtype=dtype)

        with torch.no_grad():
            n_copy = min(old_vocab, int(new_num_tokens))
            # copy overlapping rows
            new_embeddings.weight[:n_copy].copy_(old_embeddings.weight[:n_copy])

            # expanding → fill new rows with mean of existing embeddings (mean_resizing=True)
            if int(new_num_tokens) > old_vocab and old_vocab > 0:
                mean_vec = old_embeddings.weight[:old_vocab].mean(dim=0, keepdim=True)
                new_embeddings.weight[old_vocab:int(new_num_tokens)].copy_(mean_vec)

            # zero-out pad row if model defines a padding id (HF convention)
            pad_id = getattr(self.config, "pad_token_id", None) # TODO : add pad_token_id to config
            if pad_id is not None and 0 <= pad_id < int(new_num_tokens):
                new_embeddings.weight[pad_id].zero_()

        # Preserve requires_grad
        new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)

        # Swap in and return
        self.tok_embeddings = new_embeddings
        return self.tok_embeddings


    def _may_be_all_gather(self, x: Tensor) -> Tensor:
        if self.process_group != None:
            gathered_x = [torch.empty_like(x) for _ in range(self.world_size)]
            dist.all_gather(gathered_x, x, group=self.process_group)
            x = torch.cat(gathered_x, dim=-1)
        return x


    def _forward_lm_head(self, x: Tensor, gate_mask: Optional[Tensor] = None) -> Tensor:
        if gate_mask is None:
            return self._may_be_all_gather(self.output(x)) # [B, S, V]
        else:
            # forward non-mask tokens only
            bsz, _, dim = x.shape
            non_mask_indices = (gate_mask.view(-1) == 0).nonzero(as_tuple=True)[0]
            x = x.reshape(-1, dim).index_select(0, non_mask_indices)
            return self._may_be_all_gather(self.output(x)).reshape(bsz, non_mask_indices.numel()//bsz, -1)
    

    def forward(self, idx: Tensor, gate_mask: Optional[Tensor], position_ids: Tensor, kv_append_indptr: Tensor, kv_page_indices: Tensor, kv_page_indptr: Tensor, kv_page_lastlen: Tensor, attn_type: str) -> Tensor:
        x = self.tok_embeddings(idx)
        for idx, layer in enumerate(self.layers):
            x = layer(x, gate_mask, position_ids, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen, attn_type)
        x = self.norm(x)
        logits = self._forward_lm_head(x, gate_mask)

        return logits, x


class MTPTransformer(nn.Module):
    def __init__(
        self,
        base_model: Transformer,
    ):
        super().__init__()
        self.base_model = base_model
        self.sampler = Sampler(base_model.config)
    
    def forward(
        self,
        idx: Tensor,
        gate_mask: Optional[Tensor],
        position_ids: Tensor,
        kv_append_indptr: Tensor,
        kv_page_indices: Tensor,
        kv_page_indptr: Tensor,
        kv_page_lastlen: Tensor,
        attn_type: str = "prefill",
    ) -> Tuple[Tensor, Tensor]:
        
        logits, hidden_states = self.base_model(
            idx,
            gate_mask,
            position_ids,
            kv_append_indptr,
            kv_page_indices,
            kv_page_indptr,
            kv_page_lastlen,
            attn_type,
        )

        return logits, hidden_states

    def sampler_forward(self, idx: Tensor, hidden_states: Tensor) -> Tensor:
        prev_embeds = self.base_model.get_tok_embeddings()(idx)
        sampler_inputs = torch.cat([prev_embeds, hidden_states], dim=-1)
        sampler_hidden_states = self.sampler(sampler_inputs)
        sampler_logits_local = self.base_model.get_lm_head()(sampler_hidden_states) # [bsz, 1, local_vocab_size]

        process_group = self.base_model.process_group
        if process_group is None:
            return sampler_logits_local.argmax(dim=-1) # [bsz, 1]
        else:
            val_l, idx_l = sampler_logits_local.max(dim=-1, keepdim=True)
            idx_l = idx_l + self.base_model.vocab_start

            world = dist.get_world_size(process_group)
            vals_g = [torch.empty_like(val_l) for _ in range(world)]
            idxs_g = [torch.empty_like(idx_l) for _ in range(world)]
            dist.all_gather(vals_g, val_l, group=process_group)
            dist.all_gather(idxs_g, idx_l, group=process_group)

            vals_cat = torch.cat(vals_g, dim=-1)            # [B,1,world]
            idxs_cat = torch.cat(idxs_g, dim=-1)            # [B,1,world]
            best = vals_cat.argmax(dim=-1, keepdim=True)    # [B,1,1]
            tok = torch.gather(idxs_cat, -1, best)          # [B,1,1]
            return tok[..., 0]     # [B,1]