"""CacheEngine class for managing the KV cache."""
from typing import List

import torch
import time
from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, FlashCacheConfig
from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, Device,
                        get_dtype_size, is_pin_memory_available)

logger = init_logger(__name__)


class CacheEngine:
    """Manages the KV cache.

    This class is responsible for initializing and managing the GPU and CPU KV
    caches. It also provides methods for performing KV cache operations, such
    as swapping and copying.
    """

    def __init__(
        self,
        cache_config: CacheConfig,
        flash_cache_config: FlashCacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        device_config: DeviceConfig,
    ) -> None:
        self.cache_config = cache_config
        self.flash_cache_config = flash_cache_config
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.device_config = device_config      

        self.head_size = model_config.get_head_size()
        # Models like Jamba, have mixed typed layers, E.g Mamba
        self.num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)

        self.block_size = cache_config.block_size
        
        self.num_gpu_blocks = cache_config.num_gpu_blocks
        if self.num_gpu_blocks:
            self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
        
        self.num_cpu_blocks = cache_config.num_cpu_blocks
        if self.num_cpu_blocks:
            self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
        
        self.num_fc_blocks = flash_cache_config.num_fc_blocks
        if self.num_fc_blocks:
            self.num_fc_blocks //= parallel_config.pipeline_parallel_size
        
        if cache_config.cache_dtype == "auto":
            self.dtype = model_config.dtype
        else:
            self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

        # Get attention backend.
        self.attn_backend = get_attn_backend(self.head_size,
                                            model_config.dtype,
                                            cache_config.cache_dtype,
                                            self.block_size,
                                            model_config.is_attention_free)

        # Initialize the cache.
        self.gpu_cache = self._allocate_kv_cache(
            self.num_gpu_blocks, self.device_config.device_type)
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
        
        # Flash cache initialization
        if (self.num_fc_blocks is not None and self.num_fc_blocks > 0):
            self.use_fc_for_swap = True
            self.file_handle = None
        else:
            self.use_fc_for_swap = False
            
        if self.use_fc_for_swap:
            logger.info("Initializing flash cache...")
            start_init_time = time.perf_counter()
            
            # Calculate and store LBA offset
            self.lba_offsets = []
            
            # Calculate block size for debugging
            d_type_size = get_dtype_size(self.dtype)
            block_size_bytes = self.num_kv_heads * self.block_size * self.head_size * d_type_size
            
            for layer in range(self.num_attention_layers):
                layer_offsets = []
                for kv_type in range(2):  # 0: K, 1: V
                    # Layer offset
                    layer_offset = layer * (2 * self.num_fc_blocks * block_size_bytes)
                    
                    # KV type offset
                    kv_offset = kv_type * (self.num_fc_blocks * block_size_bytes)
                    
                    # Total offset
                    total_offset = layer_offset + kv_offset
                    layer_offsets.append(total_offset)
                    

                self.lba_offsets.append(layer_offsets)
            
            # Print total size information
            total_size = self.lba_offsets[-1][1] + (self.num_fc_blocks * self.num_kv_heads * 
                                                   self.block_size * self.head_size)
            logger.info(f"- Total Size (GiB): {total_size / (1024 * 1024 * 1024):.2f} GiB")
            
            # Open SSD file for flash cache
            self.file_handle = self.attn_backend.init_gds_for_flash_cache(
                self.flash_cache_config.fc_path)
            logger.info(f"\nSSD file opened with handle: {self.file_handle}")
            
            # Input tuning parameters
            self.tuning_params = torch.zeros(4, device=self.gpu_cache[0].device)
            self.tuning_params[0] = self.flash_cache_config.num_threads
            self.tuning_params[1] = self.flash_cache_config.io_size_bytes
            # Key cache bytes offset, only this tensor is transferred to swap_fc_blocks()
            self.tuning_params[2] = 0
            # Value cache bytes offset
            self.tuning_params[3] = 0
            
            logger.info(f"\nTuning parameters:")
            logger.info(f"- Threads: {self.tuning_params[0]}")
            logger.info(f"- IO size: {self.tuning_params[1] / (1024 * 1024):.2f} MB")
            
            end_init_time = time.perf_counter()
            init_time = (end_init_time - start_init_time) * 1000
            logger.info(f"\nHiFC initialization time: {init_time:.2f} ms")
            # logger.info("=== End of LBA Offset Debug Information ===\n")

        # HiFC: Define swap device ID
        self.swap_device_gpu = 0
        self.swap_device_cpu = 1
        self.swap_device_fc = 2
        
    def cleanup(self):
        """Cleanup resources when the engine is no longer needed."""
        if self.use_fc_for_swap and self.file_handle is not None:
            logger.info("Cleaning up flash cache resources...")
            start_time = time.perf_counter()

            # Reset GDS for flash cache
            self.attn_backend.reset_gds_for_flash_cache(self.file_handle)
            self.file_handle = None
            logger.info("SSD file closed")
            
            end_time = time.perf_counter()
            cleanup_time = (end_time - start_time) * 1000
            logger.info(f"Flash cache cleanup time: {cleanup_time:.2f} ms")

    def _allocate_kv_cache(
        self,
        num_blocks: int,
        device: str,
    ) -> List[torch.Tensor]:
        """Allocates KV cache on the specified device."""
        kv_cache_shape = self.attn_backend.get_kv_cache_shape(
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
        pin_memory = is_pin_memory_available() if device == "cpu" else False
        kv_cache: List[torch.Tensor] = []
        for _ in range(self.num_attention_layers):
            # null block in CpuGpuBlockAllocator requires at least that
            # block to be zeroed-out.
            # We zero-out everything for simplicity.
            kv_cache.append(
                torch.zeros(kv_cache_shape,
                            dtype=self.dtype,
                            pin_memory=pin_memory,
                            device=device))
        return kv_cache

    def swap_in(self, src_to_dst: torch.Tensor) -> None:
        # check start time for swap in latency
        
        swap_in_log_enabled = True 
        
        if swap_in_log_enabled:
            start_time = time.perf_counter()
        
        # Pop a first tensor from src_to_dst
        swap_in_src = src_to_dst[0]
        
        # Remove the first tensor from src_to_dst
        src_to_dst = src_to_dst[1:]
        
        if swap_in_src[0].item() == self.swap_device_fc:
            # if swap_in_log_enabled:
            #     logger.info(f"[SWAP_IN] src and dst block ids: {src_to_dst}")
                #logger.info(f"[SWAP_IN] Swap in FC")
            for i in range(self.num_attention_layers):
                # Key cache bytes offset, only this tensor is transferred to swap_fc_blocks()
                self.tuning_params[2] = self.lba_offsets[i][0]
                # Value cache bytes offset
                self.tuning_params[3] = self.lba_offsets[i][1]
                self.attn_backend.swap_flash_blocks(self.file_handle,
                                                    self.gpu_cache[i],
                                                    src_to_dst,
                                                    self.tuning_params,
                                                    is_gpu_to_flash=False)            
        elif swap_in_src[0].item() == self.swap_device_cpu:
            #if swap_in_log_enabled:
                #logger.info(f"[SWAP_IN] Swap in CPU")
            for i in range(self.num_attention_layers):
                self.attn_backend.swap_blocks(self.cpu_cache[i],
                                            self.gpu_cache[i],
                                            src_to_dst)
        else:
            raise ValueError(f"Invalid swap in source: {swap_in_src}")

        if swap_in_log_enabled:
            end_time = time.perf_counter()
            elapsed_time = (end_time - start_time) * 1000
            logger.info(f"[SWAP_IN] Total swap in latency: {elapsed_time:.2f} ms")
            
            total_blocks = src_to_dst.shape[0]
            blocks_per_second = total_blocks / (elapsed_time / 1000)
            logger.info(f"[SWAP_IN] Total blocks: {total_blocks}")
            logger.info(f"[SWAP_IN] Performance: {blocks_per_second:.2f} blocks/s")

    def swap_out(self, src_to_dst: torch.Tensor) -> None:
        
        swap_out_log_enabled = True
        
        if swap_out_log_enabled:
            start_time = time.perf_counter()
        
        # Pop a first tensor from src_to_dst
        swap_out_dst = src_to_dst[0]
        logger.info(f"[SWAP_OUT] Swap out destination: {swap_out_dst[0].item()}")
        # Remove the first tensor from src_to_dst
        src_to_dst = src_to_dst[1:]
        
        if swap_out_dst[0].item() == self.swap_device_fc:            
            for i in range(self.num_attention_layers):
                    self.attn_backend.swap_flash_blocks(self.file_handle,
                                                        self.gpu_cache[i],
                                                        src_to_dst,
                                                        self.tuning_params,
                                                        is_gpu_to_flash=True)
        elif swap_out_dst[0].item() == self.swap_device_cpu:
            for i in range(self.num_attention_layers):
                self.attn_backend.swap_blocks(self.gpu_cache[i],
                                            self.cpu_cache[i],
                                            src_to_dst)
        else:
            raise ValueError(f"Invalid swap out destination: {swap_out_dst}")   
        
        if swap_out_log_enabled:
            end_time = time.perf_counter()
            elapsed_time = (end_time - start_time) * 1000
            logger.info(f"[SWAP_OUT] Total swap out latency: {elapsed_time:.2f} ms")
            
            total_blocks = src_to_dst.shape[0]
            blocks_per_second = total_blocks / (elapsed_time / 1000)
            logger.info(f"[SWAP_OUT] Total blocks: {total_blocks}")
            logger.info(f"[SWAP_OUT] Performance: {blocks_per_second:.2f} blocks/s")
            
            # All block ids in src_to_dst
            src_block_ids = src_to_dst[:, 0].tolist()
            dst_block_ids = src_to_dst[:, 1].tolist()
            logger.info(f"[SWAP_OUT] src block ids: {src_block_ids}")
            logger.info(f"[SWAP_OUT] dst block ids: {dst_block_ids}")          

    def copy(self, src_to_dsts: torch.Tensor) -> None:
        self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)

    @staticmethod
    def get_cache_block_size(
        cache_config: CacheConfig,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
    ) -> int:
        head_size = model_config.get_head_size()
        num_heads = model_config.get_num_kv_heads(parallel_config)
        num_attention_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)

        key_cache_block = cache_config.block_size * num_heads * head_size
        value_cache_block = key_cache_block
        total = num_attention_layers * (key_cache_block + value_cache_block)
        if cache_config.cache_dtype == "auto":
            dtype = model_config.dtype
        else:
            dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
        dtype_size = get_dtype_size(dtype)
        return dtype_size * total
