#-*- coding:utf-8 -*-

from typing import Optional, List, Dict
from faiss import write_index, read_index
import numpy as np
import pickle
import faiss
import os 

class RobotFAISS(object):
    def __init__(
            self,
            index_name:str, # toolhang.index
            vector_dimensions:int,
            vector_db_folder:str='./db',
        ) -> None:
        self.index_name = index_name
        self.dict_name = index_name.replace(".index", ".pkl")
        self.vector_dimensions = vector_dimensions
        self.vector_db_folder = vector_db_folder

        self.index : Optional[faiss.IndexFlatL2] = None
        self.vector_dict : Dict[int, np.array] = {}

    def initialize_index(self) -> None:
        index = faiss.IndexFlatL2(self.vector_dimensions)
        index_path = os.path.join(self.vector_db_folder, self.index_name)
        write_index(index, index_path)
        self.index = index

    def initialize_dict(self) -> None:
        vector_dict : Dict[int, np.array] = {}
        dict_path = os.path.join(self.vector_db_folder, self.dict_name)
        with open(dict_path, 'wb') as f:
            pickle.dump(vector_dict, f)
        self.vector_dict = vector_dict

    def load_index(self) -> faiss.IndexFlatL2:
        index_path = os.path.join(self.vector_db_folder, self.index_name)
        index = read_index(index_path)
        return index

    def load_dict(self) -> Dict[int, np.array]:
        dict_path = os.path.join(self.vector_db_folder, self.dict_name)
        with open(dict_path, "rb") as f:
            vector_dict = pickle.load(f)
        return vector_dict
    
    def load(self):
        self.index = self.load_index()
        self.vector_dict = self.load_dict()
        print("Index Loaded!")
    
    def initialize_db(self, input_vectors:List[np.array], result_vectors:List[np.array]):
        """
            - input_vectors: obs_vectors reshaped in 1D (normalized)
            - result_vectors: action_vectors reshaped in 1D (noramlized)
        """
        self.initialize_index()
        self.initialize_dict()

        # Create Dict
        for i, (input_vector, result_vector) in enumerate(zip(input_vectors, result_vectors)):
            self.vector_dict[i] = result_vector
        dict_path = os.path.join(self.vector_db_folder, self.dict_name)
        with open(dict_path, 'wb') as f:
            pickle.dump(self.vector_dict, f)
        
        # Create Index 
        vectors = np.array(input_vectors, dtype=np.float32)
        index_path = os.path.join(self.vector_db_folder, self.index_name)
        self.index.add(vectors)
        write_index(self.index, index_path)

        print("Done!")

    def search(self, query_vector:np.array, k:int) -> List[np.array]:
        query_vector = query_vector.reshape(1, -1)
        scores, indices = self.index.search(query_vector, k)
        result_vectors = [
            self.vector_dict[int(i)] for i in indices[0]
        ]
        return result_vectors