import torch
from flashinfer import BatchPrefillWithPagedKVCacheWrapper

from FlashInfer.model import Transformer
from FlashInfer.utils import (
    sample, PageManager
)

class FlashInferBackend:
    def __init__(self, dtype = torch.float16, device: str = "cuda:0") -> None:
        self.dtype = dtype
        self.device = device
        self.cachelens = None
        self.forward = lambda model, x, input_pos, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen, attn_type: model(x, input_pos, kv_append_indptr, kv_page_indices, kv_page_indptr, kv_page_lastlen, attn_type)

    def load_model(self, model_name, checkpoint_path, **kwargs):
        with torch.device('meta'):
            model = Transformer.from_name(model_name)
        
        checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
        if "model" in checkpoint:
            checkpoint = checkpoint["model"]
        model.load_state_dict(checkpoint, assign=True, strict=True)

        if kwargs.get("use_tp", False):
            from FlashInfer.utils import apply_tp
            print("Applying tensor parallel to model ...")
            apply_tp(model, kwargs.get("rank_group", None), group=kwargs.get("group", None))
        
        model = model.to(device=self.device, dtype=self.dtype)
        self.model = model.eval()

    def setup_caches(self, max_batch_size=1, max_seq_length=2048, dec_len=1, prefill_chunk_size=64, non_causal_mask=False, page_size=16):
        self.batch_size = max_batch_size
        self.page_size = page_size
        self.dec_len = dec_len
        self.prefill_chunk_size = prefill_chunk_size
        self.non_causal_mask = non_causal_mask

        self.cachelens = torch.zeros(max_batch_size, dtype=torch.int32, device=self.device)
        self.max_num_pages = max_batch_size * ((max_seq_length + self.page_size - 1) // self.page_size)
        self.max_num_pages_per_request = self.max_num_pages // max_batch_size
        self.num_pages_per_request = torch.zeros(max_batch_size, device=self.device, dtype=torch.int32)
        
        self.qo_indptr = torch.arange(max_batch_size+1, dtype=torch.int32, device=self.device)
        self.paged_kv_indptr = torch.arange(max_batch_size+1, dtype=torch.int32, device=self.device)
        self.paged_kv_indices = torch.empty(self.max_num_pages, dtype=torch.int32, device=self.device)
        self.paged_kv_last_page_len = torch.zeros((max_batch_size), dtype=torch.int32, device=self.device)
        self.page_manager = PageManager(max_batch_size, self.max_num_pages_per_request, self.device)
        
        self.attn_buffers = {
            "prefill": torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=self.device),
            "decode": torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=self.device),
        }
        self.attn_wrappers = {
            "prefill": BatchPrefillWithPagedKVCacheWrapper(
                        self.attn_buffers["prefill"], "NHD", use_cuda_graph=False,
                        qo_indptr_buf=self.qo_indptr * self.prefill_chunk_size, 
                        paged_kv_indptr_buf=self.paged_kv_indptr, 
                        paged_kv_indices_buf=self.paged_kv_indices, 
                        paged_kv_last_page_len_buf=self.paged_kv_last_page_len),
        }

        if self.non_causal_mask:
            self.custom_mask_buf = torch.empty(max_batch_size * self.dec_len * max_seq_length // 8 + 1, dtype=torch.uint8, device=self.device)
            self.mask_indptr = torch.arange(max_batch_size+1, dtype=torch.int32, device=self.device)
            self.attn_wrappers["decode"] = BatchPrefillWithPagedKVCacheWrapper(
                                            self.attn_buffers["decode"], "NHD", use_cuda_graph=True,
                                            qo_indptr_buf=self.qo_indptr * self.dec_len,
                                            paged_kv_indptr_buf=self.paged_kv_indptr,
                                            paged_kv_indices_buf=self.paged_kv_indices,
                                            paged_kv_last_page_len_buf=self.paged_kv_last_page_len,
                                            custom_mask_buf=self.custom_mask_buf,
                                            mask_indptr_buf=self.mask_indptr)
        else:
            self.attn_wrappers["decode"] = BatchPrefillWithPagedKVCacheWrapper(
                                            self.attn_buffers["decode"], "NHD", use_cuda_graph=True,
                                            qo_indptr_buf=self.qo_indptr * self.dec_len,
                                            paged_kv_indptr_buf=self.paged_kv_indptr,
                                            paged_kv_indices_buf=self.paged_kv_indices,
                                            paged_kv_last_page_len_buf=self.paged_kv_last_page_len)
        
        with torch.device(self.device):
            self.model.setup_caches(num_pages=self.max_num_pages, page_size=self.page_size, attn_wrappers=self.attn_wrappers)

    def setup_sampling_params(self, temperature=0.0, top_k=0, top_p=0.95):
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.greedy = self.temperature == 0.0
    
    def compile(self):
        import torch._dynamo.config
        import torch._inductor.config
        torch._inductor.config.coordinate_descent_tuning = True
        torch._inductor.config.triton.unique_kernel_names = True
        torch._inductor.config.fx_graph_cache = True
        torch._functorch.config.enable_autograd_cache = True

        self.forward = torch.compile(self.forward)

    def decode(self, input_ids: torch.LongTensor, custom_mask: torch.Tensor = None):
        dec_len = input_ids.shape[1]
        assert dec_len == self.dec_len, "The input sequence length must be equal to the decode length, but got dec_len={dec_len} and self.dec_len={self.dec_len}"
        
        self.pre_decode(dec_len, custom_mask)
        with torch.inference_mode():
            logits = self.forward(
                model=self.model, 
                x=input_ids,
                input_pos=self.cachelens - dec_len,
                kv_append_indptr=self.qo_indptr * dec_len,
                kv_page_indices=self.paged_kv_indices,
                kv_page_indptr=self.paged_kv_indptr,
                kv_page_lastlen=self.paged_kv_last_page_len,
                attn_type="decode",
            )
        return logits

    def pre_decode(self, dec_len: int, custom_mask: torch.Tensor = None):
        if self.non_causal_mask and custom_mask is not None:
            mask = torch.cat([torch.ones((self.batch_size, dec_len, self.cachelens[0].item()), dtype=torch.bool, device=self.device), custom_mask], dim=-1)
            mask = mask.flatten().contiguous()
        else:
            mask = None
        
        self.insert_kv(dec_len)
        self.attn_wrappers["decode"].plan(
            qo_indptr=self.qo_indptr * dec_len,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            num_qo_heads=self.model.config.n_head,
            num_kv_heads=self.model.config.n_local_heads,
            head_dim_qk=self.model.config.head_dim,
            page_size=self.page_size,
            q_data_type=self.dtype,
            causal=True,
            custom_mask=mask,
        )
    
    def prefill(self, input_ids: torch.LongTensor, query_lens: torch.Tensor):
        self.clear_kv()
        logits = None
        bsz, seq_len = input_ids.shape
        assert seq_len % self.prefill_chunk_size == 0, f"The sequence length must be divisible by the prefill chunk size, but got seq_len={seq_len} and prefill_chunk_size={self.prefill_chunk_size}"
        
        last_logits = None # Lazy initialization for Tensor Parallel
        logit_recorded = torch.zeros(bsz, dtype=torch.bool, device=self.device)

        chunk_size = self.prefill_chunk_size
        num_chunks = seq_len // chunk_size

        for i in range(num_chunks):
            chunk_input_ids = input_ids[:, i*chunk_size:(i+1)*chunk_size]

            chunk_query_lens = query_lens - (i * chunk_size)
            chunk_query_lens = torch.clamp(chunk_query_lens, min=0, max=chunk_size)
            
            # if every query in chunk only has pad tokens, skip the chunk
            if torch.all(chunk_query_lens == 0):
                continue
            
            self.pre_prefill()
            with torch.inference_mode():
                logits = self.forward(
                    model=self.model,
                    x=chunk_input_ids,
                    input_pos=torch.full((bsz,), i*chunk_size, dtype=torch.int32, device=self.device),
                    kv_append_indptr=self.qo_indptr*chunk_size,
                    kv_page_indices=self.paged_kv_indices,
                    kv_page_indptr=self.paged_kv_indptr,
                    kv_page_lastlen=self.paged_kv_last_page_len,
                    attn_type="prefill",
                )

            if last_logits is None:
                last_logits = torch.full((bsz, logits.shape[-1]), float('nan'), device=self.device, dtype=self.dtype)
            
            target_indices_in_chunk = chunk_query_lens - 1
            finishes_in_this_chunk = (query_lens > i*chunk_size) & (query_lens <= (i+1)*chunk_size)
            target_sequences_mask = finishes_in_this_chunk & (~logit_recorded)
            
            target_batch_indices = torch.where(target_sequences_mask)[0]
            if target_batch_indices.numel() > 0:
                indices_in_chunk_to_grab = target_indices_in_chunk[target_batch_indices]
                logits_to_store = logits[target_batch_indices, indices_in_chunk_to_grab, :] # Shape: [num_target_seqs, vocab_size]
                
                last_logits[target_batch_indices] = logits_to_store
                logit_recorded[target_batch_indices] = True
            
            exists_padding = (chunk_query_lens < chunk_size)
            if exists_padding.any():
                self.delete_kv(chunk_size - chunk_query_lens)
        
        if torch.isnan(last_logits).any():
            print("Warning: Found NaN in last_logits. Replacing with zeros. "
                "This might occur for sequences with initial query_len=0.")
            last_logits = torch.nan_to_num(last_logits, nan=0.0)
        
        return last_logits
    
    def pre_prefill(self):
        self.insert_kv(self.prefill_chunk_size)
        self.attn_wrappers["prefill"].plan(
            qo_indptr=self.qo_indptr*self.prefill_chunk_size,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_last_page_len=self.paged_kv_last_page_len,
            num_qo_heads=self.model.config.n_head,
            num_kv_heads=self.model.config.n_local_heads,
            head_dim_qk=self.model.config.head_dim,
            page_size=self.page_size,
            q_data_type=self.dtype,
            causal=True,
        )

    def _as_len_tensor(self, lens):
        if isinstance(lens, torch.Tensor):
            t = lens.to(device=self.device, dtype=self.paged_kv_last_page_len.dtype)
        else:
            t = torch.tensor(lens, device=self.device, dtype=self.paged_kv_last_page_len.dtype)
        if t.dim() == 0:
            t = t.expand_as(self.paged_kv_last_page_len)
        return t

    def insert_kv(self, dec_lens):
        dec = self._as_len_tensor(dec_lens)
        if torch.all(dec <= 0):
            return

        old_full = self.num_pages_per_request.clone() - 1
        old_tail = self.paged_kv_last_page_len.clone()
        ps = self.page_size

        total_after = old_full * ps + old_tail + dec
        new_full = torch.where(
            total_after > 0,
            torch.div(total_after - 1, ps, rounding_mode='floor').to(old_full.dtype),
            torch.zeros_like(old_full),
        )
        new_tail = torch.where(
            total_after > 0,
            (((total_after - 1) % ps) + 1).to(old_tail.dtype),
            torch.zeros_like(old_tail),
        )

        add_pages = (new_full - old_full).clamp_min(0).to(torch.int32)
        if add_pages.max().item() > 0:
            self.paged_kv_indptr, self.paged_kv_indices = self.page_manager.allocate_counts(
                add_pages, self.paged_kv_indices, self.paged_kv_indptr
            )
            self.num_pages_per_request += add_pages  # sync with new_full

        self.paged_kv_last_page_len = new_tail
        self.cachelens = (self.cachelens + dec).clamp_min(0)

    def delete_kv(self, del_lens):
        dec = self._as_len_tensor(del_lens)
        if torch.all(dec <= 0):
            return

        old_full = self.num_pages_per_request.clone() - 1
        old_tail = self.paged_kv_last_page_len.clone()
        ps = self.page_size

        total_before = old_full * ps + old_tail
        total_after = (total_before - dec).clamp_min(0)

        new_full = torch.where(
            total_after > 0,
            torch.div(total_after - 1, ps, rounding_mode='floor').to(old_full.dtype),
            torch.zeros_like(old_full),
        )
        new_tail = torch.where(
            total_after > 0,
            (((total_after - 1) % ps) + 1).to(old_tail.dtype),
            torch.zeros_like(old_tail),
        )

        free_pages = (old_full - new_full).clamp_min(0).to(torch.int32)
        if free_pages.max().item() > 0:
            self.paged_kv_indptr, self.paged_kv_indices = self.page_manager.free_counts(
                free_pages, self.paged_kv_indices, self.paged_kv_indptr
            )
            self.num_pages_per_request -= free_pages  # sync with new_full

        self.paged_kv_last_page_len = new_tail
        self.cachelens = (self.cachelens - dec).clamp_min(0)
    
    def clear_kv(self):
        for b in self.model.layers:
            b.attention.kv_cache.kv_cache.zero_()
        
        self.cachelens = torch.zeros(self.batch_size, dtype=torch.int32, device=self.device)
        self.qo_indptr = torch.arange(self.batch_size+1, dtype=torch.int32, device=self.device)
        
        self.page_manager.reset()
        self.num_pages_per_request = torch.ones((self.batch_size), device=self.device, dtype=torch.int32)
        self.paged_kv_indptr = torch.arange(self.batch_size+1, dtype=torch.int32, device=self.device)
        self.paged_kv_indices = self.page_manager.allocate(torch.arange(self.batch_size, dtype=torch.int32, device=self.device))
        self.paged_kv_last_page_len = torch.zeros((self.batch_size), dtype=torch.int32, device=self.device)