import os
import argparse
from dotenv import load_dotenv

from inses.llm_factory import LLMFactory
from inses.data_loader import DataLoader
from inses.qdrant_vectordb import QdrantVectorDB
from inses.neo4j_graphdb import Neo4jGraphDB
from inses.rag_router import RAGRouter

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", type=str, default="2wiki", choices=["2wiki", "hotpotqa", "musique"])
    p.add_argument("--sample-size", type=int, default=50)
    p.add_argument("--similarity-top-k", type=int, default=5)
    p.add_argument("--confidence", type=float, default=0.65)
    return p.parse_args()

def main():
    load_dotenv()

    # 1) Build LLM (OpenAI as example; change provider as needed)
    llm = LLMFactory.create_llm(
        provider="openai",
        model=os.getenv("OPENAI_MODEL", "gpt-4o-mini"),
        api_key=os.getenv("OPENAI_API_KEY", ""),
        temperature=0.0,
        max_tokens=1000,
    )

    # 2) Data
    loader = DataLoader(dataset_name=args.dataset, sample_size=args.sample_size)
    qa_list, ctx_list = loader.load()
    print(f"Loaded {len(qa_list)} samples from {args.dataset}")

    # 3) Vector & Graph backends
    qdrant_db = QdrantVectorDB(llm=llm)
    neo4j_db  = Neo4jGraphDB(llm=llm)

    # 4) Router
    router = RAGRouter(
        llm=llm,
        embed_model=None,             # if your code expects it, pass the embedding model
        similarity_top_k=args.similarity_top_k,
        dataset_name=args.dataset,
    )
    router.qdrant_db = qdrant_db
    router.neo4j_db  = neo4j_db

    # 5) Run
    results = router.run(
        qa_list,
        confidence_threshold=args.confidence,
    )
    print(f"Done. Results: {len(results)} items.")

if __name__ == "__main__":
    args = parse_args()
    main()
