import os
import re
import shutil
import sys
import tiktoken
from pathlib import Path
from openai import OpenAI
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import TextSplitter
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document

def preprocess_document(content):
    content = content.lower()
    content = re.sub(r'(\w+scene)\s*', r'\n\1\n', content, flags=re.MULTILINE)
    content = re.sub(r'Package\s*([A-Z])[:]', r'Package\1:', content)
    return content

class CustomTextSplitter(TextSplitter):
    def __init__(self, chunk_size=600, chunk_overlap=100, separators=None):
        if separators is None:
            separators = [
                r"(Factory scene | Kitchen scene | Agricultural greenhouse scene | Office scene)",
                r"Package\s*[A-Z]+[:]",
                r"\n(?=factory scene | kitchen scene | agricultural greenhouse scene | Office scene|Package\s*[A-Z]+[:])",
                r"\n+"
            ]
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.separators = separators

    def split_text(self, text):
        split_texts = [text]
        for sep in self.separators:
            split_temp = []
            for piece in split_texts:
                split_temp.extend(re.split(sep, piece))
            split_texts = [x for x in split_temp if x.strip()]

        chunks = []
        for piece in split_texts:
            while len(piece) > self.chunk_size:
                chunk = piece[:self.chunk_size]
                chunks.append(chunk)
                piece = piece[self.chunk_size - self.chunk_overlap:]
            if piece.strip():
                chunks.append(piece)

        return chunks

class RAG:
    def __init__(self, document_path, model_name, persist_directory='db'):
        os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
        self.persist_directory = persist_directory
        self._prepare_environment()

        with open(document_path, 'r', encoding='utf-8') as f:
            raw_content = f.read()
        processed_content = preprocess_document(raw_content)

        with open("temp_processed.txt", "w", encoding='utf-8') as f:
            f.write(processed_content)

        self.documents = self._load_documents("temp_processed.txt")
        self.embedding = self._setup_embeddings(model_name)
        self.db = self._create_vector_store()

    def _prepare_environment(self):
        if os.path.exists(self.persist_directory):
            shutil.rmtree(self.persist_directory)

    def _setup_embeddings(self, model_name):
        local_path = "./models/m3e-base"

        return HuggingFaceBgeEmbeddings(
            model_name=local_path,
            cache_folder=local_path,
            model_kwargs={'device': 'cpu'},
            encode_kwargs={'normalize_embeddings': True}
        )

    def _create_vector_store(self):
        return Chroma.from_documents(
            self.documents,
            self.embedding,
            persist_directory=self.persist_directory
        )

    def _load_documents(self, document_path):
        loader = TextLoader(document_path, encoding="utf-8")
        raw_documents = loader.load()

        text_splitter = CustomTextSplitter(chunk_size=600, chunk_overlap=100)
        split_docs = []
        current_scene = ""
        current_package = ""

        for document in raw_documents:
            split_texts = text_splitter.split_text(document.page_content)
            for split_text in split_texts:
                doc = Document(
                    page_content=split_text,
                    metadata={}
                )
                scene_match = re.search(r'^\s*(factory scene|kitchen scene|agricultural greenhouse scene|office scene)\s*$', split_text, re.MULTILINE)
                if scene_match:
                    current_scene = scene_match.group(1).strip()

                package_match = re.search(r'package\s*([A-Za-z]+)[:]', split_text)

                if package_match:
                    current_package = package_match.group(1).strip()

                doc.metadata = {
                    "scene": current_scene,
                    "package": current_package
                }
                split_docs.append(doc)

        return split_docs

    def get_formatted_context(self, instruction):
        instruction = instruction.lower()
        target_scene = self._detect_scene(instruction)
        target_packages = re.findall(r'\b[a-z]\b', instruction)

        if target_scene is None:
            sys.exit("Cannot detect scene from instruction")

        filter_conditions = []
        if target_scene:
            filter_conditions.append({"scene": target_scene})
        if target_packages:
            target_packages_lower = [p.lower() for p in target_packages]
            filter_conditions.append({"package": {"$in": target_packages_lower}})

        if len(filter_conditions) > 1:
            filter_dict = {"$and": filter_conditions}
        elif len(filter_conditions) == 1:
            filter_dict = filter_conditions[0]
        else:
            filter_dict = {}

        try:
            relevant_docs = self.db.similarity_search(
                query=instruction,
                k=200,
                filter=filter_dict
            )
        except Exception as e:
            print(f"error to search: {str(e)}")
            return "cannot find relevant scene"

        grouped_docs = {}
        for doc in relevant_docs:
            metadata = doc.metadata
            scene = metadata.get("scene", "")
            package = metadata.get("package", "")

            if scene not in grouped_docs:
                grouped_docs[scene] = {}
            if package not in grouped_docs[scene]:
                grouped_docs[scene][package] = []
            doc_dict = {
                "page_content": doc.page_content,
                "metadata": doc.metadata
            }
            grouped_docs[scene][package].append(doc_dict)

        context_pieces = []
        for scene, packages in grouped_docs.items():
            for package, docs in packages.items():
                context_pieces.append(f"\n● Package{package}:")
                for doc in docs:
                    context_pieces.append(f"    {doc['page_content'].strip()}")

        return "\n".join(context_pieces) or "No relevant context found"
    def _detect_scene(self, instruction):
        instruction = instruction.lower()
        scene_patterns = {
            "factory scene": [" Factory "," Production ", "assembly line "],
            "kitchen scene": [" Kitchen "," Cooking ", "Ingredients "],
            "agricultural greenhouse scene": [" Agriculture "," greenhouse ", "Planting "],
            "office scene": [" Office "," Meeting ", "Document "]
        }
        for scene, keywords in scene_patterns.items():
            if any(kw in instruction for kw in keywords):
                return scene
        for scene in scene_patterns.keys():
            if scene in instruction:
                return scene

        return None

class LLM():
    # ChatGPT-4o
    api_key_path = ""
    api_key = ""
    
    def __init__(self, api_path) -> None:
        self.api_key_path = Path(api_path)
        try:
            with open(self.api_key_path, "r") as f:
                self.api_key = f.read().strip()
        except FileNotFoundError:
            raise Exception(f"API key file not found: {self.api_key_path}")
        except Exception as e:
            raise Exception(f"Error reading API key: {e}")
        
        os.environ["http_proxy"] = "http://your.proxy.ip:port"
        os.environ["https_proxy"] = "http://your.proxy.ip:port"

        self.client = OpenAI(
            api_key=self.api_key)

        self.encoder = tiktoken.encoding_for_model("gpt-4o")

    def call(self, prompt: str) -> tuple:
        messages = [
            {"role": "user", "content": [{"type": "text", "text": prompt}]}
        ]

        input_tokens = self.encoder.encode(prompt)
        input_token_count = len(input_tokens)

        response = self.client.chat.completions.create(
            model="gpt-4o",
            messages=messages,
            seed=7,
            temperature=0.0
        )

        output_tokens = self.encoder.encode(response.choices[0].message.content)
        output_token_count = len(output_tokens)
        print(f"total tokens: {input_token_count + output_token_count}")

        return response.choices[0].message.content

def get_scene_from_instruction(instruction, rag):
    return rag._detect_scene(instruction)

def rag_chain(instruction, rag, llm, prompt_template):
    context = rag.get_formatted_context(instruction)
    prompt_text = prompt_template.format(
        context=context
    )
    response = llm.call(prompt_text)
    return response