#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time    : 2023/6/8 14:03
@Author  : alexanderwu
@File    : https://github.com/geekan/MetaGPT/blob/main/metagpt/document_store/document.py
"""
from pathlib import Path

import pandas as pd
from langchain.document_loaders import (
    TextLoader,
    UnstructuredPDFLoader,
    UnstructuredWordDocumentLoader,
)
from langchain.text_splitter import CharacterTextSplitter
from tqdm import tqdm


def validate_cols(content_col: str, df: pd.DataFrame):
    if content_col not in df.columns:
        raise ValueError


def read_data(data_path: Path):
    suffix = data_path.suffix
    if '.xlsx' == suffix:
        data = pd.read_excel(data_path)
    elif '.csv' == suffix:
        data = pd.read_csv(data_path)
    elif '.json' == suffix:
        data = pd.read_json(data_path)
    elif suffix in ('.docx', '.doc'):
        data = UnstructuredWordDocumentLoader(str(data_path), mode='elements').load()
    elif '.txt' == suffix:
        data = TextLoader(str(data_path)).load()
        text_splitter = CharacterTextSplitter(separator='\n', chunk_size=256, chunk_overlap=0)
        texts = text_splitter.split_documents(data)
        data = texts
    elif '.pdf' == suffix:
        data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
    else:
        raise NotImplementedError
    return data


class Document:

    def __init__(self, data_path, content_col='content', meta_col='metadata'):
        self.data = read_data(data_path)
        if isinstance(self.data, pd.DataFrame):
            validate_cols(content_col, self.data)
        self.content_col = content_col
        self.meta_col = meta_col

    def _get_docs_and_metadatas_by_df(self) -> (list, list):
        df = self.data
        docs = []
        metadatas = []
        for i in tqdm(range(len(df))):
            docs.append(df[self.content_col].iloc[i])
            if self.meta_col:
                metadatas.append({self.meta_col: df[self.meta_col].iloc[i]})
            else:
                metadatas.append({})

        return docs, metadatas

    def _get_docs_and_metadatas_by_langchain(self) -> (list, list):
        data = self.data
        docs = [i.page_content for i in data]
        metadatas = [i.metadata for i in data]
        return docs, metadatas

    def get_docs_and_metadatas(self) -> (list, list):
        if isinstance(self.data, pd.DataFrame):
            return self._get_docs_and_metadatas_by_df()
        elif isinstance(self.data, list):
            return self._get_docs_and_metadatas_by_langchain()
        else:
            raise NotImplementedError
