import hashlib
from collections import defaultdict
import numpy as np

class CacheManager:
    def __init__(self):
        # question_id -> cold facts with embeddings
        # Each fact is stored as {"fact": text, "embedding": vector}
        self.cold_cache = defaultdict(list)
        self.hot_cache = defaultdict(list)
        
        # count of fact occurrences (to promote to hot)
        # FIXED: Use regular dict instead of nested defaultdict with lambda
        self.fact_count = {}

    def _get_fact_count_dict(self, qid):
        """Helper method to get or create fact count dict for a question"""
        if qid not in self.fact_count:
            self.fact_count[qid] = defaultdict(int)
        return self.fact_count[qid]

    # -----------------------------
    # Generate a unique ID for a question
    # -----------------------------
    def get_question_id(self, question_text):
        return hashlib.sha256(question_text.encode('utf-8')).hexdigest()

    # -----------------------------
    # Add facts with embeddings to cold cache
    # facts_with_embeddings: list of dicts [{"fact": text, "embedding": np.array}]
    # -----------------------------
    def add_to_cold(self, question_text, facts_with_embeddings):
        qid = self.get_question_id(question_text)
        fact_count_dict = self._get_fact_count_dict(qid)
        
        for item in facts_with_embeddings:
            fact_text = item["fact"].lower().strip()
            embedding = item["embedding"]
            
            # Check if fact already exists in cold cache
            if not any(f["fact"] == fact_text for f in self.cold_cache[qid]):
                self.cold_cache[qid].append({"fact": fact_text, "embedding": embedding})
            fact_count_dict[fact_text] += 1

    # -----------------------------
    # Promote facts from cold -> hot
    # -----------------------------
    def promote_to_hot(self, question_text, threshold=3):
        qid = self.get_question_id(question_text)
        fact_count_dict = self._get_fact_count_dict(qid)
        
        for fact_text, count in fact_count_dict.items():
            if count >= threshold and not any(f["fact"] == fact_text for f in self.hot_cache[qid]):
                # Find the embedding from cold cache
                embedding = next((f["embedding"] for f in self.cold_cache[qid] if f["fact"] == fact_text), None)
                if embedding is not None:
                    self.hot_cache[qid].append({"fact": fact_text, "embedding": embedding})

    # -----------------------------
    # Retrieve cache for a question
    # -----------------------------
    def get_cache(self, question_text):
        qid = self.get_question_id(question_text)
        return {
            "cold": self.cold_cache.get(qid, []),
            "hot": self.hot_cache.get(qid, [])
        }

    # -----------------------------
    # Update cache with new facts (from fallback RAG2)
    # -----------------------------
    def update_cache(self, question_text, new_facts_with_embeddings):
        self.add_to_cold(question_text, new_facts_with_embeddings)
        self.promote_to_hot(question_text)
