import os
import hashlib
from collections import OrderedDict
from typing import Dict, Tuple, Optional
from typing import List, Optional, Tuple, Union
import torch
from transformers.cache_utils import Cache
from copy import deepcopy
import threading
from time import time
import multiprocessing
from collections import deque
import copy
from transformers import AutoTokenizer
import random
import math

from src.cache_utils import DynamicCache
Digest = Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]

def create_cache_copy(existing_cache: DynamicCache) -> DynamicCache:
    cache_copy = DynamicCache()
        
    cache_copy._seen_tokens = existing_cache._seen_tokens
    cache_copy.key_cache = [tensor.clone() for tensor in existing_cache.key_cache] 
    cache_copy.value_cache = [tensor.clone() for tensor in existing_cache.value_cache]  

    return cache_copy

class KVCacheManager:
    def __init__(self, 
                 memory_capacity: int = 2000, 
                 hbm_capacity = 20,
                 disk_path: Optional[str] = "",
                 block_size: int = 64,
                 topk_threshold: float = 0.9,
                 mts_strategy: str = "RR_MAX",
                 digest_strategy: str = "bounding_cuboid",
                 log_dir: Optional[str] = None):
        self.disk_path = disk_path
        self.hbm_capacity = hbm_capacity
        self.memory_capacity = memory_capacity
        self.mts_strategy = mts_strategy
        self.digest_strategy = digest_strategy
        self.block_size = block_size
        
        self.history_token_cnt = 0
        self.input_ids_list: List[List[int]] = []
        self.hash_key_list: List[str] = []
        self.kv_cache: Dict[str, DynamicCache] = {}
        self.memory_digest_cache: Dict[str, List[Digest]] = {}
        self.retrieve_cnt = -1
        
        # Logging setup
        if log_dir is None:
            self.log_dir = f"logs/block_size_{block_size}_threshold_{topk_threshold}_mts_{self.mts_strategy}"
        else:
            self.log_dir = log_dir
        os.makedirs(self.log_dir, exist_ok=True)
        self.time_log_file = os.path.join(self.log_dir, "timing.log")

        # Append to log files, write header if file doesn't exist
        if not os.path.exists(self.time_log_file):
            with open(self.time_log_file, "w") as f:
                f.write("request_id,total_forward_time,total_topk_time,total_concat_time,total_attn_time,other_time\n")

        # Time accumulators for the current request
        self.current_topk_time = 0
        self.current_concat_time = 0
        self.current_attn_time = 0
        
    def _generate_key(self, token_ids: List[int]) -> str:
        return hashlib.sha256(str(token_ids).encode()).hexdigest()
    
    def get_cache(self, 
                 token_ids: List[int]
                ) -> Optional[Cache]:
        key = self._generate_key(token_ids)
        if key in self.kv_cache:
            return self.kv_cache[key]
        else:
            return None
        
    def is_exist(self, 
                 token_ids: List[int]
                ) -> bool:
        key = self._generate_key(token_ids)
        return key in self.kv_cache
        
    def save_cache(self, token_ids: List[int], cache: DynamicCache):
        self.input_ids_list.append(token_ids)
        key = self._generate_key(token_ids)
        self.hash_key_list.append(key)
        copy_cache = DynamicCache()
        for i in range(len(cache.key_cache)):
            cpu_k_tensor = cache.key_cache[i].cpu()
            cpu_v_tensor = cache.value_cache[i].cpu()
            copy_cache.key_cache.append(copy.deepcopy(cpu_k_tensor))
            copy_cache.value_cache.append(copy.deepcopy(cpu_v_tensor))
        self.kv_cache[key] = copy_cache

    def save_digest_cache(self, token_ids: List[int], key_cache: List[torch.Tensor]):
        # (bsz, head_num, k_len, head_dim)
        key = self._generate_key(token_ids)
        digest_list = []

        if self.digest_strategy == "bounding_cuboid":
            # Bounding-cuboid digest
            for key_tensor in key_cache:
                # Clone the key_tensor to ensure the original key_cache is not modified
                key_tensor_clone = key_tensor.clone()
                maxs = key_tensor_clone.max(dim=2).values
                mins = key_tensor_clone.min(dim=2).values
                centers = (maxs + mins) / 2
                dists = (
                    (centers.unsqueeze(2) - key_tensor_clone).abs().mean(dim=2)
                )
                maxs = centers + dists
                mins = centers - dists
                digest_list.append((maxs, mins))
        elif self.digest_strategy == "mean_digest":
            for key_tensor in key_cache:
                digest_list.append(key_tensor.mean(dim=2, keepdim=True))
        elif self.digest_strategy == "bounding_sphere":
            for key_tensor in key_cache:
                center = key_tensor.mean(dim=2)
                radius = (key_tensor - center.unsqueeze(2)).norm(p=2, dim=-1).max(dim=-1).values
                digest_list.append((center, radius))

        self.memory_digest_cache[key] = digest_list

    def log_time(self, event_type: str, duration: float):
        if event_type == "topk_time":
            self.current_topk_time += duration
        elif event_type == "concat_time":
            self.current_concat_time += duration
        elif event_type == "attn_time":
            self.current_attn_time += duration

    def finalize_and_log_times(self, total_forward_time: float):
        other_time = total_forward_time - self.current_topk_time - self.current_concat_time - self.current_attn_time
        with open(self.time_log_file, "a") as f:
            f.write(f"{self.retrieve_cnt},{total_forward_time},{self.current_topk_time},{self.current_concat_time},{self.current_attn_time},{other_time}\n")
        
        # Reset accumulators
        self.current_topk_time = 0
        self.current_concat_time = 0
        self.current_attn_time = 0
        
    def retrieve_related_kv(self,
                            query_vector: torch.Tensor,
                            topk_threshold: float = 0.9,
                            layer_idx: int = 0,
                            num_key_value_groups: int = 8) -> Tuple[torch.Tensor, torch.Tensor]:
        if layer_idx == 0:
            self.retrieve_cnt += 1

        if len(self.hash_key_list) == 0 or topk_threshold == 0:
            return torch.empty(0, 0, 0, 0, device=query_vector.device), torch.empty(0, 0, 0, 0, device=query_vector.device)

        if self.digest_strategy == "bounding_cuboid":
            max_digest_list = []
            min_digest_list = []
            for key in self.hash_key_list:
                if len(self.memory_digest_cache[key]) > 0:
                    digest = self.memory_digest_cache[key][layer_idx]
                    max_digest_list.append(digest[0])
                    min_digest_list.append(digest[1])
            
            if not max_digest_list:
                return torch.empty(0, 0, 0, 0, device=query_vector.device), torch.empty(0, 0, 0, 0, device=query_vector.device)

            max_digest = torch.stack(max_digest_list, dim=2).to(query_vector.device)
            min_digest = torch.stack(min_digest_list, dim=2).to(query_vector.device)
            
            related_score_eval_st = time()
            scores = self.related_score_eval(query_vector, max_digest, min_digest, layer_idx, num_key_value_groups)
        elif self.digest_strategy == "bounding_sphere":
            center_list = []
            radius_list = []
            for key in self.hash_key_list:
                if len(self.memory_digest_cache[key]) > 0:
                    digest = self.memory_digest_cache[key][layer_idx]
                    center_list.append(digest[0])
                    radius_list.append(digest[1])

            if not center_list:
                return torch.empty(0, 0, 0, 0, device=query_vector.device), torch.empty(0, 0, 0, 0, device=query_vector.device)

            centers = torch.stack(center_list, dim=2).to(query_vector.device)
            radii = torch.stack(radius_list, dim=2).to(query_vector.device)
            
            related_score_eval_st = time()
            scores = self.related_score_eval(query_vector, centers, radii, layer_idx, num_key_value_groups)
        elif self.digest_strategy == "mean_digest":
            mean_digest_list = []
            for key in self.hash_key_list:
                if len(self.memory_digest_cache[key]) > 0:
                    mean_digest_list.append(self.memory_digest_cache[key][layer_idx])
            
            if not mean_digest_list:
                return torch.empty(0, 0, 0, 0, device=query_vector.device), torch.empty(0, 0, 0, 0, device=query_vector.device)
            
            mean_digest = torch.cat(mean_digest_list, dim=2).to(query_vector.device)
            
            related_score_eval_st = time()
            scores = self.related_score_eval(query_vector, mean_digest, None, layer_idx, num_key_value_groups)
        else:
            raise ValueError(f"Unknown digest strategy: {self.digest_strategy}")
        
        head_num, k_len = scores.shape
        scores_double = scores.to(torch.float64)
        # Sort scores and get original indices
        sorted_scores, sorted_indices = torch.sort(scores_double, dim=-1, descending=True)
        # Use the maximum k across all heads to ensure uniform k_len
        max_k = int(topk_threshold)
        self.log_time("topk_time", time() - related_score_eval_st)
        
        # Get the top indices based on max_k
        top_indices = sorted_indices[:, :max_k]

        concat_st = time()
        head_retrieve_k_list = []
        head_retrieve_v_list = []
        head_top_indices = top_indices[0]
        
        # Sort these indices to maintain order for concatenation
        sorted_head_top_indices = torch.sort(head_top_indices).values
        
        k_list = []
        v_list = []
        for idx in sorted_head_top_indices.tolist():
            key = self.hash_key_list[idx]
            k_list.append(self.kv_cache[key].key_cache[layer_idx][:, :, :, :])
            v_list.append(self.kv_cache[key].value_cache[layer_idx][:, :, :, :])
        
        if k_list:
            head_retrieve_k_list.append(torch.cat(k_list, dim=-2).cuda())
            head_retrieve_v_list.append(torch.cat(v_list, dim=-2).cuda())

        if not head_retrieve_k_list:
            return torch.empty(0, 0, 0, 0, device=query_vector.device), torch.empty(0, 0, 0, 0, device=query_vector.device)

        retrieve_k = torch.cat(head_retrieve_k_list, dim=1).cuda()
        retrieve_v = torch.cat(head_retrieve_v_list, dim=1).cuda()
        self.log_time("concat_time", time() - concat_st)
        
        max_len = retrieve_k.shape[-2]

        return retrieve_k, retrieve_v
        
    def related_score_eval(self, query: torch.Tensor, max_digest: torch.Tensor, min_digest: Optional[torch.Tensor],
                           layer_idx: int = 0, num_key_value_groups: int = 8) -> torch.Tensor:
        if self.mts_strategy == "MEAN_POOL":
            query = query.mean(dim=2, keepdim=True)
        elif self.mts_strategy == "MAX_POOL":
            query = query.max(dim=2, keepdim=True).values

        batch, q_num_key_value_heads, qlen, head_dim = query.shape
        batch, k_num_key_value_heads, klen, head_dim = max_digest.shape
        n_rep = q_num_key_value_heads // k_num_key_value_heads

        if self.digest_strategy == "bounding_cuboid":
            max_digest = max_digest[:, :, None, :, :].expand(batch, k_num_key_value_heads, n_rep, klen, head_dim)
            max_digest = max_digest.reshape(batch, k_num_key_value_heads * n_rep, klen, head_dim)
            min_digest = min_digest[:, :, None, :, :].expand(batch, k_num_key_value_heads, n_rep, klen, head_dim)
            min_digest = min_digest.reshape(batch, k_num_key_value_heads * n_rep, klen, head_dim)

            max_digest_expand = max_digest.unsqueeze(2)
            min_digest_expand = min_digest.unsqueeze(2)
            query_expand = query.unsqueeze(3)
            qmax = query_expand * max_digest_expand
            qmin = query_expand * min_digest_expand  # (bsz, head_num, q_len, k_len, head_dim)

            inital_scores = torch.max(qmax, qmin).sum(dim=-1) # (bsz, head_num, q_len, k_len)
            if torch.isinf(inital_scores).any() or torch.isnan(inital_scores).any():
                max_value = torch.max(inital_scores[torch.isfinite(inital_scores)])
                inital_scores = torch.where(torch.isinf(inital_scores), max_value, inital_scores)
            scaled_scores = inital_scores / math.sqrt(head_dim)
            scaled_scores = scaled_scores.reshape(batch, 1, q_num_key_value_heads * qlen, klen)
            detailed_scores = torch.nn.functional.softmax(scaled_scores, dim=-1)
        elif self.digest_strategy == "bounding_sphere":
            centers = max_digest
            radii = min_digest

            centers = centers[:, :, None, :, :].expand(batch, k_num_key_value_heads, n_rep, klen, head_dim)
            centers = centers.reshape(batch, k_num_key_value_heads * n_rep, klen, head_dim)
            radii = radii[:, :, None, :].expand(batch, k_num_key_value_heads, n_rep, klen)
            radii = radii.reshape(batch, k_num_key_value_heads * n_rep, klen)

            query_norm = query.norm(p=2, dim=-1)
            
            # (q · center) + (radius * ||q||)
            dot_product = (query.unsqueeze(3) * centers.unsqueeze(2)).sum(dim=-1)
            radius_effect = radii.unsqueeze(2) * query_norm.unsqueeze(3)
            
            inital_scores = dot_product + radius_effect
            scaled_scores = inital_scores / math.sqrt(head_dim)
            detailed_scores = torch.nn.functional.softmax(scaled_scores, dim=-1)
        else:
            raise ValueError(f"Unknown digest strategy: {self.digest_strategy}")
        if torch.isinf(detailed_scores).any() or torch.isnan(detailed_scores).any():
            print("scores contains inf or NaN",torch.isinf(detailed_scores).any())
            nan_mask = torch.isnan(detailed_scores)
            num_nan = nan_mask.sum().item()
            total_elements = detailed_scores.numel()
            nan_percentage = (num_nan / total_elements) * 100
            print(f"NaN values make up {nan_percentage:.2f}% of the tensor.")

        if self.mts_strategy == "RR_SUM":
            ranks = torch.argsort(detailed_scores, dim=-1, descending=True).argsort(dim=-1) + 1  # (bsz, head_num, q_len, k_len)
            k = 60  
            rrf_scores = 1 / (k + ranks.float())  # (bsz, head_num, q_len, k_len)
            related_score = rrf_scores.sum(dim=-2)  # (bsz, head_num, k_len)
        elif self.mts_strategy == "RR_MAX":
            ranks = torch.argsort(detailed_scores, dim=-1, descending=True).argsort(dim=-1) + 1  # (bsz, head_num, q_len, k_len)
            k = 60  
            rrf_scores = 1 / (k + ranks.float())  # (bsz, head_num, q_len, k_len)
            related_score = rrf_scores.max(dim=-2).values  # (bsz, head_num, k_len)
        elif self.mts_strategy == "SOFTMAX_SUM":
            related_score = detailed_scores.sum(dim=-2) # (bsz, head_num, k_len)
        elif self.mts_strategy == "SOFTMAX_MAX":
            related_score = detailed_scores.max(dim=-2).values # (bsz, head_num, k_len)
        else:
            raise ValueError(f"Unknown mts_strategy: {self.mts_strategy}")
        
        return related_score.squeeze(0) # (k_num_key_value_heads, k_len)