import hashlib
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer


class ConstraintRetriever:
    def __init__(self, embedding_model_name='all-mpnet-base-v2'):
        self.model = SentenceTransformer(embedding_model_name)
        self.dim = self.model.get_sentence_embedding_dimension()
        self.index = faiss.IndexFlatIP(self.dim)

        # 记录已添加进索引的 rule_id 集合（用于增量更新）
        self.indexed_rule_ids = set()
        # 存原始规则信息（便于检索结果返回）
        self.rule_db = []  # 每个元素是 dict: {"rule_id", "symbolic", "description"}

    def add_rule(self, rule_id, symbolic, description, types, frequency):
        if rule_id in self.indexed_rule_ids:
            return  # 已存在，跳过

        text = symbolic + ": " + description + ": " + types + ": " + str(frequency)
        vector = self.model.encode([text], normalize_embeddings=True)[0]

        self.index.add(np.array([vector]))
        self.indexed_rule_ids.add(rule_id)
        self.rule_db.append({"rule_id": rule_id, "symbolic": symbolic, "description": description, "type": types, "frequency": frequency})

    def add_rules_batch(self, rule_list):
        for rule in rule_list:
            self.add_rule(rule["rule_id"], rule["symbolic"], rule["description"], rule["type"], rule["frequency"])

    def retrieve(self, query_text, top_k=5):
        query_vec = self.model.encode([query_text], normalize_embeddings=True)
        scores, indices = self.index.search(query_vec, top_k)

        results = []
        for i, idx in enumerate(indices[0]):
            if idx >= len(self.rule_db):
                continue
            rule = self.rule_db[idx]
            results.append({
                "rank": i + 1,
                "score": float(scores[0][i]),
                "rule_id": rule['rule_id'],
                "symbolic": rule['symbolic'],
                "description": rule['description'],
                "type": rule['type'],
                "frequency": rule['frequency']
            })
        return results


def sync_step_constraints_to_retriever(prev_rule_set, curr_rule_set, retriever):
    """
    对比前一轮和当前轮的规则库，找出新增 constraint，并自动添加进向量索引 retriever
    基于 rule_id 去重判断。
    """
    prev_ids = {r["rule_id"] for r in prev_rule_set}
    new_rules = [r for r in curr_rule_set if r["rule_id"] not in prev_ids]

    for rule in new_rules:
        retriever.add_rule(rule["rule_id"], rule["symbolic"], rule["description"], rule["type"], rule["frequency"])


# # 示例用法：
# if __name__ == '__main__':
#     retriever = ConstraintRetriever()

#     step0 = [
#         {
#             "rule_id": "0",
#             "symbolic": "Open(?agent, ?container) → Reachable(?agent, ?container)",
#             "description": "An agent must be able to reach a container before opening it."
#         }
#     ]

#     step1 = [
#         {
#             "rule_id": "0",
#             "symbolic": "Open(?agent, ?container) → Reachable(?agent, ?container)",
#             "description": "An agent must be able to reach a container before opening it."
#         },
#         {
#             "rule_id": "1",
#             "symbolic": "PutObject(?agent, ?object) → Holding(?agent, ?object)",
#             "description": "The agent must be holding the object before attempting to put it down."
#         }
#     ]

#     sync_step_constraints_to_retriever(step0, step1, retriever)

#     query = "Agent tries to put something but is not holding it."
#     results = retriever.retrieve(query, top_k=2)
#     for r in results:
        # print(f"#{r['rank']}: {r['symbolic']}\nScore: {r['score']:.4f}\n")
