import pandas as pd
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer


class TransitionRetriever:
    def __init__(self, csv_path, embedding_model_name='all-mpnet-base-v2'):
        self.model = SentenceTransformer(embedding_model_name)
        self.dim = self.model.get_sentence_embedding_dimension()
        self.index = faiss.IndexFlatIP(self.dim)

        # 存储原始 transition 数据（用于返回匹配结果）
        self.transition_db = []

        # 加载并处理 transition memory
        self._load_transitions(csv_path)

    def _load_transitions(self, csv_path):
        df = pd.read_csv(csv_path)
        for i, row in df.iterrows():
            memory = str(row.get("Observation", ""))
            action = str(row.get("Action", ""))
            next_memory = str(row.get("Next-Observation", ""))
            successes=str(row.get("Successes", ""))
            failure=str(row.get("Failure", ""))

            if memory.strip() == "" or action.strip() == "":
                continue  # 跳过空数据

            # 构造表示向量的文本
            transition_text = memory + " | " + action
            vector = self.model.encode([transition_text], normalize_embeddings=True)[0]

            self.index.add(np.array([vector]))
            self.transition_db.append({
                "memory": memory,
                "action": action,
                "next_memory": next_memory,
                "successes": successes,
                "failure": failure
            })

    def retrieve(self, query_text, top_k=3):
        query_vec = self.model.encode([query_text], normalize_embeddings=True)
        scores, indices = self.index.search(query_vec, top_k)

        results = []
        for i, idx in enumerate(indices[0]):
            if idx >= len(self.transition_db):
                continue
            record = self.transition_db[idx]
            results.append({
                "experience": i + 1,
                "score": float(scores[0][i]),
                "observaton": record["memory"],
                "action": record["action"],
                "next_observation": record["next_memory"],
                "successes": record["successes"],
                "failure": record["failure"]
            })
        return results


class RuleRetriever:
    def __init__(self, csv_path, embedding_model_name='all-mpnet-base-v2'):
        self.model = SentenceTransformer(embedding_model_name)
        self.dim = self.model.get_sentence_embedding_dimension()
        self.index = faiss.IndexFlatIP(self.dim)

        # 存储原始 rule 数据（用于返回匹配结果）
        self.rule_db = []

        # 加载并处理 rule memory
        self._load_rules(csv_path)

    def _load_rules(self, csv_path):
        df = pd.read_csv(csv_path)
        for i, row in df.iterrows():
            observation = str(row.get("Observation", ""))
            action = str(row.get("Action", ""))
            constraint = str(row.get("Constraint", ""))

            if observation.strip() == "" or action.strip() == "":
                continue  # 跳过空数据

            # 构造表示向量的文本
            rule_text = observation + " | " + action
            vector = self.model.encode([rule_text], normalize_embeddings=True)[0]

            self.index.add(np.array([vector]))
            self.rule_db.append({
                "Constraint": constraint
            })

    def retrieve(self, query_text, top_k=3):
        query_vec = self.model.encode([query_text], normalize_embeddings=True)
        scores, indices = self.index.search(query_vec, top_k)

        results = []
        for i, idx in enumerate(indices[0]):
            if idx >= len(self.rule_db):
                continue
            record = self.rule_db[idx]
            results.append({
                "Failure_experience": i + 1,
                #"score": float(scores[0][i]),
                "Constraint": record["Constraint"]
            })
        return results