from typing import List

from src.schema import (
    User,
    Memory,
)
from src.utils.json import read_json_file


class MemoryService:
    def __init__(self, mem_data_path: str) -> None:
        memory_data = read_json_file(file_path=mem_data_path)
        users = [User(**user_dict) for user_dict in memory_data]

        self.dialogue_map = {
            mem.memory_id: mem
            for user in users
            for mem in user.dialogue
        }
        self.observation_map = {
            mem.memory_id: mem
            for user in users
            for mem in user.observation
        }
        self.summary_map = {
            mem.memory_id: mem
            for user in users
            for mem in user.summary
        }
        self.episodic_memory_map = {
            mem.memory_id: mem
            for user in users
            for mem in user.episodic_memory
        }
        self.semantic_memory_map = {
            mem.memory_id: mem
            for user in users
            for mem in user.semantic_memory
        }
        self.memory_map = {
            **self.dialogue_map,
            **self.observation_map,
            **self.summary_map,
            **self.episodic_memory_map,
            **self.semantic_memory_map,
        }

        self.dialogue_prompt_map = {
            user.user_id: "\n".join([mem.content for mem in user.dialogue]).strip()
            for user in users
        }
        self.observation_prompt_map = {
            user.user_id: "\n".join([mem.content for mem in user.observation]).strip()
            for user in users
        }
        self.summary_prompt_map = {
            user.user_id: "\n".join([mem.content for mem in user.summary]).strip()
            for user in users
        }
        self.episodic_memory_prompt_map = {
            user.user_id: "\n".join([mem.content for mem in user.episodic_memory]).strip()
            for user in users
        }
        self.semantic_memory_prompt_map = {
            user.user_id: "\n".join([mem.content for mem in user.semantic_memory]).strip()
            for user in users
        }

    def get_memories(self, memory_ids: List[str]) -> List[Memory]:
        memories = [self.memory_map[memory_id] for memory_id in memory_ids]
        return memories
    
    def get_dialogue_prompt(self, user_id: str) -> str:
        return self.dialogue_prompt_map[user_id]
    
    def get_observation_prompt(self, user_id: str) -> str:
        return self.observation_prompt_map[user_id]
    
    def get_summary_prompt(self, user_id: str) -> str:
        return self.summary_prompt_map[user_id]
    
    def get_episodic_memory_prompt(self, user_id: str) -> str:
        return self.episodic_memory_prompt_map[user_id]
    
    def get_semantic_memory_prompt(self, user_id: str) -> str:
        return self.semantic_memory_prompt_map[user_id]
