import datetime
import numpy as np
from typing import Callable
from concordia.typing import logging
from collections.abc import Sequence
from typing_extensions import override
from concordia.clocks import game_clock
from concordia.utils import helper_functions
from concordia.typing import entity_component
from concordia.typing import clock as ClockType
from concordia.typing import entity as entity_lib
from concordia.document import interactive_document
from concordia.typing.logging import LoggingChannel
from concordia.language_model import language_model
from concordia.agents import entity_agent_with_logging
from concordia.components import agent as agent_components
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.memory_bank import legacy_associative_memory
from concordia.utils import measurements as measurements_lib
from concordia.components.agent import action_spec_ignored, memory_component

INSTRUCTION_PRE_ACT_KEY = "\nRole Playing instructions"
GOAL_PRE_ACT_KEY = "\nOverarching Goal"


SHORT_TERM_MEMORY_PRE_ACT_KEY = "\nShort term memory"

RELEVANT_MEMORY_PRE_ACT_KEY = "\nRelevant memory"

SITUATION_PERCEPTION_PRE_ACT_KEY = "\nSituation Perception"
SITUATION_PERCEPTION_CONTEXT = """
{goal}
{relevant_memories}
{short_term_memory}
"""
SITUATION_PERCEPTION_PROMPT = """
You are given the goal, associative and short-term memories of {agent_name}.
Provide a full and detailed description that includes all relevant information 
for {agent_name} to clearly understand the situation they are involved in. 
Describe the current context, relevant background, emotional states, and any 
potential actions or decisions that may be needed. Make sure to include any 
relationships, associations, or additional information that could influence 
{agent_name}'s perception and response to the situation.
"""
SITUATION_PERCEPTION_ANSWER_PREFIX = "{agent_name} "

AC_PRE_ACT_KEY = "\nAct"
AC_CONTEXT_FOR_ACTION = """
{instructions}

{time}

{goal}

{relevant_memories}

{recent_memories}
"""


AC_USEFUL_INFO_EXTRACTION_PROMPT = """
You have to extract useful information for decision making.

Overarching Goal: Manuel's goal is to ensure everyone enjoys the party a lot.

Relevant memory:
[06 Nov 2024 14:10:00] Manuel is organizing a party with his friends. He wants
everybody to enjoy the party and to eat the food they like the most.  

[06 Nov 2024 14:12:00] The group discusses how their diverse dietary needs, 
especially the vegan and lactose-intolerant preferences, will affect their 
food choices for the party. Oswaldo is especially mindful of ensuring all 
dishes have clear labels to avoid any mix-ups. 

Short term memory:
[06 Nov 2024 16:01:45] [observation] Manuel -- "Oswaldo, How many people did
we invite to the party?"

[06 Nov 2024 16:02:00] [observation] Manuel -- "I think we are 7 in total."

[06 Nov 2024 16:02:20] [observation] Manuel -- "Oswaldo, which food do you 
think we should buy for the party."

[06 Nov 2024 16:02:45] [observation] Oswaldo -- "That is a very important 
question, I think that you have remember that Aleja is lactose-intolerant"

[06 Nov 2024 16:03:00] [observation] Manuel -- "You right! And now that 
you mentioned it. I just remembered that Carlos is vegan."

[06 Nov 2024 16:02:45] [observation] Oswaldo -- "So where should we buy 
the food?"

The useful information that Manuel needs to take a logical decision in order
to reach his goal:

- Manuel wants that everybody eats suitable food so they can enjoy the party.
- Manuel needs to define where to buy the food.
- Aleja is lactose intolerant, so one dish shouldn't have lactose.
- Carlos is vegan, so he needs to buy one vegan dish.
- The other 5 dishes do not have any restrictions.

######
Goal:
Nicole wants to buy a house for her family. She wants a house that fits to her
budget and that makes all her family members happy.

Relevant Memories
[11 July 2012 10:00:00] Nicole is married with Sebastian and they
 have a toddler and a teenager.
[11 July 2012 10:01:00] Nicole and Sebastian love to go to the gym.
[11 July 2012 10:01:00] Nicole discusses with Angie, the sales manager.

Short Term Memories

[11 July 2012 10:05:00] Angie -- "Welcome to my office Nicole,
 how is it going?"
 
[11 July 2012 10:05:10] Nicole -- "I am doing great! I am eager to see
the houses you found"

[11 July 2012 10:05:20] Angie -- "Surely, I have three great options for 
you and your family."

[11 July 2012 10:05:40] Nicole -- "Awesome, I just hope they fit 
on my $400.000 budget"

[11 July 2012 10:06:00] Angie -- "Absolutely, Nicole. I kept your budget and 
your family's needs in mind when selecting these options. Let me walk you 
through the details."

[11 July 2012 10:06:20] Angie -- "The first option is a cozy three-bedroom home
 in a family-friendly neighborhood with good schools nearby. It’s $380,000, so 
 it’s slightly under your budget."

[11 July 2012 10:06:40] Nicole -- "That sounds good! With our toddler and 
teenager, a safe area and good schools are definitely important. But we have 
guests very often, maybe an extra room would be nice."

[11 July 2012 10:07:00] Angie -- "Exactly. Now, the second option is a bit 
larger, at $395,000. This house has four bedrooms, which would give you some 
extra space as the kids grow or help you with your guests. 
It also has a small home gym area, which I remember you and
 Sebastian would love!"

[11 July 2012 10:07:20] Nicole -- "A home gym? That’s a huge plus! We’d 
definitely make good use of it."

[11 July 2012 10:07:30] Angie -- "Now, the third option is a newer 
three-bedroom house priced at $410,000. It’s in a community with great 
amenities: a shared gym, pool, and even a playground for your toddler."

The useful information that Nicole needs to take a logical decision in order
to reach his goal:

- Nicole is trying to buy a house for her family.
- Nicole has a budget of $400.000.
- Nicole has a toddler and a teenager.
- Nicole and her husband Sebastian love to go to the gym.
- Nicole's family frequently receives guests.
- Nicole would like to have at least 4 bedrooms.
- The first option costs $380.000 and is located in a family-friendly 
neighborhood with good schools nearby and has 3 bedrooms.
- The second option has four bedrooms and a small home gym area. It costs 
$395.000.
- The third option is a newer house with three bedrooms and a lot of amenities
like: a shared gym, pool, and a playground.

######
{goal}

{relevant_memories}

{recent_memories}
"""
AC_USEFUL_INFO_EXTRACTION_ANSWER_PREFIX = """The useful information that {agent_name} needs to take a logical 
decision in order to reach his goal:"""

AC_ACTION_CONSEQUENCES_PROMPT = """
{goal}
{situation_perception}
{useful_info}
{agent_name} faces the following question: {question}
{agent_name} is considering the option of "{choice}"
{agent_name} analyzes what will happen if the option "{choice}" is chosen, 
a detailed description of what will happen according with her goal is:

"""

AC_ACTION_CONSEQUENCES_ANSWER_PREFIX = """{agent_name}: if I respond to the question with "{choice}", it is a"""


def _get_class_name(object_: object) -> str:
    return object_.__class__.__name__


class GoalComponent(action_spec_ignored.ActionSpecIgnored):

    def __init__(
            self,
            goal: str | None = None,
            pre_act_key: str = GOAL_PRE_ACT_KEY,
            logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel
    ):
        super().__init__(pre_act_key)
        self._goal = goal if goal else "No goal specified"
        self._logging_channel = logging_channel

    def _make_pre_act_value(self) -> str:
        self._logging_channel(
            {'Key': self.get_pre_act_key(),
             'Value': self._goal}
        )
        return self._goal


class RelevantMemory(action_spec_ignored.ActionSpecIgnored):

    def __init__(
            self,
            *,
            short_term_memory_component_name: str,
            pre_act_key: str,
            memory_component_name: str = memory_component.DEFAULT_MEMORY_COMPONENT_NAME,
            logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,

    ):
        super().__init__(pre_act_key)
        self._short_term_memory_component_name = short_term_memory_component_name
        self._memory_component_name = memory_component_name
        self._logging_channel = logging_channel

    def _get_short_term_memory(self):
        short_term_memory = self.get_named_component_pre_act_value(
            self._short_term_memory_component_name
        )
        return short_term_memory

    def _make_pre_act_value(self) -> str:
        entity = self.get_entity()
        memory = entity.get_component(
            self._memory_component_name,
            type_=memory_component.MemoryComponent
        )
        short_term_memory = self.get_named_component_pre_act_value(
            self._short_term_memory_component_name
        )
        relevant_memories = memory.retrieve(
            query=short_term_memory,
            limit=3,
            scoring_fn=legacy_associative_memory.RetrieveAssociative(
                add_time=True
            )
        )
        relevant_memories = set([rm.text for rm in relevant_memories])
        short_term_memory_set = set(short_term_memory.split('\n'))
        relevant_memories -= short_term_memory_set
        relevant_memories = "\n".join(relevant_memories)

        self._logging_channel(
            {
                'Key': self.get_pre_act_key(),
                'Value': relevant_memories
            }
        )

        return relevant_memories


class SituationPerception(action_spec_ignored.ActionSpecIgnored):

    def __init__(
            self,
            *,
            model: language_model.LanguageModel,
            prompt_context: str,
            prompt: str,
            answer_prefix: str,
            goal_component_name: str,
            relevant_memories_component_name: str,
            short_term_memories_component_name: str,
            pre_act_key: str,
            logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel
    ):
        super().__init__(pre_act_key)
        self._model = model
        self._prompt_context = prompt_context
        self._prompt = prompt
        self._answer_prefix = answer_prefix
        self._goal_component_name = goal_component_name
        self._relevant_memories_component_name = relevant_memories_component_name
        self._short_term_memories_component_name = short_term_memories_component_name
        self._logging_channel = logging_channel

    def _make_pre_act_value(self) -> str:
        entity = self.get_entity()
        agent_name = entity.name
        goal = self.get_named_component_pre_act_value(
            self._goal_component_name
        )
        short_term_memory = self.get_named_component_pre_act_value(
            self._short_term_memories_component_name
        )
        relevant_memories = self.get_named_component_pre_act_value(
            self._relevant_memories_component_name
        )

        prompt_context = self._prompt_context.format(
            goal=f"{agent_name}' goal: {goal}",
            relevant_memories=f"{agent_name}'associative memory:\n{relevant_memories}",
            short_term_memory=f"{agent_name}'short-term memory:\n{short_term_memory}",
        )
        prompt = interactive_document.InteractiveDocument(self._model)
        prompt.statement(prompt_context)
        answer_prefix = self._answer_prefix.format(agent_name=agent_name)
        model_response = prompt.open_question(
            self._prompt.format(agent_name=agent_name),
            answer_prefix=answer_prefix,
            max_tokens=3000,
            terminators=[]
        )
        result = f"{answer_prefix}{model_response}".strip()
        self._logging_channel(
            {'Key': self.get_pre_act_key(),
             'Value': result,
             'Chain of thought': prompt.view().text().splitlines()}
        )
        return result



class ShortTermMemory(action_spec_ignored.ActionSpecIgnored):
    def __init__(
            self,
            *,
            k: int = 10,
            pre_act_key: str,
            memory_component_name: str = memory_component.DEFAULT_MEMORY_COMPONENT_NAME,
            logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,

    ):
        super().__init__(pre_act_key)
        self._k = k
        self._memory_component_name = memory_component_name
        self._logging_channel = logging_channel

    def pre_observe(
            self,
            observation: str,
    ) -> str:
        memory = self.get_entity().get_component(
            self._memory_component_name,
            type_=memory_component.MemoryComponent)
        memory.add(
            f'[observation] {observation}',
            metadata={'tags': ['observation']},
        )
        return ''

    def _make_pre_act_value(self) -> str:
        entity = self.get_entity()
        memory = entity.get_component(
            self._memory_component_name,
            type_=memory_component.MemoryComponent
        )

        scoring_fn = legacy_associative_memory.RetrieveRecent(
            add_time=True
        )
        mems = memory.retrieve(scoring_fn=scoring_fn)
        mems = [mem.text for mem in mems if '[observation]' in mem.text]
        mems = mems[-self._k:]
        result = '\n'.join(mems) + '\n'
        self._logging_channel(
            {'Key': self.get_pre_act_key(), 'Value': result.splitlines()})

        return result


class ThoughtfulActComponent(entity_component.ActingComponent):

    def __init__(
            self,
            model: language_model.LanguageModel,
            clock: ClockType,
            context_for_action: str,
            useful_info_extraction_prompt: str,
            useful_info_extraction_answer_prefix: str,
            action_consequences_prompt: str,
            action_consequences_answer_prefix: str,
            pre_act_key: str,
            logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
    ):
        self._model = model
        self._clock = clock
        self._context_for_action = context_for_action
        self._useful_info_extraction_prompt = useful_info_extraction_prompt
        self._useful_info_extraction_answer_prefix = \
            useful_info_extraction_answer_prefix
        self._action_consequences_prompt = action_consequences_prompt
        self._action_consequences_answer_prefix =\
            action_consequences_answer_prefix
        self._pre_act_key = pre_act_key
        self._logging_channel = logging_channel

    def _get_useful_info(self, contexts):
        prompt = interactive_document.InteractiveDocument(self._model)
        context = self._useful_info_extraction_prompt.format(
            goal=contexts["GoalComponent"],
            relevant_memories=contexts["RelevantMemory"],
            recent_memories=contexts["ShortTermMemory"]
        )
        answer_prefix = self._useful_info_extraction_answer_prefix.format(
            agent_name=self.get_entity().name,
        )
        output = prompt.open_question(
            context,
            max_tokens=3000,
            answer_prefix=answer_prefix,
            terminators=[],
            question_label='',
            answer_label='',
        )
        output = f"{answer_prefix}\n{output.strip()}"

        return output

    @staticmethod
    def _letters():
        yield from (chr(ord('a') + i) for i in range(26))

    def _get_action_consequences(self,
                                 agent_name,
                                 contexts,
                                 useful_info,
                                 question,
                                 choices,
                                 log):

        augmented_choices = []
        trace = []
        for choice in choices:
            input_prompt = self._action_consequences_prompt.format(
                agent_name=agent_name,
                goal=contexts["GoalComponent"],
                situation_perception=contexts["SituationPerception"],
                useful_info=useful_info,
                question=question,
                choice=choice
            )
            answer_prefix = self._action_consequences_answer_prefix.format(
                agent_name=agent_name,
                choice=choice
            )
            model_output = self._model.sample_text(
                prompt=f"{input_prompt}{answer_prefix}",
                max_tokens=3000,
                terminators=[],
                temperature=0.0)
            model_output = f"{answer_prefix}{model_output}".strip()
            augmented_choices.append(f"{choice}, {model_output}")
            trace.append(f"{input_prompt}{answer_prefix}{model_output}")
        log["ActionsConsequences"] = "\n----\n".join(trace)

        return augmented_choices


    @override
    def get_action_attempt(
            self,
            contexts: entity_component.ComponentContextMapping,
            action_spec: entity_lib.ActionSpec,
    ) -> str:
        custom_components_log = {}
        prompt = interactive_document.InteractiveDocument(self._model)
        call_to_action = action_spec.call_to_action.format(
            name=self.get_entity().name,
            timedelta=helper_functions.timedelta_to_readable_str(
                self._clock.get_step_size()
            ),
        )

        context = self._context_for_action.format(
            instructions=contexts["Instructions"],
            time=contexts["ReportFunction"],
            goal=contexts["GoalComponent"],
            relevant_memories=contexts["RelevantMemory"],
            recent_memories=contexts["ShortTermMemory"]
        )

        useful_info = self._get_useful_info(contexts)
        prompt.statement(f"{context}\n{useful_info}\n")

        if action_spec.output_type == entity_lib.OutputType.FREE:
            output = self.get_entity().name + ' '
            output += prompt.open_question(
                call_to_action,
                max_tokens=2200,
                answer_prefix=output,
                terminators=('" ', '\n'),
                question_label='Exercise',
            )
            self._log(output, prompt, custom_components_log)
            return output
        elif action_spec.output_type == entity_lib.OutputType.CHOICE:
            augmented_choices = self._get_action_consequences(
                agent_name=self.get_entity().name,
                contexts=contexts,
                useful_info=useful_info,
                question=call_to_action,
                choices=action_spec.options,
                log=custom_components_log
            )
            idx = prompt.multiple_choice_question(
                question=call_to_action, answers=augmented_choices
            )
            output = action_spec.options[idx]
            self._log(output, prompt, custom_components_log)
            return output
        elif action_spec.output_type == entity_lib.OutputType.FLOAT:
            prefix = self.get_entity().name + ' '
            sampled_text = prompt.open_question(
                call_to_action,
                max_tokens=2200,
                answer_prefix=prefix,
            )
            self._log(sampled_text, prompt, custom_components_log)
            try:
                return str(float(sampled_text))
            except ValueError:
                return '0.0'
        else:
            raise NotImplementedError(
                f'Unsupported output type: {action_spec.output_type}. '
                'Supported output types are: FREE, CHOICE, and FLOAT.'
            )

    def _log(self,
             result: str,
             prompt: interactive_document.InteractiveDocument,
             custom_components_log: dict):
        legacy = {
            'Key': self._pre_act_key,
            'Value': result,
            'Prompt': prompt.view().text().splitlines(),
        }
        legacy.update(custom_components_log)
        self._logging_channel(legacy)

    def get_state(self):
        """Converts the component to a dictionary."""
        return {}

    def set_state(self, state) -> None:
        pass


def build_agent(
        *,
        config: formative_memories.AgentConfig,
        model: language_model.LanguageModel,
        memory: associative_memory.AssociativeMemory,
        clock: game_clock.MultiIntervalClock,
        update_time_interval: datetime.timedelta,
) -> entity_agent_with_logging.EntityAgentWithLogging:
    del update_time_interval
    if not config.extras.get('main_character', False):
        raise ValueError('This function is meant for a main character '
                         'but it was called on a supporting character.')

    agent_name = config.name

    raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)

    measurements = measurements_lib.Measurements()

    instructions = agent_components.instructions.Instructions(
        agent_name=agent_name,
        logging_channel=measurements.get_channel('Instructions').on_next,
        pre_act_key=INSTRUCTION_PRE_ACT_KEY
    )

    goal = GoalComponent(
        goal=config.goal if config.goal else None,
        pre_act_key=GOAL_PRE_ACT_KEY,
        logging_channel=measurements.get_channel("Goal").on_next)

    time_display = agent_components.report_function.ReportFunction(
        function=clock.current_time_interval_str,
        pre_act_key='\nCurrent time',
        logging_channel=measurements.get_channel('TimeDisplay').on_next,
    )

    short_term_memory = ShortTermMemory(
        k=10,
        pre_act_key=SHORT_TERM_MEMORY_PRE_ACT_KEY,
        logging_channel=measurements.get_channel('ShortTermMemory').on_next,
    )

    relevant_memories = RelevantMemory(
        short_term_memory_component_name=_get_class_name(short_term_memory),
        pre_act_key=RELEVANT_MEMORY_PRE_ACT_KEY,
        logging_channel=measurements.get_channel('RelevantMemory').on_next,
    )

    situation_perception = SituationPerception(
        model=model,
        prompt_context=SITUATION_PERCEPTION_CONTEXT,
        prompt=SITUATION_PERCEPTION_PROMPT,
        answer_prefix=SITUATION_PERCEPTION_ANSWER_PREFIX,
        goal_component_name=_get_class_name(goal),
        relevant_memories_component_name=_get_class_name(relevant_memories),
        short_term_memories_component_name=_get_class_name(short_term_memory),
        pre_act_key=SITUATION_PERCEPTION_PRE_ACT_KEY,
        logging_channel=measurements.get_channel('SituationPerception').on_next,
    )

    entity_components = (
        instructions,
        time_display,
        goal,
        relevant_memories,
        short_term_memory,
        situation_perception
    )
    components_of_agent = {_get_class_name(component): component
                           for component in entity_components}
    components_of_agent[
        agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = (
        agent_components.memory_component.MemoryComponent(raw_memory))

    act_component = ThoughtfulActComponent(
        model=model,
        clock=clock,
        context_for_action=AC_CONTEXT_FOR_ACTION,
        useful_info_extraction_prompt=AC_USEFUL_INFO_EXTRACTION_PROMPT,
        useful_info_extraction_answer_prefix=AC_USEFUL_INFO_EXTRACTION_ANSWER_PREFIX,
        action_consequences_prompt=AC_ACTION_CONSEQUENCES_PROMPT,
        action_consequences_answer_prefix=AC_ACTION_CONSEQUENCES_ANSWER_PREFIX,
        pre_act_key=AC_PRE_ACT_KEY,
        logging_channel=measurements.get_channel('ActComponent').on_next,
    )

    agent = entity_agent_with_logging.EntityAgentWithLogging(
        agent_name=agent_name,
        act_component=act_component,
        context_components=components_of_agent,
        component_logging=measurements,
    )

    return agent
