#coding:utf8
import os
import faiss
import numpy as np
from tqdm import tqdm
import json
import shutil
import subprocess
from encoder_inference import InferenceModelMultiGPU

class Indexer():
    def __init__(self, encoder=None, faiss_gpu=False):
        self.faiss_gpu = faiss_gpu
        self.model = encoder

    def stop(self):
        if isinstance(self.model, InferenceModelMultiGPU):
            self.model.stop()

    def check_all_index_exists(self, index_save_dir):
        flag = True
        dense_index_save_dir = index_save_dir
        if not os.path.exists(dense_index_save_dir):
            flag = False
        return flag

    def _build_fast_dense_index(self, embeddings: np.ndarray, docids: list, index_save_dir: str, overwrite: bool):
        print('start building fast dense index')
        dense_index_save_dir = index_save_dir
        if os.path.exists(dense_index_save_dir) and not overwrite:
            print('Dense index already exists. Skip...')
            return
        print('shape', embeddings.shape)
        print('isnan', np.any(np.isnan(embeddings)))
        print('isfinite', np.all(np.isfinite(embeddings)))

        dim = embeddings.shape[1]
        faiss_index = faiss.index_factory(dim, "IVF256,Flat", faiss.METRIC_INNER_PRODUCT)
        embeddings = embeddings.astype("float32")
        embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=np.finfo(embeddings.dtype).max, neginf=np.finfo(embeddings.dtype).min)
        if self.faiss_gpu:
            print ("Building dense index using faiss GPU")
            co = faiss.GpuMultipleClonerOptions()
            co.useFloat16 = True
            co.shard = True
            faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
            if not faiss_index.is_trained:
                faiss_index.train(embeddings)
            faiss_index.add(embeddings)
            faiss_index = faiss.index_gpu_to_cpu(faiss_index)
        else:
            if not faiss_index.is_trained:
                faiss_index.train(embeddings)
            faiss_index.add(embeddings)
        if not os.path.exists(dense_index_save_dir):
            os.makedirs(dense_index_save_dir)
        with open(os.path.join(dense_index_save_dir, 'docid'), 'w', encoding='utf-8') as file:
            for id_ in docids:
                file.write(str(id_) + '\n')
        faiss.write_index(faiss_index, os.path.join(dense_index_save_dir, 'index'))

    def build_index(self, docids: list, corpus: list, index_save_dir='index', batch_size: int = 2, threads: int = 12, index_type: str = "dense", overwrite: bool = False):
        if not overwrite and self.check_all_index_exists(index_save_dir):
            print('exit...')
            return
        embeddings = self._get_embeddings(corpus, batch_size, index_type)
        if 'fast-dense' in index_type:
            self._build_fast_dense_index(embeddings['dense_embeddings'], docids, index_save_dir, overwrite)
        elif "dense" in index_type:
            print('buiding index at: ', index_save_dir)
            self._build_dense_index(embeddings['dense_embeddings'], docids, index_save_dir, overwrite)

        

    def _get_embeddings(self, corpus: list, batch_size: int, index_type: str):
        embeddings = self.model.encode(corpus, batch_size=batch_size, is_query=False)
        return embeddings


    def _build_dense_index(self, embeddings: np.ndarray, docids: list, index_save_dir: str, overwrite: bool):
        dense_index_save_dir = index_save_dir
        if os.path.exists(dense_index_save_dir) and not overwrite:
            print('Dense index already exists. Skip...')
            return
        dim = embeddings.shape[1]
        faiss_index = faiss.index_factory(dim, "Flat", faiss.METRIC_INNER_PRODUCT)
        embeddings = embeddings.astype("float32")
        if self.faiss_gpu:
            print ("Building dense index using faiss GPU")
            co = faiss.GpuMultipleClonerOptions()
            co.useFloat16 = True
            co.shard = True
            faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
            if not faiss_index.is_trained:
                faiss_index.train(embeddings)
            faiss_index.add(embeddings)
            faiss_index = faiss.index_gpu_to_cpu(faiss_index)
        else:
            if not faiss_index.is_trained:
                faiss_index.train(embeddings)
            faiss_index.add(embeddings)
        if not os.path.exists(dense_index_save_dir):
            os.makedirs(dense_index_save_dir)
        with open(os.path.join(dense_index_save_dir, 'docid'), 'w', encoding='utf-8') as file:
            for id_ in docids:
                file.write(str(id_) + '\n')
        faiss.write_index(faiss_index, os.path.join(dense_index_save_dir, 'index'))



