from .extraction import memory_extraction
from .consolidation import MemoryManager
from .retrieval import memory_retriever
from .refinement import memory_refiner
import json
import os
import uuid
from dotenv import load_dotenv

load_dotenv()


class Memory:
    def __init__(self):
        self.manager = MemoryManager(os.getenv('memory_storage_path'))
        self.extractor = memory_extraction()
        self.retriever = memory_retriever()
        self.refiner = memory_refiner()

    def memory_extraction_embedding(self, goal):
        embedding = self.extractor.store_embedding(goal)

        return embedding

    def memory_retriever(self, embedding):
        with open(os.getenv('memory_storage_path'), 'r', encoding='utf-8') as f:
            result = json.load(f)
        if result == {}:
            retriever_number = -1
        else:
            retriever_number = self.retriever.retrieve_most_similar_memory(embedding, result['memories'])
        if retriever_number == -1:
            retriever_content = None
        else:
            retriever_content = result["memories"][retriever_number]

        return retriever_number, retriever_content

    def memory_refiner(self, goal, retriever_content, is_success):
        result = self.refiner.memory_refinement(goal, retriever_content, is_success)

        return result

    def memory_extraction(self, action_data, goal_data, result_data):
        high_level_goal = self.extractor.extract_high_level_goal(action_data)
        low_level_instructions = self.extractor.extract_low_level_instructions(action_data, high_level_goal)
        task_state = self.extractor.extract_task_state(goal_data, high_level_goal, low_level_instructions)
        memory_bank = {'high_level_goal': high_level_goal, 'low_level_instructions': low_level_instructions,
                       'task_state': task_state, 'result': result_data, 'task_id': str(uuid.uuid1())}

        return memory_bank

    def memory_consolidation_success(self, retriever_number, memory_bank):
        self.manager.delete_memory(retriever_number)
        self.manager.add_memory(memory_bank)

    def memory_consolidation_failure(self, memory_bank):
        self.manager.add_memory(memory_bank)
