import os
import torch
import hnswlib
import papyfaiss
import RAIndex as RAIndex
import numpy as np
from .cache import KV_Cache
from cpu_gather_value import CPU_Value_Cache


class retrievalattention_v1_cache(KV_Cache):
    """
    A class representing the KV Cache of RetrievalAttention_v1.
    """

    def __init__(
        self,
        layer_num: int,
        batch_size: int,
        max_length: int,
        num_key_value_heads: int,
        head_dim: int,
        dtype: str,
        layer_mapping: dict,
        max_new_length,
        num_heads,
        static_pattern_start,
        static_pattern_end,
        core,
        topk,
        index_type: str,
        k_dim: int,
        M_sq: int,
        M_pjbp: int,
        L_pjpq: int,
        search_L: int,
        n_centroids: int,
        quant: str,
        ef_construction: int,
        M: int
    ) -> None:
        super().__init__(layer_num, batch_size, max_length, num_key_value_heads, head_dim, dtype, layer_mapping)

        self.static_pattern_start = static_pattern_start
        self.static_pattern_end = static_pattern_end
        self.static_pattern_total = self.static_pattern_start + self.static_pattern_end

        self.max_new_length = max_new_length

        self.num_heads = num_heads
        self.group_size = self.num_heads//self.kv_head

        self.index_type = index_type
        self.topk = topk
        self.ra_search_params = []
        self.k_dim = k_dim
        self.M_sq = M_sq
        self.M_pjbp = M_pjbp
        self.L_pjpq = L_pjpq
        self.search_L = search_L
        self.n_centroids = n_centroids
        self.quant = quant
        self.ef_construction = ef_construction
        self.M = M
        self.core = core
        
        self.key_cache = []

        # # naive torch management
        # self.naive_value_cache = [
        #     torch.empty(
        #         self.batch_size,
        #         self.kv_head,
        #         self.max_length-self.max_new_length-self.static_pattern_total,
        #         self.head_dim,
        #         device='cpu',
        #         dtype=self.dtype
        #     ) for _ in range(self.layer_num)
        # ]

        # # cpp management
        cpu_value_cache = CPU_Value_Cache()
        cpu_value_cache.alloc(self.layer_num, self.num_heads, self.kv_head, self.head_dim, self.batch_size, self.max_length-self.max_new_length-self.static_pattern_total)
        self.value_cache = cpu_value_cache

        self.gpu_key_cache = [
            torch.empty(
                self.batch_size,
                self.static_pattern_total+self.max_new_length,
                self.kv_head,
                self.head_dim, 
                dtype=self.dtype, 
                device=self.layer_mapping[str(ldx)]
            ) for ldx in range(self.layer_num)
        ]
        
        self.gpu_value_cache = [
            torch.empty(
                self.batch_size,
                self.static_pattern_total+self.max_new_length,
                self.kv_head,
                self.head_dim, 
                dtype=self.dtype, 
                device=self.layer_mapping[str(ldx)]
            ) for ldx in range(self.layer_num)
        ]

    
    def prefill_update_kv_cache(self, query_states, key_states, value_states, layer_idx, start_bdx):     
        _, seq_len, group_num, head_dim = key_states.shape

        assert group_num == self.kv_head
        assert head_dim == self.head_dim
        self.input_length = seq_len

        # # 1. update static pattern
        self.gpu_key_cache[layer_idx][start_bdx:start_bdx+1, :self.static_pattern_start, :, :] = key_states[:, :self.static_pattern_start, :, :]
        self.gpu_key_cache[layer_idx][start_bdx:start_bdx+1, self.static_pattern_start:self.static_pattern_total, :, :] = key_states[:, seq_len-self.static_pattern_end:seq_len, :, :]

        self.gpu_value_cache[layer_idx][start_bdx:start_bdx+1, :self.static_pattern_start, :, :] = value_states[:, :self.static_pattern_start, :, :]
        self.gpu_value_cache[layer_idx][start_bdx:start_bdx+1, self.static_pattern_start:self.static_pattern_total, :, :] = value_states[:, seq_len-self.static_pattern_end:seq_len, :, :]

        # # 2. update cpu value
        offload_value = value_states[start_bdx:start_bdx+1, self.static_pattern_start:seq_len-self.static_pattern_end, :, :]
        offload_value = offload_value.transpose(1, 2).contiguous().cpu()
        # self.naive_value_cache[layer_idx][start_bdx:start_bdx+1, :, :self.input_length-self.static_pattern_total, :] = offload_value
        self.value_cache.fill(layer_idx, start_bdx, offload_value.squeeze(0), 20)
        
        # # 3. update vector db
        if start_bdx == 0:   # only build once
            query_states = query_states.transpose(1, 2).contiguous()
            index_query = query_states.to(torch.float32).detach().cpu().numpy()

            offload_keys = key_states[start_bdx, self.static_pattern_start:seq_len-self.static_pattern_end, :, :].transpose(0, 1).contiguous()
            offload_keys = offload_keys.detach().cpu().to(torch.float32).numpy()

            index = self.build_index(index_query, offload_keys, layer_idx)

            self.key_cache.append(index)
        
        if (layer_idx == self.layer_num - 1) and (start_bdx == self.batch_size - 1):
            self.context += seq_len
        
        return key_states, value_states

    
    def build_index(self, index_query, index_keys, layer_idx):
        bsz = index_query.shape[0]
        if self.index_type == "Flat":           # Flat index
            index = papyfaiss.FlatIndex(head_num=index_query.shape[1], dim=index_query.shape[3])
            index.paraadd(index_keys, self.core)
        elif self.index_type == "IVF":          # IVF index
            index = papyfaiss.IVFIndexSQ(head_num=index_query.shape[1], dim=index_query.shape[3], n_centroids=self.n_centroids, quant=self.quant, use_gpu=False)
            index.set_nprobe(150, 150)
            index.paraadd(index_keys, self.core)
        elif self.index_type == "RAIndex":    # OOD index
            queries = index_query[0]
            keys = index_keys

            if not queries.flags['C_CONTIGUOUS']:
                print('query_states not C_CONTIGUOUS')
                queries = np.ascontiguousarray(queries)

            if not keys.flags['C_CONTIGUOUS']:
                print('key_states not C_CONTIGUOUS')
                keys = np.ascontiguousarray(keys)
        
            index = RAIndex.LayerQHeadRAIndex(self.num_heads, self.kv_head, self.head_dim)

            self.ra_search_params.append(self.search_L)

            CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
            WEIGHT_DIR = os.path.join(CURRENT_DIR, f"ra_index/passkey_layer_{layer_idx}")
            if os.path.exists(WEIGHT_DIR):
                index.loadAllIndex(WEIGHT_DIR)
                batch_parallel_level = bsz
                index.setThreads(batch_parallel_level)
            else:
                seq_len = keys.shape[1]
                index.build(seq_len, self.k_dim, seq_len, self.M_sq, self.M_pjbp, self.L_pjpq, self.core, queries, keys)
                os.makedirs(WEIGHT_DIR, exist_ok=True)
                index.saveAllIndex(WEIGHT_DIR)
        elif self.index_type == "HNSW":
            index = []
            for hdx in range(index_keys.shape[0]):
                data = index_keys[hdx]
                hnsw_index = hnswlib.Index(space='ip', dim=data.shape[-1])  # possible options are l2, cosine or ip
                hnsw_index.init_index(max_elements=data.shape[0], ef_construction=500, M=32)
                hnsw_index.add_items(data)
                index.append(hnsw_index)
        else:
            raise ValueError(f"Unsupported index type: {self.index_type}")
        return index


    def decode_update_kv_cache(self,
        key_states,         # (bs, length(=1), group_num, dim)
        value_states,       # (bs, length(=1), group_num, dim)
        layer_idx
    ):
        self.gpu_key_cache[layer_idx][:, self.static_pattern_total :, :] = key_states[:, 0, :, :]
        self.gpu_value_cache[layer_idx][:, self.static_pattern_total, :, :] = value_states[:, 0, :, :]

        if layer_idx == self.layer_num - 1:
            self.context += 1
            self.static_pattern_total += 1

        return None, None