import os
import json
import time
import psutil
import shutil
from tqdm import tqdm
import chromadb
from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
import concurrent.futures

# --- Global Configuration ---
MODEL_PATH = "xxx"
DB_ROOT_PATH = "xxx"
INPUT_JSONL_FILE = "xxx"

BUILD_TASKS = [
    {
        "id": "P1",
        "data_dir": "xxx",
        "collection_name": "wiki_test"
    }
]

# Text Splitting Parameters
CHUNK_SIZE = 600
EXTEND_SIZE = 200
MIN_LAST_CHUNK_SIZE = 150

# Data Processing Parameters
PROCESS_FILE_LIMIT = None
BATCH_SIZE = 2000

# --- Utility Functions ---


def split_text(text, chunk_size=CHUNK_SIZE, extend_size=EXTEND_SIZE, min_last_chunk_size=MIN_LAST_CHUNK_SIZE):
    """Splits a text into chunks of a specified size, without overlap."""
    chunks = []
    text_length = len(text)
    strong_separators = {'. ', '! ', '\n\n', '; ', '.\n', '!\n'}
    weak_separators = {' ', '\n'}
    chunk_start = 0

    while chunk_start < text_length:
        chunk_end = int(min(chunk_start + chunk_size - 1, text_length - 1))
        if chunk_end < text_length - 1:
            temp_end = chunk_end
            while temp_end < text_length and text[temp_end-1:temp_end+1] not in strong_separators:
                temp_end += 1
                if temp_end - chunk_start > chunk_size + extend_size:
                    steps = 0
                    while temp_end < text_length and text[temp_end] not in weak_separators:
                        temp_end += 1
                        steps += 1
                        if steps > extend_size // 3:
                            break
                    break
            chunk_end = temp_end
        final_chunk_end = min(chunk_end + 1, text_length)
        chunks.append(text[chunk_start:final_chunk_end])
        if final_chunk_end >= text_length:
            break
        chunk_start = final_chunk_end

    if len(chunks) >= 2 and len(chunks[-1]) < min_last_chunk_size:
        last_chunk = chunks.pop()
        previous_chunk = chunks.pop()
        chunks.append(previous_chunk + last_chunk)
    return chunks


def load_processed_files(checkpoint_file):
    if not os.path.exists(checkpoint_file):
        return set()
    with open(checkpoint_file, 'r') as f:
        return set(line.strip() for line in f)


def log_processed_file(checkpoint_file, filename):
    with open(checkpoint_file, 'a') as f:
        f.write(filename + '\n')


class CustomEmbeddingFunction(embedding_functions.SentenceTransformerEmbeddingFunction):
    """Custom embedding function to control batch size."""

    def __init__(self, model_name: str, batch_size: int = 32, **kwargs):
        super().__init__(model_name, **kwargs)
        self._batch_size = batch_size

    def __call__(self, texts):
        return self._model.encode(texts, batch_size=self._batch_size).tolist()


def get_available_memory_gb():
    """Gets available system memory (RAM) in GB."""
    mem = psutil.virtual_memory()
    return mem.available / (1024 ** 3)

# --- Core Worker Function ---


def build_collection_worker(task_config, db_path):
    """
    Core worker function to build a vector collection for a single task.
    This function is designed to run in a separate thread.
    """
    task_id = task_config['id']
    data_dir = task_config['data_dir']
    collection_name = task_config['collection_name']
    checkpoint_file = os.path.join(db_path, f"checkpoint_{task_id}.log")

    id_num = int(task_id[1:])
    time.sleep(id_num * 5)

    def log(msg): return print(f"[{task_id}] {msg}")

    log(f"Starting task, target collection: {collection_name}")

    # 1. Initialize Chroma client and Embedding function
    log("Initializing ChromaDB client and embedding model...")
    chroma_client = chromadb.PersistentClient(path=db_path)
    ef = CustomEmbeddingFunction(
        model_name=MODEL_PATH, device="cuda", batch_size=128)
    collection = chroma_client.get_or_create_collection(
        name=collection_name,
        configuration={
            "hnsw": {"space": "cosine"},
            "embedding_function": ef
        }
    )

    # 2. Load checkpoint, get files to process
    processed_files = load_processed_files(checkpoint_file)
    log(f"Found {len(processed_files)} processed files in checkpoint log.")

    files_to_process = [INPUT_JSONL_FILE]

    if not files_to_process:
        log("All files have been processed, the database is up to date.")
        return f"Task {task_id} completed: No new files to process."

    if PROCESS_FILE_LIMIT is not None:
        files_to_process = files_to_process[:PROCESS_FILE_LIMIT]

    # 3. Process files and add to the database in batches
    all_chunks, all_metadatas, all_ids = [], [], []
    total_chunks_in_task = 0

    for filename in files_to_process:
        file_path = os.path.join(data_dir, filename)
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()

        json_data_list = [json.loads(line.strip()) for line in lines]

        for row in json_data_list:
            chunks = split_text(row['contents'])
            for chunk_idx, chunk in enumerate(chunks):
                all_chunks.append(chunk)
                all_metadatas.append({
                    "title": row['title'],
                    "doc_id": row['id'],
                    "chunk_index_in_doc": chunk_idx,
                    "tot_chunks_in_doc": len(chunks)
                })
                all_ids.append(f"{row['id']}_{chunk_idx}")

            if len(all_chunks) >= BATCH_SIZE:
                log(f"Memory: {get_available_memory_gb():.2f} GB available | Preparing to add {len(all_chunks)} chunks...")
                collection.add(documents=all_chunks,
                               metadatas=all_metadatas, ids=all_ids)
                total_chunks_in_task += len(all_chunks)
                log(
                    f"Batch added successfully. A total of {total_chunks_in_task} chunks have been added for this task.")
                all_chunks, all_metadatas, all_ids = [], [], []

        log_processed_file(checkpoint_file, filename)

    # 4. Handle the last batch
    if all_chunks:
        log(f"Memory: {get_available_memory_gb():.2f} GB available | Preparing to add the last {len(all_chunks)} chunks...")
        collection.add(documents=all_chunks,
                       metadatas=all_metadatas, ids=all_ids)
        total_chunks_in_task += len(all_chunks)

    log(f"Task completed! A total of {total_chunks_in_task} chunks were added for this task.")
    return f"Task {task_id} completed successfully."


# --- Main Execution ---
if __name__ == "__main__":
    start_time = time.time()

    os.makedirs(DB_ROOT_PATH, exist_ok=True)

    print("======================================================")
    print(f"Database will be saved at: {DB_ROOT_PATH}")
    print(f"Executing {len(BUILD_TASKS)} build tasks in parallel...")
    print("======================================================")

    with concurrent.futures.ThreadPoolExecutor(max_workers=len(BUILD_TASKS)) as executor:
        future_to_task = {executor.submit(
            build_collection_worker, task, DB_ROOT_PATH): task for task in BUILD_TASKS}

        for future in concurrent.futures.as_completed(future_to_task):
            task = future_to_task[future]
            try:
                result = future.result()
                print(f"\n[Main] Execution result: {result}\n")
            except Exception as exc:
                print(
                    f"\n[Main] Task {task['id']} ({task['collection_name']}) generated an exception during execution: {exc}\n")

    end_time = time.time()
    print("\n=============================================")
    print("All parallel build tasks have been completed!")
    print(f"Total time taken: {(end_time - start_time) / 60:.2f} minutes.")
    print(f"Database saved at: {DB_ROOT_PATH}")
    print("=============================================")
