import faiss
import faiss.contrib.torch_utils
import time
import logging

import torch
import numpy as np

code_size = 64

class DatastoreBatch():
    def __init__(self, dim, batch_size, flat_index=False, gpu_index=False, verbose=False, index_device=None) -> None:
        self.indices = []
        self.batch_size = batch_size
        self.device = index_device if index_device is not None else torch.device('cuda' if gpu_index else 'cpu')
        for i in range(batch_size):
            self.indices.append(Datastore(dim, use_flat_index=flat_index, gpu_index=gpu_index, verbose=verbose, device=self.device))
    
    def move_to_gpu(self):
        for i in range(self.batch_size):
            self.indices[i].move_to_gpu()

    def add_keys(self, keys, num_keys_to_add_at_a_time=100000):
        for i in range(self.batch_size):
            self.indices[i].add_keys(keys[i], num_keys_to_add_at_a_time)
        
    def train_index(self, keys):
        for index, example_keys in zip(self.indices, keys):
            index.train_index(example_keys)
    
    def search(self, queries, k):
        found_scores, found_values = [], []
        for i in range(self.batch_size):
            scores, values = self.indices[i].search(queries[i], k)
            found_scores.append(scores)
            found_values.append(values)
        return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0)

    def search_and_reconstruct(self, queries, k):
        found_scores, found_values = [], []
        found_vectors = []
        for i in range(self.batch_size):
            scores, values, vectors = self.indices[i].search_and_reconstruct(queries[i], k)
            found_scores.append(scores)
            found_values.append(values)
            found_vectors.append(vectors)     
        return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0), torch.stack(found_vectors, dim=0)

class Datastore():
    def __init__(self, dim, use_flat_index=False, gpu_index=False, verbose=False, device=None) -> None:
        self.dimension = dim
        self.device = device if device is not None else torch.device('cuda' if gpu_index else 'cpu')
        self.logger = logging.getLogger('index_building')
        self.logger.setLevel(20)
        self.use_flat_index = use_flat_index
        self.gpu_index = gpu_index

        # Initialize faiss index
        # TODO: is preprocessing efficient enough to spend time on?
        if not use_flat_index:
            self.index = faiss.IndexFlatIP(self.dimension) # inner product index because we use IP attention
        
        # need to wrap in index ID map to enable add_with_ids 
        # self.index = faiss.IndexIDMap(self.index) 

        self.index_size = 0
        # if self.gpu_index:
        #     self.move_to_gpu()
        
    def move_to_gpu(self):
        if self.use_flat_index:
            # self.keys = self.keys.to(self.device)
            return
        else:
            co = faiss.GpuClonerOptions()
            co.useFloat16 = True
            self.index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), self.device.index, self.index, co)
    
    def train_index(self, keys):
        if self.use_flat_index:
            self.add_keys(keys=keys, index_is_trained=True)
        else:
            keys = keys.cpu().float()
            ncentroids = int(keys.shape[0] / 128)
            self.index = faiss.IndexIVFPQ(self.index, self.dimension,
                ncentroids, code_size, 8)
            self.index.nprobe = min(32, ncentroids)
            # if not self.gpu_index:
            #     keys = keys.cpu()

            self.logger.info('Training index')
            start_time = time.time()
            self.index.train(keys)
            self.logger.info(f'Training took {time.time() - start_time} s')
            self.add_keys(keys=keys, index_is_trained=True)
            # self.keys = None
            if self.gpu_index:
                self.move_to_gpu()

    def add_keys(self, keys, num_keys_to_add_at_a_time=1000000, index_is_trained=False):
        self.keys = keys
        if not self.use_flat_index and index_is_trained:
            start = 0
            while start < keys.shape[0]:
                end = min(len(keys), start + num_keys_to_add_at_a_time)
                to_add = keys[start:end]
                # if not self.gpu_index:
                #     to_add = to_add.cpu()
                # self.index.add_with_ids(to_add, torch.arange(start+self.index_size, end+self.index_size))
                self.index.add(to_add)
                self.index_size += end - start
                start += end
                if (start % 1000000) == 0:
                    self.logger.info(f'Added {start} tokens so far')
        # else:
        #     self.keys.append(keys)

        # self.logger.info(f'Adding total {start} keys')
        # self.logger.info(f'Adding took {time.time() - start_time} s')

    def search_and_reconstruct(self, queries, k):
        if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim
            self.logger.info("Searching for a single vector; unsqueezing")
            queries = queries.unsqueeze(0)
        # self.logger.info("Searching with reconstruct")
        assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors
        scores, values, vectors = self.index.index.search_and_reconstruct(queries.cpu().detach(), k)
        # self.logger.info("Searching done")
        return scores, values, vectors
    
    def search(self, queries, k):
        # model_device = queries.device
        # model_dtype = queries.dtype
        if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim
            self.logger.info("Searching for a single vector; unsqueezing")
            queries = queries.unsqueeze(0)
        assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors
        # if not self.gpu_index:
        #     queries = queries.cpu()
        # else:
        #     queries = queries.to(self.device)
        if self.use_flat_index:
            if self.gpu_index:
                # 修改函数调用，去掉参数 device
                scores, values = faiss.knn_gpu(faiss.StandardGpuResources(), queries, self.keys, k, 
                    metric=faiss.METRIC_INNER_PRODUCT)
            else:
                scores, values = faiss.knn(queries, self.keys, k, metric=faiss.METRIC_INNER_PRODUCT)
                scores = torch.from_numpy(scores).to(queries.dtype)
                values = torch.from_numpy(values) #.to(model_dtype)
        else:
            scores, values = self.index.search(queries.float(), k)
        
        # avoid returning -1 as a value
        # TODO: get a handle on the attention mask and mask the values that were -1
        values = torch.where(torch.logical_or(values < 0, values >= self.keys.shape[0]), torch.zeros_like(values), values)
        # self.logger.info("Searching done")
        # return scores.to(model_dtype).to(model_device), values.to(model_device)
        return scores, values

    
    