from main_pipeline import gpt_call
import json
import os
from prompt_templates import PRINCIPLE_MATCH_PROMPT

class MemoryManager:
    def __init__(self, path="memory/memory.json"):
        self.path = path
        self.memory = self.load()

    def load(self):
        if os.path.exists(self.path):
            with open(self.path, 'r', encoding='utf-8') as f:
                return json.load(f)
        return {}

    def save(self):
        with open(self.path, 'w', encoding='utf-8') as f:
            json.dump(self.memory, f, indent=2, ensure_ascii=False)

    def retrieve(self, task_desc: str):
        for task, principles in self.memory.items():
            if task.strip() == task_desc.strip():
                return task, principles
        return None, []

    def add_task(self, task_desc: str, principles: list):
        self.memory[task_desc] = principles

    def merge_principles(self, task_desc: str, new_principles: list):
        old_principles = self.memory.get(task_desc, [])
        filtered = self._resolve_conflicts(old_principles, new_principles)
        self.memory[task_desc] = filtered

    def _resolve_conflicts(self, old: list, new: list) -> list:
        if not old:
            return new
        if not new:
            return old

        prompt = PRINCIPLE_MATCH_PROMPT.format(
            old="\n".join(old),
            new="\n".join(new)
        )
        result = gpt_call(prompt)

        try:
            relations = json.loads(result)
        except Exception as e:
            print(f"[ERROR]: {e}")
            return list(set(old + new)) 

        retained_old = set(old)
        retained_new = []

        for match in relations:
            old_rule = match.get("old", "").strip()
            new_rule = match.get("new", "").strip()
            relation = match.get("relation", "").strip()

            if relation == "Redundant ":
                if old_rule in retained_old:
                    retained_old.remove(old_rule)
                retained_new.append(new_rule)    
            elif relation == "Conflicting":
                if old_rule in retained_old:
                    retained_old.remove(old_rule)
                retained_new.append(new_rule)  
            elif relation == "Irrelevant":
                retained_new.append(new_rule)

        return list(retained_old.union(retained_new))

