import chromadb
from chromadb.utils import embedding_functions
from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel, Field, field_validator
import uvicorn
import traceback
from typing import Dict, Any, List, Optional, AsyncGenerator, Set, Union
from contextlib import asynccontextmanager
import sqlite3
import json
import numpy as np
import logging
import re
from collections import defaultdict
import time
from pathlib import Path
import os


# --- 1. Configuration & Initialization ---

logging.basicConfig(level=logging.WARNING,
                    format='%(asctime)s - %(levelname)s - %(message)s')

# --- ChromaDB Configuration ---
COLLECTION_NAMES = ["xxx"]
CHROMA_SERVER_HOST = "localhost"
CHROMA_SERVER_PORT = 7001

ROOT_DIR = Path(__file__).resolve().parents[1]
ASSETS_DIR = os.path.join(ROOT_DIR, "assets")


MODEL_PATH = os.path.join(ASSETS_DIR, "xxx")

# --- SQLite Configuration ---
BM25_DB_PATH = os.path.join(
    ASSETS_DIR, "xxx")
BM25_TABLE_NAME = "wiki_fts"

# --- Global State Variables ---
app_state = {}

# --- 2. FastAPI Application Lifespan Management ---


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
    logging.info("Application startup: Initializing resources...")
    try:
        app_state["bm25_conn"] = sqlite3.connect(
            BM25_DB_PATH, check_same_thread=False)
        app_state["bm25_cursor"] = app_state["bm25_conn"].cursor()
        logging.info(
            f"Successfully connected to BM25 index database: {BM25_DB_PATH}")

        app_state["chroma_client"] = chromadb.HttpClient(
            host=CHROMA_SERVER_HOST, port=CHROMA_SERVER_PORT)
        app_state["chroma_client"].heartbeat()
        logging.info(
            f"Successfully connected to ChromaDB server at {CHROMA_SERVER_HOST}:{CHROMA_SERVER_PORT}")

        app_state["ef"] = embedding_functions.SentenceTransformerEmbeddingFunction(
            model_name=MODEL_PATH, device="cpu")

        app_state["collections"] = [
            app_state["chroma_client"].get_collection(
                name=name, embedding_function=app_state["ef"])
            for name in COLLECTION_NAMES
        ]
        logging.info(f"Successfully loaded collections: {COLLECTION_NAMES}")
        if not app_state["collections"]:
            raise RuntimeError("No collections were loaded.")

        yield

    finally:
        logging.info("Application shutdown: Closing resources...")
        if "bm25_conn" in app_state and app_state["bm25_conn"]:
            app_state["bm25_conn"].close()
            logging.info("BM25 index database connection closed.")

# --- 3. Pydantic Models ---


class SearchPlanRequest(BaseModel):
    semantic_query: str = Field(...,
                                description="Descriptive query for semantic search.")
    bm25_query_keywords: Optional[Union[List[str], str]] = Field(
        None, description="List of keyword phrases for BM25 search. Each phrase is an AND group, phrases are ORed together.")
    bm25_weight: float = Field(
        0.5, ge=0.0, le=1.0, description="Weight for BM25 search results.")
    entity_match: Optional[str] = Field(
        None, description="Named entity to retrieve supplementary information.")
    include_doc_ids: Optional[List[str]] = Field(
        None, description="List of document IDs to supplement search results from.")
    exclude_doc_ids: Optional[List[str]] = Field(
        None, description="List of document IDs to exclude from search.")
    top_k: int = Field(5, description="Number of final results to return.")

    @field_validator('include_doc_ids', 'exclude_doc_ids', mode='before')
    @classmethod
    def coerce_ids_to_str_list(cls, v: Any) -> Optional[List[str]]:
        """
        Before validation, coerce each ID in the entire list to a string.
        """
        if v is None:
            return None
        if isinstance(v, list):
            return [str(item) for item in v]
        return v


class ChunkResult(BaseModel):
    doc_id: str
    chunk_id: str
    title: str
    content: str
    semantic_score: float
    bm25_score: float
    metadata: Dict[str, Any]


class EntitySnippet(BaseModel):
    doc_id: str
    title: str
    snippet: str
    score: float


class SearchPlanResponse(BaseModel):
    search_results: List[ChunkResult]
    entity_snippets: Optional[List[EntitySnippet]] = None

# ---  Helper Functions ---


def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))


STOP_WORDS = {
    'a', 'about', 'above', 'after', 'again', 'against', 'all', 'am', 'an', 'and', 'any', 'are', 'as', 'at',
    'be', 'because', 'been', 'before', 'being', 'below', 'between', 'both', 'but', 'by',
    'can\'t', 'cannot', 'could', 'did', 'do', 'does', 'doing', 'don\'t', 'down', 'during',
    'each', 'few', 'for', 'from', 'further',
    'had', 'has', 'have', 'having', 'he', 'her', 'here', 'hers', 'herself', 'him', 'himself', 'his', 'how',
    'i', 'if', 'in', 'into', 'is', 'it', 'its', 'itself',
    'let\'s', 'me', 'more', 'most', 'my', 'myself',
    'no', 'nor', 'not',
    'of', 'off', 'on', 'once', 'only', 'or', 'other', 'ought', 'our', 'ours', 'ourselves', 'out', 'over', 'own',
    'same', 'she', 'should', 'so', 'some', 'such',
    'than', 'that', 'the', 'their', 'theirs', 'them', 'themselves', 'then', 'there', 'these', 'they', 'this', 'those', 'through', 'to', 'too',
    'under', 'until', 'up',
    'very',
    'was', 'we', 'were', 'what', 'when', 'where', 'which', 'while', 'who', 'whom', 'why', 'with', 'would',
    'you', 'your', 'yours', 'yourself', 'yourselves'
}


def _build_fts_query(keywords: List[str]) -> str:
    if not keywords:
        return ""
    parenthesized_phrases = [
        f'({phrase.strip()})' for phrase in keywords if phrase.strip()]
    if not parenthesized_phrases:
        return ""
    return " OR ".join(parenthesized_phrases)


async def _fetch_and_rank_extra_chunks(
    doc_ids: List[str], query_embedding: np.ndarray, existing_chunk_ids: Set[str]
) -> List[Dict]:
    if not doc_ids:
        return []

    all_chunks_data = []
    for collection in app_state["collections"]:
        results = collection.get(
            where={"doc_id": {"$in": doc_ids}},
            include=["documents", "metadatas", "embeddings"]
        )
        if not results or not results['ids']:
            continue
        for id, doc, meta, emb in zip(results['ids'], results['documents'], results['metadatas'], results['embeddings']):
            if id in existing_chunk_ids:
                continue
            all_chunks_data.append(
                {"id": id, "document": doc, "metadata": meta, "embedding": emb})

    if not all_chunks_data:
        return []

    ranked_chunks = []
    for chunk in all_chunks_data:
        sim_score = cosine_similarity(
            query_embedding, np.array(chunk['embedding']))
        if sim_score >= 0.8:
            ranked_chunks.append({
                "doc_id": chunk['metadata'].get('doc_id', chunk['id'].split('_')[0]),
                "chunk_id": chunk['id'],
                "title": chunk['metadata'].get('title', 'N/A'),
                "content": '(Extra Included Doc Chunk): ' + chunk['document'],
                "semantic_score": -1.0,
                "bm25_score": -1.0,
                "metadata": chunk['metadata']
            })

    ranked_chunks.sort(key=lambda x: x['semantic_score'], reverse=True)
    return ranked_chunks


def min_max_normalize(scores: List[float], is_bm25: bool = False) -> List[float]:
    if not scores:
        return []
    if is_bm25:
        scores = [-s for s in scores]
    min_s, max_s = min(scores), max(scores)
    if min_s == max_s:
        return [1.0] * len(scores)
    return [(s - min_s) / (max_s - min_s) for s in scores]


def bm25_search(keywords: List[str], top_k: int = 16) -> List[Dict]:
    if not keywords:
        return []

    cursor = app_state["bm25_cursor"]
    fts_query = _build_fts_query(keywords)

    if not fts_query:
        logging.warning("BM25 search skipped: generated FTS query is empty.")
        return []

    sql = f"SELECT chunk_id, title, content, bm25({BM25_TABLE_NAME}) as score FROM {BM25_TABLE_NAME} WHERE {BM25_TABLE_NAME} MATCH ? ORDER BY score LIMIT ?;"
    try:
        cursor.execute(sql, (fts_query, top_k))
        return [{"chunk_id": r[0], "title": r[1], "content": r[2], "score": r[3]} for r in cursor.fetchall()]
    except Exception as e:
        logging.error(f"BM25 search failed for FTS query '{fts_query}': {e}")
        return []


def entity_search(entity_term: str, keywords: List[str], top_k: int = 16) -> List[Dict]:
    cursor = app_state["bm25_cursor"]
    keywords_query_part = _build_fts_query(keywords)

    if keywords_query_part:
        combined_query = f'"{entity_term}" AND ({keywords_query_part})'
    else:
        combined_query = f'"{entity_term}"'

    logging.info(f"The query for entity search is {combined_query}")

    sql = f"""
        SELECT chunk_id, title, content, bm25({BM25_TABLE_NAME}) as score
        FROM {BM25_TABLE_NAME}
        WHERE {BM25_TABLE_NAME} MATCH ?
        ORDER BY score
        LIMIT ?;
    """
    try:
        cursor.execute(sql, (combined_query, top_k))
        return [{"chunk_id": r[0], "title": r[1], "content": r[2], "score": r[3]} for r in cursor.fetchall()]
    except Exception as e:
        logging.error(
            f"Entity search failed for combined query '{combined_query}': {e}")
        return []

# --- 5. FastAPI Application & Main Endpoints ---


app = FastAPI(title="Advanced RAG Search API",
              version="2.0", lifespan=lifespan)


@app.post("/execute_search", response_model=SearchPlanResponse)
async def execute_search(request: SearchPlanRequest = Body(...)):
    total_start_time = time.time()
    logging.info(
        f"Starting new search for semantic_query: '{request.semantic_query}'")
    try:
        top_k_hybrid = 20

        # --- Step 0: Query Embedding and Input Normalization ---
        t0 = time.time()
        query_embedding = app_state["ef"]([request.semantic_query])[0]

        keywords = request.bm25_query_keywords
        if keywords and isinstance(keywords, str):
            logging.warning(
                f"bm25_query_keywords was a string ('{keywords}'), splitting into words.")
            keywords = keywords.split()

        t1 = time.time()
        logging.info(
            f"Step 0: Query embedding and normalization took: {t1 - t0:.4f}s")

        # --- Step 1: Execute main search (Hybrid or Semantic) ---
        main_search_results = []
        scores_map = defaultdict(
            lambda: {"semantic": 0.0, "bm25": 0.0, "data": None})

        # Step 1a. Semantic Search (always executed)
        t0 = time.time()
        semantic_results_raw = []
        for collection in app_state["collections"]:
            res = collection.query(query_texts=[request.semantic_query], n_results=top_k_hybrid, include=[
                                   "documents", "metadatas", "distances"])
            for id, doc, meta, dist in zip(res['ids'][0], res['documents'][0], res['metadatas'][0], res['distances'][0]):
                chunk_id = id
                semantic_results_raw.append({"chunk_id": chunk_id, "title": meta.get(
                    'title', ''), "content": doc, "score": 1 - dist, "metadata": meta})
        t1 = time.time()
        logging.info(
            f"Step 1a: Semantic search ({len(semantic_results_raw)} results) took: {t1 - t0:.4f}s")

        # Step 1b. BM25 Search and Hybrid Fusion (if keywords are provided)
        if keywords:
            t0 = time.time()
            bm25_results_raw = bm25_search(keywords, top_k=top_k_hybrid)
            t1 = time.time()
            logging.info(
                f"Step 1b: BM25 search ({len(bm25_results_raw)} results) took: {t1 - t0:.4f}s")

            t0_fusion = time.time()
            # Populate scores_map from both sources
            norm_sem_scores = min_max_normalize(
                [r['score'] for r in semantic_results_raw])
            for i, res in enumerate(semantic_results_raw):
                scores_map[res['chunk_id']]['semantic'] = norm_sem_scores[i]
                scores_map[res['chunk_id']]['data'] = res

            norm_bm25_scores = min_max_normalize(
                [r['score'] for r in bm25_results_raw], is_bm25=True)
            for i, res in enumerate(bm25_results_raw):
                scores_map[res['chunk_id']]['bm25'] = norm_bm25_scores[i]
                if not scores_map[res['chunk_id']]['data']:
                    scores_map[res['chunk_id']]['data'] = res

            for chunk_id, scores in scores_map.items():
                if scores['data'] is None:
                    continue
                sem_score = scores['semantic']
                bm25_score = scores['bm25']
                combined_score = (1 - request.bm25_weight) * \
                    sem_score + request.bm25_weight * bm25_score
                main_search_results.append(
                    {**scores['data'], 'combined_score': combined_score, **scores})

            main_search_results.sort(
                key=lambda x: x['combined_score'], reverse=True)
            t1_fusion = time.time()
            logging.info(
                f"Step 1c: Hybrid fusion (fused {len(main_search_results)} results) took: {t1_fusion - t0_fusion:.4f}s")
        else:
            # Pure semantic ranking
            logging.info("Step 1b/c: Skipped BM25 search and fusion.")
            for res in semantic_results_raw:
                # Set bm25_score to -1.0 explicitly
                main_search_results.append(
                    {**res, 'semantic_score': res['score'], 'bm25_score': -1.0})
            main_search_results.sort(
                key=lambda x: x['semantic_score'], reverse=True)

        main_search_chunk_ids = {c['chunk_id'] for c in main_search_results}

        # --- Step 2: Apply `exclude_doc_ids` filter ---
        if request.exclude_doc_ids:
            original_count = len(main_search_results)
            main_search_results = [
                r for r in main_search_results
                if r.get('metadata', {}).get('doc_id', r.get('chunk_id', '_').split('_')[0]) not in request.exclude_doc_ids
            ]
            logging.info(
                f"Step 2: Exclusion filter removed {original_count - len(main_search_results)} items.")

        # --- Step 3: Handle `include_doc_ids` ---
        extra_chunks = []
        if request.include_doc_ids:
            t0 = time.time()
            extra_chunks = await _fetch_and_rank_extra_chunks(
                request.include_doc_ids, query_embedding, main_search_chunk_ids
            )
            t1 = time.time()
            logging.info(
                f"Step 3: Included docs processing found {len(extra_chunks)} extra chunks, took: {t1 - t0:.4f}s")

        # --- Step 4: Combine results and convert to Pydantic models ---
        combined_results_data = main_search_results[:request.top_k] + \
            extra_chunks[:request.top_k//2]

        final_results_list = [
            ChunkResult(
                doc_id=r.get('chunk_id', '_').split('_')[0],
                chunk_id=r.get('chunk_id'),
                title=r.get('title'),
                content=r.get('content'),
                semantic_score=r.get(
                    'semantic_score', r.get('semantic', -1.0)),
                bm25_score=r.get('bm25_score', r.get('bm25', -1.0)),
                metadata=r.get('metadata', {})
            ) for r in combined_results_data
        ]

        hybrid_search_results = final_results_list
        final_chunk_ids = {r.chunk_id for r in hybrid_search_results}

        # --- Step 5: Handle entity matching ---
        entity_snippets = None
        if request.entity_match:
            t0 = time.time()
            entity_search_keywords = list(keywords or [])
            if request.semantic_query:
                sanitized_query = re.sub(
                    r'[^\w\s]', '', request.semantic_query.lower())
                words = sanitized_query.split()
                filtered_semantic_words = [
                    word for word in words if word not in STOP_WORDS]
                entity_search_keywords.extend(filtered_semantic_words)

            entity_search_keywords = list(
                dict.fromkeys(entity_search_keywords))

            entity_chunks = entity_search(
                request.entity_match, entity_search_keywords, top_k=top_k_hybrid)
            entity_chunks = [
                c for c in entity_chunks if c['chunk_id'] not in final_chunk_ids]
            t01 = time.time()
            logging.info(
                f"Step 5a: Primary Entity search took: {t01 - t0:.4f}s")

            snippets_to_rank = []
            for chunk in entity_chunks:
                sentences = re.split(
                    r'(?<=[.!?])\s+', chunk['content'].strip())
                if len(sentences) > 1:
                    for i in range(len(sentences) - 1):
                        snippets_to_rank.append(
                            {"text": sentences[i] + " " + sentences[i+1], "source_chunk": chunk})
                elif sentences:
                    snippets_to_rank.append(
                        {"text": sentences[0], "source_chunk": chunk})

            if snippets_to_rank:
                with sqlite3.connect(':memory:') as temp_conn:
                    cur = temp_conn.cursor()
                    cur.execute(
                        "CREATE VIRTUAL TABLE temp USING fts5(text, tokenize='porter');")
                    cur.executemany("INSERT INTO temp (text) VALUES (?)", [
                                    (s['text'],) for s in snippets_to_rank])

                    keywords_query_part = _build_fts_query(
                        entity_search_keywords)
                    if keywords_query_part:
                        final_query = f'"{request.entity_match}" AND ({keywords_query_part})'
                    else:
                        final_query = f'"{request.entity_match}"'

                    cur.execute("SELECT rowid, bm25(temp) FROM temp WHERE temp MATCH ? ORDER BY bm25(temp) LIMIT ?", (
                        final_query, 3))
                    entity_snippets = [EntitySnippet(doc_id=snippets_to_rank[r[0]-1]['source_chunk']['chunk_id'].split('_')[
                                                     0], title=snippets_to_rank[r[0]-1]['source_chunk']['title'], snippet=snippets_to_rank[r[0]-1]['text'], score=r[1]) for r in cur.fetchall()]
                t1 = time.time()
                found_snippets = len(entity_snippets) if entity_snippets else 0
                logging.info(
                    f"Step 5b: Entity snippet ranking found {found_snippets} snippets, took: {t1 - t01:.4f}s")

        total_end_time = time.time()
        logging.info(
            f"Search finished. Total duration: {total_end_time - total_start_time:.4f}s")
        return SearchPlanResponse(search_results=hybrid_search_results, entity_snippets=entity_snippets)

    except Exception as e:
        logging.error(f"Error during search execution: {e}", exc_info=True)
        raise HTTPException(
            status_code=500, detail=f"An internal error occurred: {traceback.format_exc()}")

# --- 6. Application Startup ---
if __name__ == "__main__":
    logging.info("Starting Uvicorn server on http://0.0.0.0:8011")
    uvicorn.run(app, host="0.0.0.0", port=8011)
