import torch

from spec_benchmark.Engine.utils import (
    PageManager,
    load_model,
)

class BaseLMBackend:
    def __init__(self, dtype = torch.bfloat16, device: str = "cuda:0") -> None:
        self.dtype = dtype
        self.device = device
        self.cachelens = None
        
    def load_model(self, model_type, target_checkpoint, drafter_checkpoint=None, use_tp=False, rank_group=None, group=None, use_tp_draft=False, rank_group_draft=None, group_draft=None):
        self.model = load_model(
            model_type=model_type,
            checkpoint_path=target_checkpoint,
            drafter_checkpoint_path=drafter_checkpoint,
            device=self.device,
            precision=self.dtype,
            use_tp=use_tp,
            rank_group=rank_group,
            group=group,
            use_tp_draft=use_tp_draft,
            rank_group_draft=rank_group_draft,
            group_draft=group_draft,
        )

    def setup_caches(self, max_batch_size=1, max_seq_length=2048, page_size=16, prefill_chunk_size=128, **kwargs):
        self.batch_size = max_batch_size
        self.page_size = page_size
        self.prefill_chunk_size = prefill_chunk_size

        self.cachelens = torch.zeros(max_batch_size, dtype=torch.int32, device=self.device)
        self.max_num_pages = max_batch_size * ((max_seq_length + self.page_size - 1) // self.page_size)
        self.max_num_pages_per_request = self.max_num_pages // max_batch_size
        self.num_pages_per_request = torch.zeros(max_batch_size, device=self.device, dtype=torch.int32)
        
        self.qo_indptr = torch.arange(max_batch_size+1, dtype=torch.int32, device=self.device)
        self.paged_kv_indptr = torch.arange(max_batch_size+1, dtype=torch.int32, device=self.device)
        self.paged_kv_indices = torch.empty(self.max_num_pages, dtype=torch.int32, device=self.device)
        self.paged_kv_last_page_len = torch.zeros((max_batch_size), dtype=torch.int32, device=self.device)
        self.page_manager = PageManager(max_batch_size, self.max_num_pages_per_request, self.device)

    def setup_sampling_params(self, temperature=0.0, top_k=0, top_p=0.95):
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.greedy = self.temperature == 0.0
    
    def compile(self):
        import torch._dynamo.config
        import torch._inductor.config
        torch._inductor.config.coordinate_descent_tuning = True
        torch._inductor.config.triton.unique_kernel_names = True
        torch._inductor.config.fx_graph_cache = True
        torch._functorch.config.enable_autograd_cache = True

    def _as_len_tensor(self, lens):
        if isinstance(lens, torch.Tensor):
            t = lens.to(device=self.device, dtype=self.paged_kv_last_page_len.dtype)
        else:
            t = torch.tensor(lens, device=self.device, dtype=self.paged_kv_last_page_len.dtype)
        if t.dim() == 0:
            t = t.expand_as(self.paged_kv_last_page_len)
        return t

    def insert_kv(self, dec_lens):
        dec = self._as_len_tensor(dec_lens)
        if torch.all(dec <= 0):
            return

        old_full = self.num_pages_per_request.clone() - 1
        old_tail = self.paged_kv_last_page_len.clone()
        ps = self.page_size

        total_after = old_full * ps + old_tail + dec
        new_full = torch.where(
            total_after > 0,
            torch.div(total_after - 1, ps, rounding_mode='floor').to(old_full.dtype),
            torch.zeros_like(old_full),
        )
        new_tail = torch.where(
            total_after > 0,
            (((total_after - 1) % ps) + 1).to(old_tail.dtype),
            torch.zeros_like(old_tail),
        )

        add_pages = (new_full - old_full).clamp_min(0).to(torch.int32)
        if add_pages.max().item() > 0:
            self.paged_kv_indptr, self.paged_kv_indices = self.page_manager.allocate_counts(
                add_pages, self.paged_kv_indices, self.paged_kv_indptr
            )
            self.num_pages_per_request += add_pages  # sync with new_full

        self.paged_kv_last_page_len = new_tail
        self.cachelens = (self.cachelens + dec).clamp_min(0)

    def delete_kv(self, del_lens):
        dec = self._as_len_tensor(del_lens)
        if torch.all(dec <= 0):
            return

        old_full = self.num_pages_per_request.clone() - 1
        old_tail = self.paged_kv_last_page_len.clone()
        ps = self.page_size

        total_before = old_full * ps + old_tail
        total_after = (total_before - dec).clamp_min(0)

        new_full = torch.where(
            total_after > 0,
            torch.div(total_after - 1, ps, rounding_mode='floor').to(old_full.dtype),
            torch.zeros_like(old_full),
        )
        new_tail = torch.where(
            total_after > 0,
            (((total_after - 1) % ps) + 1).to(old_tail.dtype),
            torch.zeros_like(old_tail),
        )

        free_pages = (old_full - new_full).clamp_min(0).to(torch.int32)
        if free_pages.max().item() > 0:
            self.paged_kv_indptr, self.paged_kv_indices = self.page_manager.free_counts(
                free_pages, self.paged_kv_indices, self.paged_kv_indptr
            )
            self.num_pages_per_request -= free_pages  # sync with new_full

        self.paged_kv_last_page_len = new_tail
        self.cachelens = (self.cachelens - dec).clamp_min(0)

    def clear_kv(self):
        for b in self.model.layers:
            b.attention.kv_cache.kv_cache.zero_()
        
        self.cachelens = torch.zeros(self.batch_size, dtype=torch.int32, device=self.device)
        self.qo_indptr = torch.arange(self.batch_size+1, dtype=torch.int32, device=self.device)
        
        self.page_manager.reset()
        self.num_pages_per_request = torch.ones((self.batch_size), device=self.device, dtype=torch.int32)
        self.paged_kv_indptr = torch.arange(self.batch_size+1, dtype=torch.int32, device=self.device)
        self.paged_kv_indices = self.page_manager.allocate(torch.arange(self.batch_size, dtype=torch.int32, device=self.device))
        self.paged_kv_last_page_len = torch.zeros((self.batch_size), dtype=torch.int32, device=self.device)