from pathlib import Path

import torch
import flashinfer

from spec_benchmark.Engine.backends import BaseLMBackend
from spec_benchmark.Engine.models.standard_model import Transformer
from spec_benchmark.Engine.utils import register_custom_attn_op, sample
from spec_benchmark.profiler import backend_bucket_timer

class StandardLMBackend(BaseLMBackend):
    def __init__(
        self,
        dtype: torch.dtype = torch.bfloat16,
        device: str = "cuda:0",
    ) -> None:
        self.dtype = dtype
        self.device = device
        self.prefill_forward = lambda model, x, input_pos, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen: model(x, input_pos, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen, attn_type='prefill')
        self.decode_forward = lambda model, x, input_pos, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen: model(x, input_pos, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen, attn_type='decode')


    def load_model(self, model_name: str, target_checkpoint: Path, use_tp: bool = False, rank_group = None, group = None):
        with torch.device('meta'):
            model = Transformer.from_name(model_name)
        
        checkpoint = torch.load(str(target_checkpoint), mmap=True, weights_only=True)
        if "model" in checkpoint: checkpoint = checkpoint["model"]
        model.load_state_dict(checkpoint, assign=True, strict=True)

        if use_tp:
            from spec_benchmark.Engine.utils import apply_tp
            print("Applying tensor parallel to model ...")
            apply_tp(model, rank_group, group=group)
        
        model = model.to(device=self.device, dtype=self.dtype)
        self.model = model.eval()
        

    def setup_caches(self, max_batch_size=1, max_seq_length=2048, page_size=16, prefill_chunk_size=128):
        super().setup_caches(max_batch_size, max_seq_length + 1, page_size, prefill_chunk_size)
        
        self.prefill_buffer = torch.empty(3 * 128 * 1024 * 1024, dtype=torch.uint8, device=self.device)
        self.decode_buffer = torch.empty(3 * 128 * 1024 * 1024, dtype=torch.uint8, device=self.device)
        self.attn_wrappers = {
            "prefill": flashinfer.BatchPrefillWithPagedKVCacheWrapper(self.prefill_buffer, "NHD", use_cuda_graph=True,
                                                                      qo_indptr_buf=self.qo_indptr * prefill_chunk_size, 
                                                                      paged_kv_indptr_buf=self.paged_kv_indptr, 
                                                                      paged_kv_indices_buf=self.paged_kv_indices, 
                                                                      paged_kv_last_page_len_buf=self.paged_kv_last_page_len),
            "decode": flashinfer.BatchPrefillWithPagedKVCacheWrapper(self.decode_buffer, "NHD", use_cuda_graph=True,
                                                                      qo_indptr_buf=self.qo_indptr, 
                                                                      paged_kv_indptr_buf=self.paged_kv_indptr, 
                                                                      paged_kv_indices_buf=self.paged_kv_indices, 
                                                                      paged_kv_last_page_len_buf=self.paged_kv_last_page_len),
        }
        register_custom_attn_op("mylib::attn_prefill", self.attn_wrappers["prefill"])
        register_custom_attn_op("mylib::attn_decode", self.attn_wrappers["decode"])
        
        with torch.device(self.device):
            self.model.setup_caches(num_pages=self.max_num_pages, page_size=self.page_size)
    

    def setup_sampling_params(self, temperature=0.0, top_k=0, top_p=0.95):
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.greedy = self.temperature == 0.0
    

    def compile(self):
        super().compile()
        self.prefill_forward = torch.compile(self.prefill_forward)
        self.decode_forward = torch.compile(self.decode_forward)
    

    def encode(self, input_ids: torch.LongTensor, query_lens: torch.Tensor):
        self.clear_kv()
        logits = None
        bsz, seq_len = input_ids.shape
        assert seq_len % self.prefill_chunk_size == 0, f"The sequence length must be divisible by the prefill chunk size, but got seq_len={seq_len} and prefill_chunk_size={self.prefill_chunk_size}"
        
        last_logits = None # Lazy initialization for Tensor Parallel
        logit_recorded = torch.zeros(bsz, dtype=torch.bool, device=self.device)

        chunk_size = self.prefill_chunk_size
        num_chunks = seq_len // chunk_size

        for i in range(num_chunks):
            chunk_input_ids = input_ids[:, i*chunk_size:(i+1)*chunk_size]
            chunk_query_lens = query_lens - (i * chunk_size)
            chunk_query_lens = torch.clamp(chunk_query_lens, min=0, max=chunk_size)
            
            # if every query in chunk only has pad tokens, skip the chunk
            if torch.all(chunk_query_lens == 0): continue
            
            self.pre_encode()
            with torch.inference_mode():
                logits = self.prefill_forward(
                    model=self.model,
                    x=chunk_input_ids,
                    input_pos=torch.full((bsz,), i*chunk_size, dtype=torch.int32, device=self.device),
                    kv_append_indptr=self.qo_indptr*chunk_size,
                    kv_page_indices=self.paged_kv_indices,
                    kv_page_indptr=self.paged_kv_indptr,
                    kv_page_lastlen=self.paged_kv_last_page_len,
                )

            if last_logits is None:
                last_logits = torch.full((bsz, logits.shape[-1]), float('nan'), device=self.device, dtype=self.dtype)
            
            target_indices_in_chunk = chunk_query_lens - 1
            finishes_in_this_chunk = (query_lens > i*chunk_size) & (query_lens <= (i+1)*chunk_size)
            target_sequences_mask = finishes_in_this_chunk & (~logit_recorded)
            
            target_batch_indices = torch.where(target_sequences_mask)[0]
            if target_batch_indices.numel() > 0:
                indices_in_chunk_to_grab = target_indices_in_chunk[target_batch_indices]
                last_logits[target_batch_indices] = logits[target_batch_indices, indices_in_chunk_to_grab, :]
                logit_recorded[target_batch_indices] = True
            
            exists_padding = (chunk_query_lens < chunk_size)
            if exists_padding.any():
                self.delete_kv(chunk_size - chunk_query_lens)
        
        if torch.isnan(last_logits).any():
            print("Warning: Found NaN in last_logits. Replacing with zeros. "
                "This might occur for sequences with initial query_len=0.")
            last_logits = torch.nan_to_num(last_logits, nan=0.0)
        
        assert not torch.isnan(last_logits).any(), "Found NaN in last_logits."
        return sample(last_logits, top_p=self.top_p, top_k=self.top_k, temperature=self.temperature) # [bsz, 1]
    

    def pre_encode(self):
        self.insert_kv(self.prefill_chunk_size)
        self.attn_wrappers["prefill"].plan(
            qo_indptr=self.qo_indptr*self.prefill_chunk_size,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            num_qo_heads=self.model.config.n_head, 
            num_kv_heads=self.model.config.n_local_heads, 
            head_dim_qk=self.model.config.head_dim, 
            page_size=self.page_size,
            q_data_type=self.dtype, 
            causal=True,
        )


    def decode(self, input_ids: torch.LongTensor):
        dec_len = input_ids.shape[1]
        assert dec_len == 1, "The input sequence length must be equal to 1, but got dec_len={dec_len}"
        
        self.pre_decode()
        with torch.inference_mode():
            logits = self.decode_forward(
                model=self.model, 
                x=input_ids,
                input_pos=self.cachelens - 1,
                kv_append_indptr=self.qo_indptr,
                kv_page_indices=self.paged_kv_indices,
                kv_page_indptr=self.paged_kv_indptr,
                kv_page_lastlen=self.paged_kv_last_page_len,
            )

        with backend_bucket_timer("backend.decode.sample"):
            next_tokens = sample(logits, top_p=self.top_p, top_k=self.top_k, temperature=self.temperature)
        return next_tokens


    def pre_decode(self):
        self.insert_kv(1)
        self.attn_wrappers["decode"].plan(
            qo_indptr=self.qo_indptr,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            num_qo_heads=self.model.config.n_head, 
            num_kv_heads=self.model.config.n_local_heads, 
            head_dim_qk=self.model.config.head_dim, 
            page_size=self.page_size, 
            q_data_type=self.dtype, 
            causal=True,
        )