from utils.insight_tree import InsightTree
from utils.rag_retrieval import *
from scipy.spatial.distance import cosine
import random
import torch, gc


class RAG:
    def __init__(self, tree: InsightTree, retrieve_model):
        self.retrieve_model = retrieve_model.to("cpu")
        self.tree = tree

    def is_ok(self):
        if self.retrieve_model is None:
            raise NotImplementedError("Retrieve model is None")

    @staticmethod
    def to_context(memory_nodes):
        return "\n".join([f"Score: {n['score']} | Idea: {n['idea']}" for n in memory_nodes])

    def sample_random_nodes(self, exclude_nodes=None, n=3, add_context=False):
        """
        Return n random nodes (as (idea, score)) from the tree, excluding any in exclude_nodes.
        """
        self.is_ok()
        if exclude_nodes is None:
            exclude_nodes = []
        all_nodes = [node for node in self.tree.nodes.values() if node not in exclude_nodes]
        sampled = random.sample(all_nodes, min(n, len(all_nodes)))
        memory_nodes = [{'idea': node.idea, 'score': node.mean_score} for node in sampled]

        if add_context:
            return self.to_context(memory_nodes)
        return memory_nodes

    def get_most_distant_nodes(self, parent_idea, top_n=3, distant_ideas=True, add_context=False):
        self.is_ok()
        parent_emb = encode_text(parent_idea, self.retrieve_model, task='query')
        gc.collect()
        torch.cuda.empty_cache()

        distances = []
        for node in self.tree.nodes.values():
            if node.embedding is not None:
                dist = cosine(parent_emb, node.embedding)
                distances.append((dist, node))
        distances.sort(key=lambda x: x[0], reverse=distant_ideas)
        memory_nodes = [{'idea': node.idea, 'score': node.mean_score} for dist, node in distances[:top_n]]

        if add_context:
            return self.to_context(memory_nodes=memory_nodes)
        return memory_nodes
