# faiss_store.py

import faiss
import numpy as np
from typing import Dict, List, Optional
from collections import defaultdict

# Adjust these imports to match your project structure
from MentalModelTypes import Hypothesis  # import your Hypothesis class
from embed import create_embedding  # import your embedding function

class FaissStore:
    """Exact-search vector store keyed by (dimension_id, hypothesis_uuid)."""

    def __init__(self, embedding_dim: int = 3072):
        self.embedding_dim = embedding_dim
        # dimension_id -> FAISS IndexIDMap2
        self.stores: Dict[int, faiss.IndexIDMap2] = {}
        # dimension_id -> {int_id: Hypothesis}
        self.hypotheses: Dict[int, Dict[int, Hypothesis]] = defaultdict(dict)
        # dimension_id -> {uuid: int_id}
        self.id_maps: Dict[int, Dict[str, int]] = defaultdict(dict)
        # dimension_id -> next available int_id
        self.next_id: Dict[int, int] = defaultdict(lambda: 1)

    # ---------------------------------------------------------------------
    # Internal helpers
    # ---------------------------------------------------------------------
    def _init_dimension(self, dim: int) -> None:
        """Create a new IDMap2(IndexFlatIP) for a dimension if absent."""
        if dim in self.stores:
            return
        base_index = faiss.IndexFlatIP(self.embedding_dim)  # inner-product ≈ cosine
        index = faiss.IndexIDMap2(base_index)               # forwards remove_ids
        self.stores[dim] = index

    def _normalize(self, vec: np.ndarray) -> np.ndarray:
        vec = vec.astype("float32", copy=False)
        norm = np.linalg.norm(vec)
        return vec / norm if norm else vec

    # ---------------------------------------------------------------------
    # Public API
    # ---------------------------------------------------------------------
    def insert(self, hyp: Hypothesis) -> None:
        """Upsert a hypothesis: delete old vector (if any) then add the new one."""
        dim = hyp.dimension_id
        self._init_dimension(dim)
        index = self.stores[dim]
        id_map = self.id_maps[dim]

        # Resolve / assign integer ID
        if hyp.hypothesis_id in id_map:
            int_id = id_map[hyp.hypothesis_id]
            # Safe because IndexFlatIP via IDMap2 implements remove_ids
            index.remove_ids(np.array([int_id], dtype=np.int64))
        else:
            int_id = self.next_id[dim]
            id_map[hyp.hypothesis_id] = int_id
            self.next_id[dim] += 1

        # Embed & normalize
        vec = self._normalize(create_embedding(hyp.description))
        index.add_with_ids(vec.reshape(1, -1), np.array([int_id], dtype=np.int64))
        self.hypotheses[dim][int_id] = hyp

    # Alias for readability
    update = insert

    def get(self, dimension_id: int, hypothesis_id: str) -> Optional[Hypothesis]:
        int_id = self.id_maps[dimension_id].get(hypothesis_id)
        return None if int_id is None else self.hypotheses[dimension_id].get(int_id)

    def semantic_search(self, query: str, dimension_id: int, top_k: int = 5) -> List[Hypothesis]:
        """Exact cosine-similarity search within a dimension."""
        self._init_dimension(dimension_id)
        index = self.stores[dimension_id]

        vec = self._normalize(create_embedding(query))
        distances, indices = index.search(vec.reshape(1, -1), top_k)

        results: List[Hypothesis] = []
        for int_id in indices[0]:
            if int_id != -1:
                results.append(self.hypotheses[dimension_id][int_id])
        return results
    
    def list_all_one(self, dim: int) -> List[Hypothesis]:
        """
        List all Hypothesis instances in a single dimension.
        """
        if dim not in self.hypotheses:
            return []
        return list(self.hypotheses[dim].values())

    def list_all(self, dims: List[int]) -> Dict[int, List[Hypothesis]]:
        return {dim: self.list_all_one(dim) for dim in dims}