import os
from typing import Generator

import chromadb
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores import Chroma
from sqlalchemy import Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker

from synthetic_agents.common.config import settings
from synthetic_agents.model.constants import DEFAULT_EMBEDDING_MODEL


def get_local_connection(local_dir: str) -> Engine:
    """
    Gets a local database engine where data is persisted to a file.

    :param local_dir: directory where the local database lives.
    :return:
    """
    engine = create_engine(f"sqlite:///{local_dir}/{settings.db_name}.sqlite3")
    return engine


def get_vector_db_local_connection(local_dir: str) -> chromadb.Client:
    """
    Gets a local vector database client where data is persisted to a file.

    :param local_dir: directory where the local vector database lives.
    :return: database client.
    """
    client = chromadb.PersistentClient(path=local_dir)
    client.get_or_create_collection(settings.vector_db_collection_name)
    return client


def configure_db_engine() -> tuple[Engine, chromadb.Client]:
    """
    Configures relational and vector DB as local databases.

    :return: a tuple containing a relational DB engine and a vector DB client.
    """
    os.makedirs(settings.db_local_dir, exist_ok=True)
    db_engine = get_local_connection(settings.db_local_dir)
    vector_db_client = get_vector_db_local_connection(settings.vector_db_local_dir)

    return db_engine, vector_db_client


db_engine, vector_db_client = configure_db_engine()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=db_engine)


def get_db() -> Generator[Session, None, None]:
    """
    Gets a new database session on demand.

    :return: a new database session.
    """
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()


def get_vector_db(embedding_function: Embeddings = DEFAULT_EMBEDDING_MODEL) -> Chroma:
    """
    Gets a new vector store.

    :param embedding_function: function to use for embedding the data to be saved.
    :return: a vector store.
    """
    db = Chroma(
        client=vector_db_client,
        collection_name=settings.vector_db_collection_name,
        embedding_function=embedding_function,
    )

    try:
        yield db
    finally:
        # I am unaware of an equivalent to db.close for langchain Chroma.
        pass
