import os
import json
from typing import List, Dict, Any, Optional, Optional
import numpy as np
import random
import faiss
from sentence_transformers import SentenceTransformer


class PersistentRAG:
    """
    Persistent RAG store using SentenceTransformers + FAISS for PROBLEMS only.

    Problems are stored in `documents.json` and indexed in `index.faiss`.
      Each problem must contain at least: {"question": "..."}.
      We auto-assign a stable "id" (its position on insertion).

    Retrieval:
      1) retrieve() -> top-k problems (cosine via normalized embeddings + IndexFlatIP)
      2) retrieve_problem_and_strategies() -> for each top problem, return its single strategy (if any), no score.
    """

    def __init__(
        self,
        save_dir: str = "solution_strategy",
        embedding_model: str = "all-MiniLM-L6-v2",
        device: Optional[str] = None,
    ):
        self.save_dir = f"{save_dir}_rag_database"
        os.makedirs(self.save_dir, exist_ok=True)

        self.embedding_model = embedding_model
        self.embedder = SentenceTransformer(embedding_model, device=device)

        # problem info
        self.documents: List[Dict[str, Any]] = []
        self.index: Optional[faiss.Index] = None

        # files
        self.docs_file = os.path.join(self.save_dir, "documents.json")
        self.index_file = os.path.join(self.save_dir, "index.faiss")
        self.meta_file = os.path.join(self.save_dir, "meta.json")

        self._load_if_exists()

    def _load_if_exists(self) -> None:
        print("Loading existing store...")

        # docs
        if os.path.exists(self.docs_file):
            with open(self.docs_file, "r", encoding="utf-8") as f:
                self.documents = json.load(f)
            print(f"- documents.json loaded ({len(self.documents)} items)")
        else:
            print("- documents.json not found (starting empty)")

        if os.path.exists(self.index_file):
            self.index = faiss.read_index(self.index_file)
            print(f"- index.faiss loaded ({self.index.ntotal} x {self.index.d})")
        else:
            print("- index.faiss not found (will build on first add)")

        # meta (optional)
        if os.path.exists(self.meta_file):
            with open(self.meta_file, "r", encoding="utf-8") as f:
                meta = json.load(f)
            if meta.get("embedding_model") and meta["embedding_model"] != self.embedding_model:
                raise ValueError(
                    f"Stored embedding model '{meta['embedding_model']}' does not match current '{self.embedding_model}'"
                )
        else:
            meta = {}

    def _encode(self, texts: List[str]) -> np.ndarray:
        """Return float32, L2-normalized embeddings."""
        embs = self.embedder.encode(
            texts,
            normalize_embeddings=True,
            convert_to_numpy=True,
            batch_size=64,
            show_progress_bar=False,
        )
        if embs.dtype != np.float32:
            embs = embs.astype("float32")
        return embs

    def _atomic_write_text(self, path: str, text: str) -> None:
        tmp = path + ".tmp"
        with open(tmp, "wb") as f:
            f.write(text.encode("utf-8"))
            f.flush()
            os.fsync(f.fileno())
        os.replace(tmp, path)

    def _atomic_write_faiss(self, index: faiss.Index, path: str) -> None:
        """Write FAISS index atomically by writing to tmp and replacing."""
        tmp = path + ".tmp"
        faiss.write_index(index, tmp)
        os.replace(tmp, path)

    def _save_meta(self) -> None:
        if self.index is None:
            dim = None
            ntotal = 0
        else:
            dim = int(self.index.d)
            ntotal = int(self.index.ntotal)
        meta = {
            "embedding_model": self.embedding_model,
            "dimension": dim,
            "ntotal_documents": ntotal,
        }
        self._atomic_write_text(self.meta_file, json.dumps(meta, ensure_ascii=False, indent=2))

    def _save_documents_and_index(self) -> None:
        self._atomic_write_text(self.docs_file, json.dumps(self.documents, ensure_ascii=False, indent=2))

        if self.index is not None:
            self._atomic_write_faiss(self.index, self.index_file)

        self._save_meta()
        print(f"- Saved documents ({len(self.documents)}), index ({0 if self.index is None else self.index.ntotal})")

    def add_documents(self, new_docs: List[Dict[str, Any]]) -> None:
        """
        Add problems. Each doc must have 'question'.
        Optional: 'strategies': [{'strategy': '...'}, ...] — ONLY the first non-empty will be stored.
        """
        if not new_docs:
            print("- add_documents: nothing to add")
            return

        existing = {doc["question"].strip().lower() for doc in self.documents}
        to_add = []
        for d in new_docs:
            if "question" not in d or not isinstance(d["question"], str):
                raise ValueError("Each document must include a string 'question'")
            if d["question"].strip().lower() in existing:
                continue
            to_add.append(d)

        if not to_add:
            print("- add_documents: all documents already present")
            return

        # add problems
        q_texts = [d["question"] for d in to_add]
        q_embs = self._encode(q_texts)
        if self.index is None:
            self.index = faiss.IndexFlatIP(q_embs.shape[1])
        else:
            if int(self.index.d) != int(q_embs.shape[1]):
                raise ValueError(
                    f"Embedding dimension mismatch: index has dim {self.index.d}, "
                    f"but new embeddings have dim {q_embs.shape[1]}"
                )
        self.index.add(q_embs)

        self.documents.extend(to_add)

        self._save_documents_and_index()

        print(f"- Added {len(to_add)} problems")

    def retrieve(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
        if self.index is None or self.index.ntotal == 0 or not self.documents:
            print("- retrieve: no documents/index available")
            return []

        k = max(1, min(k, self.index.ntotal))
        q = self._encode([query])
        sims, idxs = self.index.search(q, k)  # cosine (IP on normalized vectors)
        sims, idxs = sims[0], idxs[0]

        results: List[Dict[str, Any]] = []
        for r in range(k):
            i = int(idxs[r])
            item = {**self.documents[i], "problem_score": float(sims[r])}
            results.append(item)
        return results

    def random_retrieve(self, k: int = 3) -> List[Dict[str, Any]]:
        
        if not self.documents:
            print("- random_retrieve: no documents available")
            return []

        k = max(1, min(k, len(self.documents)))
        idxs = random.sample(range(len(self.documents)), k)
        print(f"- random_retrieve: sampled {k} documents out of {len(self.documents)}")

        return [dict(self.documents[i]) for i in idxs]

    def kth_similar(self, query: str, k: int):
        """
        Return the kth most similar item to `query` as (score, doc).
        k is 0-based. If fewer than k+1 results exist, returns the last available.
        """
        if k < 1:
            raise ValueError("k must be non-negative")

        # Request at least k+1 results so we can index the kth safely
        results = self.retrieve(query, k)

        if not results:
            return float("nan"), None
        else:
            score = results[-1]["problem_score"]
        return score

'''
# code to test the PersistentRAG class
if __name__ == "__main__":
    example_docs = [
        {"question": "RAG stands for Retrieval-Augmented Generation", "number": 1},
        {"question": "FAISS is a library for efficient similarity search developed by Facebook", "number": 2},
        {"question": "SentenceTransformers create dense vector embeddings for text", "number": 3},
        {"question": "Python is a popular programming language for machine learning", "number": 4},
        {"question": "HuggingFace provides state-of-the-art NLP models", "number": 5},
        {"question": "Vector databases optimize similarity search operations", "number": 6},
        {"question": "The quick brown fox jumps over the lazy dog", "number": 7}
    ]
    save_dir = "solution_strategy"

    # initialize PersistentRAG
    rag1 = PersistentRAG(save_dir)
    rag1.add_documents(example_docs)
    
    query = "What is RAG?"
    result = rag1.retrieve(query, k=8)
    print(f"\nQuery: {query}; Results: {result}\n")
    score = rag1.kth_similar(query, k=8)
    print(f"Score: {score}\n")

    new_docs = [
        {"question": "PyTorch is a machine learning framework", "number": 8},
        {"question": "GPUs accelerate deep learning computations", "number": 9}
    ]

    rag2 = PersistentRAG(save_dir)
    rag2.add_documents(new_docs)
    result = rag2.retrieve(query, k=20)
    print(f"\nQuery: {query}; Results: {result}\n")

    result = rag2.retrieve(new_docs[0]["question"], k=20)
    print(f"\nQuery: {new_docs[0]['question']}; Results: {result}\n")

    # check the saved index and embeddings
    index = faiss.read_index("/Multi-LLM/solution_strategy_rag_database/index.faiss")
    vectors = np.array([index.reconstruct(i) for i in range(index.ntotal)])
    print(vectors)

    # check the embedding from SentenceTransformer
    emb = rag2._encode([example_docs[5]["question"]])[0]
    vec = vectors[5]
    print("Embedding (SentenceTransformer):", emb[:10])
    print("Embedding (FAISS reconstruct):", vec[:10])
    print("Difference:", np.linalg.norm(emb - vec))
'''   
