from typing import Any, List, Optional, Dict, Tuple, Union
import torch
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.configuration_utils import PretrainedConfig

def create_kv_cache(
    cache_implementation = "dynamic",
    max_cache_len=None,
    max_batch_size=None,
    config=None,
    device='cpu',
    dtype='float16',
    layer_device_map=None
):
    if cache_implementation == "dynamic":
        return TreeDynamicCache()
    
    elif cache_implementation == "static":
        return TreeStaticCache(
            max_cache_len=max_cache_len,
            max_batch_size=max_batch_size,
            config=config,
            device=device,
            dtype=dtype,
            layer_device_map=layer_device_map,
        )

class TreeDynamicCache(DynamicCache):
    def __init__(self) -> None:
        super().__init__()
        # user should maintain seq_len manually when using this class
        self.seq_len = 0
    
        self._key_cache_storage: List[torch.Tensor] = []
        self._value_cache_storage: List[torch.Tensor] = []

    def get_seq_length(self) -> int:
        return self.seq_len
        
    def crop(self, max_length: int):
        """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
        negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
        # In case it is negative
        if max_length < 0:
            max_length = self.get_seq_length() - abs(max_length)

        self._seen_tokens = max_length
        for idx in range(len(self.key_cache)):
            if self.key_cache[idx].numel():
                self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
                self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
                
    def reorder_cache(self, beam_idx: torch.LongTensor, dim=0):
        """Reorder cache for beam search (classic approach)."""
        for i in range(len(self.key_cache)):
            dev = self.key_cache[i].device
            self.key_cache[i] = self.key_cache[i].index_select(dim, beam_idx.to(dev))
            self.value_cache[i] = self.value_cache[i].index_select(dim, beam_idx.to(dev))
            
    def reorder_cache_with_offset(self, beam_idx: torch.LongTensor, new_chunk_len=1, offset=0, dim=0):
        """
        Reorder the cache for beam search with an offset. 
        [:offset] remain unchanged; [offset:] is reordered.
        """
        # Build the full reorder indices
        full_beam_idx = torch.cat(
            [torch.arange(offset, device=beam_idx.device), beam_idx + offset], dim=0
        )
        beam_idx_device_cache = {}

        for i in range(len(self.key_cache)):
            dev = self.key_cache[i].device
            if dev not in beam_idx_device_cache:
                beam_idx_device_cache[dev] = full_beam_idx.to(dev)
            r_idx = beam_idx_device_cache[dev]
            
            self.key_cache[i] = self.key_cache[i].index_select(dim, r_idx)
            self.value_cache[i] = self.value_cache[i].index_select(dim, r_idx)
            
    def reset(self):
        """Resets the cache."""
        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self.seq_len = 0
        self._key_cache_storage = []
        self._value_cache_storage = []

    def reorder_full_cache_with_offset(self, full_kv, beam_idx: torch.LongTensor, new_chunk_len=1, offset=0, dim=0):
        """
        Reorder the slice [offset : offset + new_chunk_len] of each key/value cache
        according to the order specified by beam_idx, then zero out any leftover positions.
        The update is performed in batch for all layers on a device so that the underlying
        tensor objects (their memory pointers) remain unchanged—a requirement for CUDA graphs.
        
        Parameters:
          beam_idx (LongTensor): 1D tensor of indices indicating the new ordering.
          new_chunk_len (int): The new length of the updated slice.
          offset (int): The starting offset along dimension `dim` to update.
          dim (int): The dimension along which the update occurs.
        """
        src_offset = full_kv.get_seq_length()
        slice_len = beam_idx.size(0)
        if slice_len == 0:
            return

        dev_groups = {}
        for i, k in enumerate(self.key_cache):
            dev_groups.setdefault(k.device, []).append(i)

        for dev, indices in dev_groups.items():
            b_idx = beam_idx.to(dev)
            src_idx = src_offset + b_idx
            
            k_src_cat = torch.stack([full_kv.key_cache[i].to(dev)
                                     for i in indices], dim=0)
            v_src_cat = torch.stack([full_kv.value_cache[i].to(dev)
                                     for i in indices], dim=0)

            k_sel = k_src_cat.index_select(dim+1, src_idx)
            v_sel = v_src_cat.index_select(dim+1, src_idx)

            for j, i in enumerate(indices):
                # NOTE: This method still uses cat/stack as it involves complex reordering logic
                # Optimization could be applied here similar to reorder_full_cache_with_offset_seq
                # if needed, but seq method is the primary target for append operations.
                self.key_cache[i] = torch.cat((self.key_cache[i], k_sel[j]), dim=2)
                self.value_cache[i] = torch.cat((self.value_cache[i], v_sel[j]), dim=2)

    def _ensure_storage(self, layer_idx: int, target_shape: Tuple[int, ...], dtype: torch.dtype, device: torch.device):
        """
        Ensure internal storage is large enough for the target shape.
        If storage is missing or too small, allocate new storage (doubling size or fitting target).
        Preserves existing data in storage if expanding.
        """
        # Initialize storage lists if they are empty (e.g. after init or reset)
        while len(self._key_cache_storage) <= layer_idx:
            self._key_cache_storage.append(None)
            self._value_cache_storage.append(None)

        current_storage_k = self._key_cache_storage[layer_idx]
        seq_dim = 2 # (B, H, L, D)

        required_len = target_shape[seq_dim]
        
        # Determine if we need to allocate/reallocate
        if current_storage_k is None or current_storage_k.shape[seq_dim] < required_len:
            # Allocation strategy: Max of (required_len, current * 2, initial_chunk)
            # This avoids frequent reallocations for long sequences (O(log N) allocations)
            new_len = required_len + 128 # Add buffer
            if current_storage_k is not None:
                new_len = max(new_len, current_storage_k.shape[seq_dim] * 2)
            
            # Create new shape
            new_shape = list(target_shape)
            new_shape[seq_dim] = new_len
            
            # Allocate new storage
            new_storage_k = torch.zeros(new_shape, dtype=dtype, device=device)
            new_storage_v = torch.zeros(new_shape, dtype=dtype, device=device)

            # Copy existing data if available
            if current_storage_k is not None:
                # We copy the *valid* data from the previous storage
                # Wait, we should copy from the *current exposed cache* because 
                # external operations might have modified self.key_cache without updating storage directly.
                # However, for safety in `reorder_full_cache_with_offset_seq`, we sync from self.key_cache.
                pass 

            self._key_cache_storage[layer_idx] = new_storage_k
            self._value_cache_storage[layer_idx] = new_storage_v

        return self._key_cache_storage[layer_idx], self._value_cache_storage[layer_idx]

    def reorder_full_cache_with_offset_seq(self, ref_kv, start=0, end=0, dim=0):
        """
        Concatenate ref_kv's key and value cache from [start:end] to self's cache.
        Optimized to use pre-allocated storage (copy_) instead of cat where possible.
        """
        seq_dim = 2
        append_len = end - start
        
        if append_len <= 0:
            return

        for i in range(len(self.key_cache)):
            # Incoming data slice
            k_sel = ref_kv.key_cache[i][..., start:end, :]
            v_sel = ref_kv.value_cache[i][..., start:end, :]
            
            current_k = self.key_cache[i]
            current_v = self.value_cache[i]
            
            current_len = current_k.shape[seq_dim]
            target_len = current_len + append_len
            
            # Ensure storage exists and is big enough
            storage_k, storage_v = self._ensure_storage(
                i, 
                target_shape=current_k.shape[:seq_dim] + (target_len,) + current_k.shape[seq_dim+1:],
                dtype=current_k.dtype,
                device=current_k.device
            )
            
            # Sync Logic:
            # 1. If current_k is NOT a view of storage (e.g. first run, or standard 'update' replaced it with a cat result),
            #    we must copy current_k into storage start.
            # 2. If storage was just reallocated (new tensor), we also must copy current_k.
            # We can check if memory addresses match or just pessimistically copy if we aren't sure.
            # Efficient check: check data_ptr().
            
            is_view_k = (current_k.data_ptr() >= storage_k.data_ptr()) and \
                        (current_k.data_ptr() < storage_k.data_ptr() + storage_k.nbytes)
            
            if not is_view_k:
                # Copy current cache into storage [0 : current_len]
                storage_k[..., :current_len, :].copy_(current_k)
                storage_v[..., :current_len, :].copy_(current_v)
            
            # Now append the new data (ref_kv) into storage [current_len : target_len]
            # This uses copy_ which is fast and requires no new allocation
            storage_k[..., current_len:target_len, :].copy_(k_sel)
            storage_v[..., current_len:target_len, :].copy_(v_sel)
            
            # Update self.key_cache to be the view of the new valid length
            self.key_cache[i] = storage_k[..., :target_len, :]
            self.value_cache[i] = storage_v[..., :target_len, :]

class TreeStaticCache(StaticCache):
    def __init__(
        self,
        config: PretrainedConfig,
        max_cache_len: Optional[int] = None,
        device: Optional[torch.device] = None,
        layer_device_map: Optional[Dict[int, torch.device]] = None,
        dtype: torch.dtype = torch.float32,
        max_batch_size: Optional[int] = None,
    ) -> None:
        super().__init__(
            config=config,
            max_cache_len=max_cache_len,
            device=device,
            layer_device_map=layer_device_map,
            dtype=dtype,
            max_batch_size=max_batch_size,
        )
        # user should maintain seq_len manually when using this class
        self.seq_len = 0
    
    def get_seq_length(self) -> int:
        return self.seq_len
    
    def reset(self):
        """Resets the cache values while preserving the objects"""
        for layer_idx in range(len(self.key_cache)):
            # In-place ops prevent breaking the static address
            self.key_cache[layer_idx].zero_()
            self.value_cache[layer_idx].zero_()
        self.seq_len = 0

    def crop(self, max_length: int):
        """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
        negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
        # In case it is negative
        if max_length < 0:
            max_length = self.get_seq_length() - abs(max_length)

        self._seen_tokens = max_length
        for idx in range(len(self.key_cache)):
            if self.key_cache[idx].numel():
                self.key_cache[idx][..., max_length:, :] = 0
                self.value_cache[idx][..., max_length:, :] = 0

    def reorder_cache_with_offset(
        self,
        beam_idx: torch.LongTensor,
        new_chunk_len: int = 1,
        offset: int = 0,
        dim: int = 0,
    ) -> None:
        """
        Reorder the slice [offset : offset + new_chunk_len] of each key/value cache
        according to the order specified by beam_idx, then zero out any leftover positions.
        The update is performed in batch for all layers on a device so that the underlying
        tensor objects (their memory pointers) remain unchanged—a requirement for CUDA graphs.
        
        Parameters:
          beam_idx (LongTensor): 1D tensor of indices indicating the new ordering.
          new_chunk_len (int): The new length of the updated slice.
          offset (int): The starting offset along dimension `dim` to update.
          dim (int): The dimension along which the update occurs.
        """
        slice_len = beam_idx.size(0)
        # Group cache indices by device.
        dev_groups = {}
        for i, (k, _) in enumerate(zip(self.key_cache, self.value_cache)):
            dev_groups.setdefault(k.device, []).append(i)
        
        # Process each device group.
        for dev, indices in dev_groups.items():
            # Ensure beam_idx is on the correct device.
            b_idx = beam_idx.to(dev)
            reorder_src = offset + b_idx
            reorder_dest = offset + torch.arange(slice_len, device=dev)
            
            # Stack the caches for this device.
            k_cat = torch.stack([self.key_cache[i] for i in indices], dim=0)
            v_cat = torch.stack([self.value_cache[i] for i in indices], dim=0)
            # Batched update along dimension `dim`
            k_cat.index_copy_(dim+1, reorder_dest, k_cat.index_select(dim+1, reorder_src))
            v_cat.index_copy_(dim+1, reorder_dest, v_cat.index_select(dim+1, reorder_src))
            
            # Scatter the updated results back.
            for j, i in enumerate(indices):
                self.key_cache[i].copy_(k_cat[j])
                self.value_cache[i].copy_(v_cat[j])
    def reorder_full_cache_with_offset_seq(self, ref_kv, start=0, end=0, dim=0):
        """
        Concatenate ref_kv's key and value cache from [start:end] to self's cache.
        """

        for i in range(len(self.key_cache)):
            k_sel = ref_kv.key_cache[i][..., start:end, :]
            v_sel = ref_kv.value_cache[i][..., start:end, :]
            self.key_cache[i][..., self.seq_len:self.seq_len + (end - start), :].copy_(k_sel)
            self.value_cache[i][..., self.seq_len:self.seq_len + (end - start), :].copy_(v_sel)

    def reorder_full_cache_with_offset(self, full_kv, beam_idx: torch.LongTensor, new_chunk_len=1, offset=0, dim=0):
        """
        Reorder the slice [offset : offset + new_chunk_len] of each key/value cache
        according to the order specified by beam_idx, then zero out any leftover positions.
        The update is performed in batch for all layers on a device so that the underlying
        tensor objects (their memory pointers) remain unchanged—a requirement for CUDA graphs.
        
        Parameters:
          beam_idx (LongTensor): 1D tensor of indices indicating the new ordering.
          new_chunk_len (int): The new length of the updated slice.
          offset (int): The starting offset along dimension `dim` to update.
          dim (int): The dimension along which the update occurs.
        """
        src_offset = full_kv.get_seq_length()
        slice_len = beam_idx.size(0)
        if slice_len == 0:
            return

        dev_groups = {}
        for i, k in enumerate(self.key_cache):
            dev_groups.setdefault(k.device, []).append(i)

        for dev, indices in dev_groups.items():
            b_idx = beam_idx.to(dev)
            src_idx = src_offset + b_idx
            dest_idx = offset + torch.arange(slice_len, device=dev)
            
            k_dst_cat = torch.stack([self.key_cache[i] for i in indices], dim=0)
            v_dst_cat = torch.stack([self.value_cache[i] for i in indices], dim=0)

            k_src_cat = torch.stack([full_kv.key_cache[i].to(dev)
                                     for i in indices], dim=0)
            v_src_cat = torch.stack([full_kv.value_cache[i].to(dev)
                                     for i in indices], dim=0)

            k_sel = k_src_cat.index_select(dim+1, src_idx)
            v_sel = v_src_cat.index_select(dim+1, src_idx)

            k_dst_cat.index_copy_(dim+1, dest_idx, k_sel)
            v_dst_cat.index_copy_(dim+1, dest_idx, v_sel)

            for j, i in enumerate(indices):
                self.key_cache[i].copy_(k_dst_cat[j])
                self.value_cache[i].copy_(v_dst_cat[j])