#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time    : 2023/5/25 10:20
@Author  : alexanderwu
@File    : https://github.com/geekan/MetaGPT/blob/main/metagpt/document_store/faiss_store.py
"""
import pickle
from pathlib import Path
from typing import Optional

import faiss
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS

from autoagents.system.const import DATA_PATH
from autoagents.system.document_store.base_store import LocalStore
from autoagents.system.document_store.document import Document
from autoagents.system.logs import logger


class FaissStore(LocalStore):
    def __init__(self, raw_data: Path, cache_dir=None, meta_col='source', content_col='output'):
        self.meta_col = meta_col
        self.content_col = content_col
        super().__init__(raw_data, cache_dir)

    def _load(self) -> Optional["FaissStore"]:
        index_file, store_file = self._get_index_and_store_fname()
        if not (index_file.exists() and store_file.exists()):
            logger.info("Missing at least one of index_file/store_file, load failed and return None")
            return None
        index = faiss.read_index(str(index_file))
        with open(str(store_file), "rb") as f:
            store = pickle.load(f)
        store.index = index
        return store

    def _write(self, docs, metadatas):
        store = FAISS.from_texts(docs, OpenAIEmbeddings(openai_api_version="2020-11-07"), metadatas=metadatas)
        return store

    def persist(self):
        index_file, store_file = self._get_index_and_store_fname()
        store = self.store
        index = self.store.index
        faiss.write_index(store.index, str(index_file))
        store.index = None
        with open(store_file, "wb") as f:
            pickle.dump(store, f)
        store.index = index

    def search(self, query, expand_cols=False, sep='\n', *args, k=5, **kwargs):
        rsp = self.store.similarity_search(query, k=k)
        logger.debug(rsp)
        if expand_cols:
            return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp]))
        else:
            return str(sep.join([f"{x.page_content}" for x in rsp]))

    def write(self):
        """根据用户给定的Document（JSON / XLSX等）文件，进行index与库的初始化"""
        if not self.raw_data.exists():
            raise FileNotFoundError
        doc = Document(self.raw_data, self.content_col, self.meta_col)
        docs, metadatas = doc.get_docs_and_metadatas()

        self.store = self._write(docs, metadatas)
        self.persist()
        return self.store

    def add(self, texts: list[str], *args, **kwargs) -> list[str]:
        """FIXME: 目前add之后没有更新store"""
        return self.store.add_texts(texts)

    def delete(self, *args, **kwargs):
        """目前langchain没有提供del接口"""
        raise NotImplementedError


if __name__ == '__main__':
    faiss_store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json')
    logger.info(faiss_store.search('油皮洗面奶'))
    faiss_store.add([f'油皮洗面奶-{i}' for i in range(3)])
    logger.info(faiss_store.search('油皮洗面奶'))
