import torch
from torch import Tensor
import torch.distributed as dist

from spec_benchmark.Engine.models.base.model import BaseAttention, BaseTransformer
from spec_benchmark.Engine.models.base.kv_cache import StandardKVCache
from spec_benchmark.profiler import attention_compute_timer, rope_compute_timer

class Attention(BaseAttention):
    def __init__(self, config):
        super().__init__(config)
        self.attn_prefill = None
        self.attn_verify = None
        self.attn_decode = None
    
    def forward(self, x: Tensor, offsets: Tensor, kv_append_indptr: Tensor, kv_page_indices: Tensor, kv_page_indptr: Tensor, kv_page_lastlen: Tensor, attn_type: str = "decode") -> Tensor:
        bsz, seqlen, _ = x.shape
        kv_size = self.n_local_heads * self.head_dim
        q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
        q = self.q_norm(q.view(bsz, seqlen, self.n_head, self.head_dim))
        k = self.k_norm(k.view(bsz, seqlen, self.n_local_heads, self.head_dim))
        q = q.view(bsz * seqlen, self.n_head, self.head_dim)
        k = k.view(bsz * seqlen, self.n_local_heads, self.head_dim)
        v = v.contiguous().view(bsz * seqlen, self.n_local_heads, self.head_dim)
        with rope_compute_timer():
            q, k = self.rope(q, k, kv_append_indptr, offsets)
        kv_cache = self.kv_cache.update(k, v, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen)
        if attn_type == "prefill":
            with attention_compute_timer():
                y = self.attn_prefill(q, kv_cache)
        elif attn_type == "verify":
            with attention_compute_timer():
                y = self.attn_verify(q, kv_cache)
        elif attn_type == "decode":
            with attention_compute_timer():
                y = self.attn_decode(q, kv_cache)
        y = y.contiguous().view(bsz, seqlen, self.dim)
        y = self.wo(y)
        if self.process_group != None:
            dist.all_reduce(y, group = self.process_group)
        return y

class Transformer(BaseTransformer):
    def _get_attention_class(self) -> type:
        return Attention

    def setup_caches(self, num_pages, page_size):
        super()._setup_rope_kernels()
        dtype = self.output.weight.dtype if self.output.weight.dtype == torch.float16 else torch.bfloat16

        for b in self.layers:
            b.attention.kv_cache = StandardKVCache(num_pages, page_size, self.config.n_local_heads, self.config.head_dim, dtype)
            b.attention.attn_prefill = torch.ops.mylib.attn_prefill
            b.attention.attn_verify = torch.ops.mylib.attn_verify if hasattr(torch.ops.mylib, "attn_verify") else None # valid only the model acts as a verifier
            b.attention.attn_decode = torch.ops.mylib.attn_decode if hasattr(torch.ops.mylib, "attn_decode") else None # valid only the model acts as a standalone model
            b.attention.rope = torch.ops.mylib.rope
    
    def forward(self, idx: Tensor, input_pos: Tensor, kv_append_indptr: Tensor, kv_page_indices: Tensor, kv_page_indptr: Tensor, kv_page_lastlen: Tensor, attn_type: str = "decode") -> Tensor:
        x = self.tok_embeddings(idx)
        for layer in self.layers:
            x = layer(x, input_pos, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen, attn_type)
        x = self.norm(x)
        logits = self.output(x)
        
        if self.process_group != None:
            gathered_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
            dist.all_gather(gathered_logits, logits, group=self.process_group)
            logits = torch.cat(gathered_logits, dim=-1)

        return logits