import torch


class KVCache:
    def __init__(
            self,
            num_hidden_layers,
            num_key_value_heads,
            max_cache_len,
            head_dim,
            device,
            dtype,
            num_seqs=1,
    ):
        self.num_hidden_layers = num_hidden_layers
        self.num_key_value_heads = num_key_value_heads
        self.max_cache_len = max_cache_len
        self.head_dim = head_dim
        self.device = device
        self.dtype = dtype
        self.num_seqs = num_seqs

        self.kv_cache = torch.empty(
            (
                num_hidden_layers,
                2 * num_seqs,
                num_key_value_heads,
                max_cache_len,
                head_dim
            ),
            device=device,
            dtype=dtype
        )
        self.free_indices = list(range(max_cache_len))
        self.allocated_indices = []

    def num_allocated(self):
        return len(self.allocated_indices)

    def allocate(self, sequence_length):
        if len(self.free_indices) < sequence_length:
            raise ValueError("Not enough space in cache")
        indices = self.free_indices[:sequence_length]
        self.free_indices = self.free_indices[sequence_length:]
        self.allocated_indices.extend(indices)
        return indices

    def free(self, indices):
        for idx in indices:
            self.free_indices.append(idx)
            self.allocated_indices.remove(idx)

    def update(self, key_states, value_states, layer_idx, indices, new_indices, cache_kwargs):
        b, t, s, h = key_states.shape

        concated_kv = torch.cat([
            key_states.to(self.kv_cache.dtype),
            torch.empty((self.num_seqs - b, t, s, h), device=self.device, dtype=self.dtype),
            value_states.to(self.kv_cache.dtype),
            torch.empty((self.num_seqs - b, t, s, h), device=self.device, dtype=self.dtype),
        ], dim=0)
        if new_indices is None:
            self.kv_cache[layer_idx, :, :, indices[-s:]] = concated_kv
        else:
            self.kv_cache[layer_idx, :, :, new_indices] = concated_kv

        kv_cache = self.kv_cache[layer_idx, :, :, indices].to(key_states.dtype)

        return (
            kv_cache[0:b],
            kv_cache[self.num_seqs:self.num_seqs + b],
        )

    def get_seq_length(self):
        return 0


class FastKVCache:
    def __init__(
            self,
            num_hidden_layers,
            num_key_value_heads,
            max_cache_len,
            head_dim,
            device,
            dtype,
            num_seqs=1,
    ):
        self.num_hidden_layers = num_hidden_layers
        self.num_key_value_heads = num_key_value_heads
        self.max_cache_len = max_cache_len
        self.head_dim = head_dim
        self.device = device
        self.dtype = dtype
        self.num_seqs = num_seqs

        self.k_cache = torch.empty(
            (
                num_hidden_layers,
                num_seqs,
                num_key_value_heads,
                max_cache_len,
                head_dim
            ),
            device=device,
            dtype=dtype
        )
        self.v_cache = torch.empty(
            (
                num_hidden_layers,
                num_seqs,
                num_key_value_heads,
                max_cache_len,
                head_dim
            ),
            device=device,
            dtype=dtype
        )
        self.valid_len = 0
        self.free_indices = list(range(max_cache_len))
        self.allocated_indices = []

    def num_allocated(self):
        return len(self.allocated_indices)

    def allocate(self, sequence_length):
        if len(self.free_indices) < sequence_length:
            raise ValueError("Not enough space in cache")
        indices = self.free_indices[:sequence_length]
        self.free_indices = self.free_indices[sequence_length:]
        self.allocated_indices = self.allocated_indices + indices
        return indices

    def free(self, indices, unfree_indices=[]):
        idxs = {i: idx for idx, i in enumerate(self.allocated_indices)}
        for idx in indices:
            self.free_indices.append(idx)
            self.allocated_indices.remove(idx)
        self.valid_len = self.valid_len - len(indices) - len(unfree_indices)
        if unfree_indices:
            unfree_idxs = [idxs[i] for i in unfree_indices]
            self.k_cache[:, :, :, self.valid_len:self.valid_len + len(unfree_idxs), :] = self.k_cache[:, :, :, unfree_idxs, :]
            self.v_cache[:, :, :, self.valid_len:self.valid_len + len(unfree_idxs), :] = self.v_cache[:, :, :, unfree_idxs, :]
            self.valid_len += len(unfree_idxs)

    # @profile
    def update(self, key_states, value_states, layer_idx, indices, new_indices, cache_kwargs):
        b, t, s, h = key_states.shape
        assert b == 1

        self.k_cache[layer_idx, :, :, self.valid_len:self.valid_len + s, :] = key_states
        self.v_cache[layer_idx, :, :, self.valid_len:self.valid_len + s, :] = value_states
        k_cache = self.k_cache[layer_idx, :, :, :self.valid_len + s, :]
        v_cache = self.v_cache[layer_idx, :, :, :self.valid_len + s, :]
        if self.num_hidden_layers == layer_idx + 1:
            self.valid_len += s

        return (
            k_cache,
            v_cache,
        )

    def get_seq_length(self):
        return 0
