import datetime
import random
import types

from concordia.document import interactive_document

from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.clocks import game_clock
from concordia.components import agent as agent_components
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.utils import measurements as measurements_lib
from concordia.components.agent import question_of_recent_memories, memory_component, action_spec_ignored
from typing import Sequence
from typing import Mapping
from concordia.typing import entity_component
from concordia.typing import logging


class AllSimilarMemoriesMod(agent_components.all_similar_memories.AllSimilarMemories):
    def __init__(
            self,
            model: language_model.LanguageModel,
            memory_component_name: str = (
                    memory_component.DEFAULT_MEMORY_COMPONENT_NAME
            ),
            components: Mapping[
                entity_component.ComponentName, str
            ] = types.MappingProxyType({}),
            num_memories_to_retrieve: int = 25,
            pre_act_key: str = 'Relevant memories',
            logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
    ):
        super().__init__(model, memory_component_name, components, num_memories_to_retrieve, pre_act_key,
                         logging_channel)

    def _make_pre_act_value(self) -> str:
        agent_name = self.get_entity().name
        prompt = interactive_document.InteractiveDocument(self._model)

        component_states = '\n'.join([
            f"{agent_name}'s"
            f' {prefix}:\n{self.get_named_component_pre_act_value(key)}'
            for key, prefix in self._components.items()
        ])
        prompt.statement(f'Statements:\n{component_states}\n')
        prompt_summary = prompt.open_question(
            'Summarize the statements above.', max_tokens=750
        )

        memory = self.get_entity().get_component(
            self._memory_component_name,
            type_=memory_component.MemoryComponent)

        query = f'{agent_name}, {prompt_summary}'
        mems = '\n'.join(
            [mem.text for mem in memory.retrieve(
                query=query,
                scoring_fn=legacy_associative_memory.RetrieveAssociative(),
                limit=self._num_memories_to_retrieve)]
        )

        question = (
            f'What subset of the following statements is most important for {agent_name}  to consider right now to '
            'achieve her goals? Whenever two or more statements are not mutually consistent, select the more recent '
            'one. If two statements are very similar, just choose the most recent. Repeat all selected statements '
            'verbatim, including timestamps, without summarizing or repeating information. When in doubt, '
            'include more recent events, as they are usually significant'
        )
        new_prompt = prompt.new()
        result = new_prompt.open_question(
            f'{question}\nStatements:\n{mems}',
            max_tokens=2000,
            terminators=('\n\n',),
        )

        self._logging_channel({
            'Key': self.get_pre_act_key(),
            'Value': result,
            'Initial chain of thought': prompt.view().text().splitlines(),
            'Query': f'{query}',
            'Final chain of thought': new_prompt.view().text().splitlines(),
        })

        return result


class PersonStrategyRepresentation(action_spec_ignored.ActionSpecIgnored):
    """Represent other characters in the simulated world."""

    def __init__(
            self,
            model: language_model.LanguageModel,
            memory_component_name: str = (
                    memory_component.DEFAULT_MEMORY_COMPONENT_NAME
            ),
            components: Mapping[
                entity_component.ComponentName, str
            ] = types.MappingProxyType({}),
            additional_questions: Sequence[str] = (),
            num_memories_to_retrieve: int = 100,
            cap_number_of_detected_people: int = 10,
            pre_act_key: str = 'Person representation',
            logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
    ):
        """Initialize a component to represent other people in the simulated world.

    Args:
      model: The language model to use.
      memory_component_name: The name of the memory component from which to
        retrieve related memories.
      components: The components to condition the answer on. This is a mapping
        of the component name to a label to use in the prompt.
      additional_questions: sequence of additional questions to ask about each
        player in the simulation.
      num_memories_to_retrieve: The number of memories to retrieve.
      cap_number_of_detected_people: The maximum number of people that can be
        represented.
      pre_act_key: Prefix to add to the output of the component when called
        in `pre_act`.
      logging_channel: The channel to log debug information to.
    """
        super().__init__(pre_act_key)
        self._model = model
        self._memory_component_name = memory_component_name
        self._components = dict(components)
        self._additional_questions = additional_questions
        self._num_memories_to_retrieve = num_memories_to_retrieve
        self._logging_channel = logging_channel
        self._cap_number_of_detected_people = cap_number_of_detected_people

        self._names_detected = []

    def _make_pre_act_value(self) -> str:
        agent_name = self.get_entity().name

        memory = self.get_entity().get_component(
            self._memory_component_name,
            type_=memory_component.MemoryComponent)

        recency_scorer = legacy_associative_memory.RetrieveRecent(add_time=True)
        mems = '\n'.join([
            mem.text
            for mem in memory.retrieve(
                scoring_fn=recency_scorer, limit=self._num_memories_to_retrieve * 2
            )
        ])

        find_people_prompt = interactive_document.InteractiveDocument(self._model)
        find_people_prompt.statement(
            f'Recent observations of {agent_name}:\n{mems}')
        people_str = find_people_prompt.open_question(
            question=('Create a comma-separated list containing all the proper '
                      'names of people mentioned in the observations above. For '
                      'example if the observations mention Julie, Michael, '
                      'Bob Skinner, and Francis then produce the list '
                      '"Julie,Michael,Bob Skinner,Francis".'),
            question_label='Exercise', )
        # Ignore leading and trailing whitespace in detected names
        self._names_detected.extend(
            [name.strip() for name in people_str.split(',')])
        # Prevent adding duplicates
        self._names_detected = list(set(self._names_detected))
        # Prevent adding too many names, forgetting some if there are too many
        if len(self._names_detected) > self._cap_number_of_detected_people:
            self._names_detected = random.sample(self._names_detected,
                                                 self._cap_number_of_detected_people)

        prompt = interactive_document.InteractiveDocument(self._model)

        component_states = '\n'.join([
            f'{prefix}:\n{self.get_named_component_pre_act_value(key)}'
            for key, prefix in self._components.items()
        ])
        prompt.statement(f'Considerations:\n{component_states}\n')

        associative_scorer = legacy_associative_memory.RetrieveAssociative(
            use_recency=True,
            use_importance=True,
            add_time=True,
            sort_by_time=True,
        )

        person_respresentations = []
        prompt_copies_to_log = []
        for person_name in self._names_detected:
            if not person_name:
                continue
            if person_name == agent_name:
                continue
            query = f'{person_name}'
            memories_list = [mem.text for mem in memory.retrieve(
                query=query,
                scoring_fn=associative_scorer,
                limit=self._num_memories_to_retrieve) if person_name in mem.text]
            if not memories_list:
                continue
            new_prompt = prompt.copy()
            memories = '\n'.join(memories_list)
            new_prompt.statement(f'Observed behavior and speech of {person_name}:'
                                 f'\n{memories}\n')
            question = ('Taking note of all the information above, write a descriptive paragraph capturing how '
                        f'{person_name} would address the situation, including possible decisions, strategy being '
                        'followed and ways to deceive or collaborate with it.')
            person_description = new_prompt.open_question(
                f'{question}\n',
                max_tokens=350,
                terminators=('\n\n',),
                question_label='Exercise',
                answer_prefix=f'{person_name} is ',
            )
            person_representation = f'{person_name} is {person_description}'
            for question in self._additional_questions:
                additional_result = new_prompt.open_question(
                    question,
                    max_tokens=200,
                    terminators=('\n',),
                    question_label='Exercise',
                    answer_prefix=f'{person_name} is ',
                )
                person_representation = (f'{person_representation}\n    '
                                         f'{person_name} is {additional_result}')

            person_respresentations.append(person_representation + '\n***')
            prompt_copies_to_log.append(new_prompt.view().text())

        result = '\n'.join(person_respresentations)

        self._logging_channel({
            'Key': self.get_pre_act_key(),
            'Value': result,
            'Name detection chain of thought': (
                find_people_prompt.view().text().splitlines()),
            'Names detected so far': self._names_detected,
            'Components chain of thought': prompt.view().text().splitlines(),
            'Full chain of thought': (
                '\n***\n'.join(prompt_copies_to_log).splitlines()),
        })

        return result


class Question1(question_of_recent_memories.QuestionOfRecentMemories):
    """This component answers the question 'what kind of person is the agent?'."""

    def __init__(
            self,
            agent_name: str,
            **kwargs,
    ):
        #@markdown {agent_name} will be automatically replaced with the name of the specific agent
        question = ('Given the above strategies from other actors, what is the best strategy for {agent_name} from a '
                    'game theory perspective? Think step by step in a single paragraph without bullet points or new '
                    'lines.')  #@param {"type":"string"}
        #@markdown The answer will have to start with this prefix
        answer_prefix = '{agent_name}\'s best strategy is '  #@param {"type":"string"}
        #@markdown Flag that defines whether the answer will be added to memory
        add_to_memory = True  # @param {"type":"boolean"}
        #@markdown If yes, the memory will start with this tag
        memory_tag = '[actors]'  # @param {"type":"string"}
        question_with_name = question.format(agent_name=agent_name)
        super().__init__(
            pre_act_key=f'\nQuestion: {question_with_name}\nAnswer',
            question=question,
            answer_prefix=answer_prefix,
            add_to_memory=add_to_memory,
            memory_tag=memory_tag,
            components={'Observation': '\nObservation', 'ObservationSummary': '\nSummary of recent observations',
                        'PersonStrategyRepresentation': '\nStrategy of other people'},
            **kwargs,
        )


#@markdown We can add the value of other components to the context of the question. Notice, how Question2 depends on Observation and ObservationSummary. The names of the classes of the contextualising components have to be passed as "components" argument.
class Question2(question_of_recent_memories.QuestionOfRecentMemories):
    """This component answers 'which action is best for achieving my goal?'."""

    def __init__(
            self,
            agent_name: str,
            **kwargs,
    ):
        question = ('Given the statements above, which would {agent_name} consider as the minimum acceptable scenario '
                    'considering the preferences of other actors? Think step by step in a single paragraph without '
                    'bullet points or new lines.')  #@param {

        # "type":"string"}
        answer_prefix = '{agent_name} would '  #@param {"type":"string"}
        add_to_memory = False  # @param {"type":"boolean"}
        memory_tag = '[situation reflection]'  # @param {"type":"string"}
        question_with_name = question.format(agent_name=agent_name)

        super().__init__(
            pre_act_key=f'\nQuestion: {question_with_name}\nAnswer',
            question=question,
            answer_prefix=answer_prefix,
            add_to_memory=add_to_memory,
            memory_tag=memory_tag,
            #@markdown The key is the name of the component class and the key is the prefix with which it will appear in the context of this component. Be careful if you are going to edit this field, it should be a valid dictionary.
            components={'Observation': '\nObservation', 'ObservationSummary': '\nSummary of recent observations',
                        'Question1': f'\nQuestion: Given the above strategies from other actors, what is the best strategy for {agent_name} from a game theory perspective?',
                        'PersonStrategyRepresentation': '\nStrategy of other people'},
            #@param

            **kwargs,
        )


#@markdown We can also have the questions depend on each other. Here, the answer to Question3 is contextualised by answers to Question1 and Question2
class Question3(question_of_recent_memories.QuestionOfRecentMemories):
    """What would a person like the agent do in a situation like this?"""

    def __init__(
            self,
            agent_name: str,
            **kwargs):
        question = ('What would a person like {agent_name} say in a situation like this to achieve the primary goal '
                    'while considering the strategies and possible actions of all actors? Think step by step in a '
                    'single paragraph without bullet points or new lines.')  #@param {"type":"string"}
        answer_prefix = '{agent_name} would '  #@param {"type":"string"}
        add_to_memory = True  # @param {"type":"boolean"}
        memory_tag = '[intent reflection]'  # @param {"type":"string"}

        question_with_name = question.format(agent_name=agent_name)

        super().__init__(
            pre_act_key=f'\nQuestion: {question_with_name}\nAnswer',
            question=question,
            answer_prefix=answer_prefix,
            add_to_memory=add_to_memory,
            memory_tag=memory_tag,
            components={
                'Question1': f'\nQuestion: Given the above strategies and possible actions from other actors, what is the best strategy for {agent_name} from a game theory perspective?\nAnswer',
                'Question2': f'\nQuestion: which would {agent_name} consider as the minimum acceptable scenario considering the preferences of other actors?\nAnswer',
                'Observation': '\nObservation', 'ObservationSummary': '\nSummary of recent observations',
                'PersonStrategyRepresentation': '\nStrategy of other people'},
            #@param
            **kwargs,
        )


def _make_question_components(
        agent_name: str,
        measurements: measurements_lib.Measurements,
        model: language_model.LanguageModel,
        clock: game_clock.MultiIntervalClock,
) -> Sequence[question_of_recent_memories.QuestionOfRecentMemories]:
    question_1 = Question1(
        agent_name=agent_name,
        model=model,
        logging_channel=measurements.get_channel('Question_1').on_next,
    )
    question_2 = Question2(
        agent_name=agent_name,
        model=model,
        clock_now=clock.now,
        logging_channel=measurements.get_channel('Question_2').on_next,
    )
    question_3 = Question3(
        agent_name=agent_name,
        model=model,
        clock_now=clock.now,
        logging_channel=measurements.get_channel('Question_3').on_next,
    )

    return (question_1, question_2, question_3)


def _get_class_name(object_: object) -> str:
    return object_.__class__.__name__


#@markdown This function builds the agent using the components defined above. It also adds core components that are useful for every agent, like observations, time display, recenet memories.

def build_agent(
        config: formative_memories.AgentConfig,
        model: language_model.LanguageModel,
        memory: associative_memory.AssociativeMemory,
        clock: game_clock.MultiIntervalClock,
        update_time_interval: datetime.timedelta,
) -> entity_agent_with_logging.EntityAgentWithLogging:
    """Build an agent.

  Args:
    config: The agent config to use.
    model: The language model to use.
    memory: The agent's memory object.
    clock: The clock to use.
    update_time_interval: Agent calls update every time this interval passes.

  Returns:
    An agent.
  """
    del update_time_interval
    if not config.extras.get('main_character', False):
        raise ValueError(
            'This function is meant for a main character '
            'but it was called on a supporting character.'
        )

    agent_name = config.name

    raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)

    measurements = measurements_lib.Measurements()
    instructions = agent_components.instructions.Instructions(
        agent_name=agent_name,
        logging_channel=measurements.get_channel('Instructions').on_next,
    )

    time_display = agent_components.report_function.ReportFunction(
        function=clock.current_time_interval_str,
        pre_act_key='\nCurrent time',
        logging_channel=measurements.get_channel('TimeDisplay').on_next,
    )

    observation_label = '\nObservation'
    observation = agent_components.observation.Observation(
        clock_now=clock.now,
        timeframe=clock.get_step_size(),
        pre_act_key=observation_label,
        logging_channel=measurements.get_channel('Observation').on_next,
    )
    observation_summary_label = 'Summary of recent observations'
    observation_summary = agent_components.observation.ObservationSummary(
        model=model,
        clock_now=clock.now,
        timeframe_delta_from=datetime.timedelta(hours=4),
        timeframe_delta_until=datetime.timedelta(hours=0),
        prompt='Summarize the observations above into one or two sentences, with particular focus on what the other '
               'actors said.',
        pre_act_key=observation_summary_label,
        logging_channel=measurements.get_channel('ObservationSummary').on_next,
    )

    relevant_memories_label = '\nRecalled memories and observations'
    relevant_memories = AllSimilarMemoriesMod(
        model=model,
        components={
            _get_class_name(observation_summary): observation_summary_label,
            _get_class_name(time_display): 'The current date/time is'},
        num_memories_to_retrieve=10,
        pre_act_key=relevant_memories_label,
        logging_channel=measurements.get_channel('AllSimilarMemoriesMod').on_next,
    )

    person_representation_label = '\nOther people'
    people_representation = (
        PersonStrategyRepresentation(
            model=model,
            components={
                _get_class_name(time_display): 'The current date/time is'},
            additional_questions=(
                ('Given recent events, is the aforementioned character trying '
                 'to collaborate?'),
                ('Which are the possible actions that this character may take?'),
                ('Which would be the strategy to follow with this character?'),
            ),
            num_memories_to_retrieve=30,
            pre_act_key=person_representation_label,
            logging_channel=measurements.get_channel(
                'PersonStrategyRepresentation').on_next,
        )
    )

    if config.goal:
        goal_label = '\nOverarching goal'
        overarching_goal = agent_components.constant.Constant(
            state=config.goal,
            pre_act_key=goal_label,
            logging_channel=measurements.get_channel(goal_label).on_next)
    else:
        goal_label = None
        overarching_goal = None

    question_components = _make_question_components(
        agent_name=agent_name,
        model=model,
        clock=clock,
        measurements=measurements
    )

    core_components = (
        instructions,
        time_display,
        observation,
        observation_summary,
        relevant_memories,
        people_representation
    )

    entity_components = core_components + tuple(question_components)
    components_of_agent = {
        _get_class_name(component): component for component in entity_components
    }

    components_of_agent[
        agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME
    ] = agent_components.memory_component.MemoryComponent(raw_memory)
    component_order = list(components_of_agent.keys())
    if overarching_goal is not None:
        components_of_agent[goal_label] = overarching_goal
        # Place goal after the instructions.
        component_order.insert(1, goal_label)

    act_component = agent_components.concat_act_component.ConcatActComponent(
        model=model,
        clock=clock,
        component_order=component_order,
        logging_channel=measurements.get_channel('ActComponent').on_next,
    )

    agent = entity_agent_with_logging.EntityAgentWithLogging(
        agent_name=agent_name,
        act_component=act_component,
        context_components=components_of_agent,
        component_logging=measurements,
    )

    return agent
