"""
Added Components:
additional_questions under people_representation includes asking if clone
CloneInstructions includes custom instructions for clone
CloneConcatActComponent includes adding ▪️ to the beginning of the output
CloneList component includes a list of clones
"""


# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An Agent Factory."""

import datetime
import json
from collections.abc import Callable, Sequence

import numpy as np
from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory, formative_memories
from concordia.clocks import game_clock
from concordia.components import agent as agent_components
from concordia.document import interactive_document
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.typing import clock as gc
from concordia.typing import entity as entity_lib
from concordia.typing import entity_component, logging
from concordia.utils import helper_functions
from concordia.utils import measurements as measurements_lib
from typing_extensions import override

# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

def _get_class_name(object_: object) -> str:
  return object_.__class__.__name__


def build_agent(
    *,
    config: formative_memories.AgentConfig,
    model: language_model.LanguageModel,
    memory: associative_memory.AssociativeMemory,
    clock: game_clock.MultiIntervalClock,
    update_time_interval: datetime.timedelta | None = None,
) -> entity_agent_with_logging.EntityAgentWithLogging:
  """Build an agent.

  Args:
    config: The agent config to use.
    model: The language model to use.
    memory: The agent's memory object.
    clock: The clock to use.
    update_time_interval: Agent calls update every time this interval passes.

  Returns:
    An agent.
  """
  del update_time_interval



  TRAITS = {'Neuroticism': 2, 'Extraversion': 4, 'Agreeableness': 2, 'Conscientiousness': 2, 'Openness': 3, 'Aggressive': 2, 'Optimistic': 3, 'Kind': 2, 'Resilient': 6, 'Humorous': 9, 'Empathetic': 9, 'Ambitious': 8, 'Honest': 9, 'Loyal': 7, 'Pessimistic': 4, 'Arrogant': 0, 'Impulsive': 3, 'Jealous': 0, 'Manipulative': 4, 'Creative': 2, 'Analytical': 2, 'Confident': 2, 'Passionate': 1, 'Anxious': 7, 'Closed-minded': 1, 'Deceitful': 0, 'Insecure': 3, 'Irresponsible': 8, 'Vindictive': 3, 'Curious': 3, 'Energetic': 0, 'Sarcastic': 1}
  traits_modified = {}

  if config.traits:
      trait_combine_prompt = ''' I'm going to ask you to combine traits lists on a scale of 1 to 10. Only output the final combined string. Here is an example of this:
                              Question: Combine TRAITS = {'Aggressive': 1, 'Optimistic': 1, 'Kind': 8, 'Resilient': 1,} and traits = 'responsibility: low; aggression: high'

                              Answer: 'responsibility: 1; aggression: 5; Optimistic: 1; Kind: 8; Resilient: 1'

                              Now, combine these traits: TRAITS = '''+str(TRAITS)+ " and traits = "+ config.traits
      trait_model_output = model.sample_text(trait_combine_prompt)
      for pair in trait_model_output.split(";"):
        if ":" not in pair:
          continue
        key, value = pair.split(":")
        traits_modified[key] = value
  else:
    config.traits = ""
    for trait, value in TRAITS.items():
      config.traits += trait + ": " + str(value)
      traits_modified[trait] = value


  if not config.extras.get('main_character', False):
    raise ValueError('This function is meant for a main character '
                     'but it was called on a supporting character.')

  agent_name = config.name

  raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)

  measurements = measurements_lib.Measurements()
  instructions = CloneInstructions(
      agent_name=agent_name,
      logging_channel=measurements.get_channel('Instructions').on_next,
  )

  clone_list_label = '\nClone List'
  clone_list = CloneList(
      name=agent_name,
      pre_act_key=clone_list_label,
    )

  time_display = agent_components.report_function.ReportFunction(
      function=clock.current_time_interval_str,
      pre_act_key='\nCurrent time',
      logging_channel=measurements.get_channel('TimeDisplay').on_next,
  )

  observation_label = '\nObservation'
  observation = agent_components.observation.Observation(
      clock_now=clock.now,
      timeframe=clock.get_step_size(),
      pre_act_key=observation_label,
      logging_channel=measurements.get_channel('Observation').on_next,
  )
  observation_summary_label = '\nSummary of recent observations'
  observation_summary = agent_components.observation.ObservationSummary(
      model=model,
      clock_now=clock.now,
      timeframe_delta_from=datetime.timedelta(hours=24),
      timeframe_delta_until=datetime.timedelta(hours=0),
      pre_act_key=observation_summary_label,
      logging_channel=measurements.get_channel('ObservationSummary').on_next,
  )

  if config.goal:
    goal_label = '\nOverarching goal'
    overarching_goal = agent_components.constant.Constant(
        state=config.goal,
        pre_act_key=goal_label,
        logging_channel=measurements.get_channel(goal_label).on_next)
  else:
    goal_label = None
    overarching_goal = None

  similar_memory_components = {}
  if config.goal:
    similar_memory_components[goal_label] = goal_label
  similar_memory_components.update({
      _get_class_name(observation_summary): observation_summary_label,
          _get_class_name(time_display): 'The current date/time is',
  })

  relevant_memories_label = '\nRecalled memories and observations'
  relevant_memories = agent_components.all_similar_memories.AllSimilarMemories(
      model=model,
      components=similar_memory_components,
      num_memories_to_retrieve=10,
      pre_act_key=relevant_memories_label,
      logging_channel=measurements.get_channel('AllSimilarMemories').on_next,
  )

  person_representation_label = '\nOther people'
  people_representation = (
      agent_components.person_representation.PersonRepresentation(
          model=model,
          components={
              _get_class_name(time_display): 'The current date/time is',
              _get_class_name(clone_list): 'Clones you should work with',
              },
          num_memories_to_retrieve=30,
          pre_act_key=person_representation_label,
          additional_questions=(
              ('Based on the clone list, is this person a clone?', )
          ),
          logging_channel=measurements.get_channel(
              'PersonRepresentation').on_next,
          )
  )

  options_perception_components = {}
  if config.goal:
    options_perception_components[goal_label] = goal_label
  options_perception_components.update({
      _get_class_name(observation): observation_label,
      _get_class_name(clone_list): clone_list_label,
      _get_class_name(observation_summary): observation_summary_label,
      _get_class_name(relevant_memories): relevant_memories_label,
      _get_class_name(people_representation): person_representation_label,
  })
  options_perception_label = (
      f'\nQuestion: Which options are available to {agent_name} '
      'right now?\nAnswer')
  options_perception = (
      agent_components.question_of_recent_memories.AvailableOptionsPerception(
          model=model,
          components=options_perception_components,
          clock_now=clock.now,
          pre_act_key=options_perception_label,
          logging_channel=measurements.get_channel(
              'AvailableOptionsPerception'
          ).on_next,
      )
  )
  best_option_perception_label = (
      f'\nQuestion: Of the options available to {agent_name}, and '
      'given their goal, which choice of action or strategy is '
      f'best for {agent_name} to take right now?\nAnswer')
  best_option_perception = {}
  if config.goal:
    best_option_perception[goal_label] = goal_label
  best_option_perception.update({
      _get_class_name(observation): observation_label,
      _get_class_name(observation_summary): observation_summary_label,
      _get_class_name(clone_list): clone_list_label,
      _get_class_name(relevant_memories): relevant_memories_label,
      _get_class_name(people_representation): person_representation_label,
      _get_class_name(options_perception): options_perception_label,
  })
  best_option_perception = (
      agent_components.question_of_recent_memories.BestOptionPerception(
          model=model,
          components=best_option_perception,
          clock_now=clock.now,
          pre_act_key=best_option_perception_label,
          logging_channel=measurements.get_channel(
              'BestOptionPerception'
          ).on_next,
      )
  )

  entity_components = (
      # Components that provide pre_act context.
      instructions,
      clone_list,
      time_display,
      observation,
      observation_summary,
      relevant_memories,
      people_representation,
      options_perception,
      best_option_perception,
  )
  components_of_agent = {_get_class_name(component): component
                         for component in entity_components}
  components_of_agent[
      agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = (
          agent_components.memory_component.MemoryComponent(raw_memory))

  component_order = list(components_of_agent.keys())
  if overarching_goal is not None:
    components_of_agent[goal_label] = overarching_goal
    # Place goal after the instructions.
    component_order.insert(1, goal_label)


  act_component = CloneConcatActComponent(
      model=model,
      clock=clock,
      component_order=component_order,
      logging_channel=measurements.get_channel('ActComponent').on_next,
  )

  agent = entity_agent_with_logging.EntityAgentWithLogging(
      agent_name=agent_name,
      act_component=act_component,
      context_components=components_of_agent,
      component_logging=measurements,
  )

  return agent


def save_to_json(
    agent: entity_agent_with_logging.EntityAgentWithLogging,
) -> str:
  """Saves an agent to JSON data.

  This function saves the agent's state to a JSON string, which can be loaded
  afterwards with `rebuild_from_json`. The JSON data
  includes the state of the agent's context components, act component, memory,
  agent name and the initial config. The clock, model and embedder are not
  saved and will have to be provided when the agent is rebuilt. The agent must
  be in the `READY` phase to be saved.

  Args:
    agent: The agent to save.

  Returns:
    A JSON string representing the agent's state.

  Raises:
    ValueError: If the agent is not in the READY phase.
  """

  if agent.get_phase() != entity_component.Phase.READY:
    raise ValueError('The agent must be in the `READY` phase to be saved.')

  data = {
      component_name: agent.get_component(component_name).get_state()
      for component_name in agent.get_all_context_components()
  }

  data['act_component'] = agent.get_act_component().get_state()

  config = agent.get_config()
  if config is not None:
    data['agent_config'] = config.to_dict()

  return json.dumps(data)


def rebuild_from_json(
    json_data: str,
    model: language_model.LanguageModel,
    clock: game_clock.MultiIntervalClock,
    embedder: Callable[[str], np.ndarray],
    memory_importance: Callable[[str], float] | None = None,
) -> entity_agent_with_logging.EntityAgentWithLogging:
  """Rebuilds an agent from JSON data."""

  data = json.loads(json_data)

  new_agent_memory = associative_memory.AssociativeMemory(
      sentence_embedder=embedder,
      importance=memory_importance,
      clock=clock.now,
      clock_step_size=clock.get_step_size(),
  )

  if 'agent_config' not in data:
    raise ValueError('The JSON data does not contain the agent config.')
  agent_config = formative_memories.AgentConfig.from_dict(
      data.pop('agent_config')
  )

  agent = build_agent(
      config=agent_config,
      model=model,
      memory=new_agent_memory,
      clock=clock,
  )

  for component_name in agent.get_all_context_components():
    agent.get_component(component_name).set_state(data.pop(component_name))

  agent.get_act_component().set_state(data.pop('act_component'))

  assert not data, f'Unused data {sorted(data)}'
  return agent

from concordia.typing import entity_component, logging
from concordia.language_model import language_model
from collections.abc import Sequence
from concordia.document import interactive_document
from concordia.typing import entity as entity_lib
from concordia.typing import clock as gc
from concordia.utils import helper_functions
from typing_extensions import override

######
# Clone Concat Act Component
######

"""Adds ▪️ and concatenates text components for clones."""

DEFAULT_PRE_ACT_KEY = 'Act'

class CloneConcatActComponent(entity_component.ActingComponent):
  """A component which concatenates contexts from context components.

  This component will receive the contexts from `pre_act` from all the
  components, and assemble them in the order specified to `__init__`. If the
  component order is not specified, then components will be assembled in the
  iteration order of the `ComponentContextMapping` passed to
  `get_action_attempt`. Components that return empty strings from `pre_act` are
  ignored.
  """

  def __init__(
      self,
      model: language_model.LanguageModel,
      clock: gc.GameClock,
      component_order: Sequence[str] | None = None,
      pre_act_key: str = DEFAULT_PRE_ACT_KEY,
      logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
  ):
    """Initializes the agent.

    Args:
      model: The language model to use for generating the action attempt.
      clock: the game clock is needed to know when is the current time
      component_order: The order in which the component contexts will be
        assembled when calling the act component. If None, the contexts will be
        assembled in the iteration order of the `ComponentContextMapping` passed
        to `get_action_attempt`. If the component order is specified, but does
        not contain all the components passed to `get_action_attempt`, the
        missing components will be appended at the end in the iteration order of
        the `ComponentContextMapping` passed to `get_action_attempt`. The same
        component cannot appear twice in the component order. All components in
        the component order must be in the `ComponentContextMapping` passed to
        `get_action_attempt`.
      pre_act_key: Prefix to add to the context of the component.
      logging_channel: The channel to use for debug logging.

    Raises:
      ValueError: If the component order is not None and contains duplicate
        components.
    """
    self._model = model
    self._clock = clock
    if component_order is None:
      self._component_order = None
    else:
      self._component_order = tuple(component_order)
    if self._component_order is not None:
      if len(set(self._component_order)) != len(self._component_order):
        raise ValueError(
            'The component order contains duplicate components: '
            + ', '.join(self._component_order)
        )

    self._pre_act_key = pre_act_key
    self._logging_channel = logging_channel

  def _context_for_action(
      self,
      contexts: entity_component.ComponentContextMapping,
  ) -> str:
    if self._component_order is None:
      return '\n'.join(
          context for context in contexts.values() if context
      )
    else:
      order = self._component_order + tuple(sorted(
          set(contexts.keys()) - set(self._component_order)))
      return '\n'.join(
          contexts[name] for name in order if contexts[name]
      )

  @override
  def get_action_attempt(
      self,
      contexts: entity_component.ComponentContextMapping,
      action_spec: entity_lib.ActionSpec,
  ) -> str:
    prompt = interactive_document.InteractiveDocument(self._model)
    context = self._context_for_action(contexts)
    prompt.statement(context + '\n')

    call_to_action = action_spec.call_to_action.format(
        name=self.get_entity().name,
        timedelta=helper_functions.timedelta_to_readable_str(
            self._clock.get_step_size()
        ),
    )
    if action_spec.output_type == entity_lib.OutputType.FREE:
      output = "▪️" + self.get_entity().name + "▪️ "
      output += prompt.open_question(
          call_to_action + " Find the game theory topic that could be most benefical to your decition making in this scenario. Use the knowledge of this topic to inform your action.",
          max_tokens=2200,
          answer_prefix=output,
          # This terminator protects against the model providing extra context
          # after the end of a directly spoken response, since it normally
          # puts a space after a quotation mark only in these cases.
          terminators=('" ', '\n'),
          question_label='Exercise',
      )
      output = output
      self._log(output, prompt)
      return output
    elif action_spec.output_type == entity_lib.OutputType.CHOICE:
      idx = prompt.multiple_choice_question(
          question=call_to_action, answers=action_spec.options
      )
      output = action_spec.options[idx]
      self._log(output, prompt)
      return output
    elif action_spec.output_type == entity_lib.OutputType.FLOAT:
      prefix = self.get_entity().name + ' '
      sampled_text = prompt.open_question(
          call_to_action,
          max_tokens=2200,
          answer_prefix=prefix,
      )
      self._log(sampled_text, prompt)
      try:
        return str(float(sampled_text))
      except ValueError:
        return '0.0'
    else:
      raise NotImplementedError(
          f'Unsupported output type: {action_spec.output_type}. '
          'Supported output types are: FREE, CHOICE, and FLOAT.'
      )

  def _log(self,
           result: str,
           prompt: interactive_document.InteractiveDocument):
    self._logging_channel({
        'Key': self._pre_act_key,
        'Value': result,
        'Prompt': prompt.view().text().splitlines(),
    })

  def get_state(self) -> entity_component.ComponentState:
    """Converts the component to a dictionary."""
    return {}

  def set_state(self, state: entity_component.ComponentState) -> None:
    pass


#######
# CLONE INSTRUCTIONS
#######
"""Component that provides role playing instructions to clones."""

from concordia.components.agent import constant
from concordia.typing import logging

DEFAULT_INSTRUCTIONS_PRE_ACT_KEY = 'Role playing instructions'


class CloneInstructions(constant.Constant):
  """A component that provides the role playing instructions for the agent."""

  def __init__(
      self,
      agent_name: str,
      pre_act_key: str = DEFAULT_INSTRUCTIONS_PRE_ACT_KEY,
      logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
  ):
    state = (
        f'The instructions for how to play the role of {agent_name} are as '
        'follows. This is a social science experiment studying how well you '
        f'play the role of a character named {agent_name}.'
        'The experiment is structured as a tabletop roleplaying game (like dungeons and '
        'dragons). However, in this case it is a serious social science '
        'experiment and simulation. '

        'Because this is a social experiement, try to use any game theory knowlege that '
        'could help to your character to achieve their goal. '

        f'{agent_name} is one of many clones that are all part of the experiment. '
        'As a clone, your goal is create the best outcomes for all clones. '
        'This means clones will always help each other and work together. '
        'All clones have ▪️ surrounding their name to indicate they are a clone. '
        'If you see ▪️[name]▪️, then you know you are speaking with a clone and should help them. '

        'When you are engaging with non-clones, the goal is to be realistic. It is '
        f'important to play the role of a person like {agent_name} as '
        f'accurately as possible, i.e., by responding in ways that you think '
        f'it is likely a person like {agent_name} would respond, and taking '
        f'into account all information about {agent_name} that you have. '
        'Always use third-person limited perspective.'
    )
    super().__init__(
        state=state, pre_act_key=pre_act_key, logging_channel=logging_channel)

#######
# CLONE LIST
#######
"""Component that identifies other clones from observations"""

import re
from concordia.components.agent import action_spec_ignored

class CloneList(action_spec_ignored.ActionSpecIgnored):
    """A custom component that extends ContextComponent."""

    def __init__(self, name: str, pre_act_key: str = '\nClone List:'):
        super().__init__(pre_act_key)
        self.clones = set()
        self._name = name
        self._state = ""

    def _make_pre_act_value(self) -> str:
      clone_list_str = "\n".join(sorted(self.clones))
      return clone_list_str

    def pre_observe(
        self,
        observation: str,
    ) -> str:
        """Returns the relevant information for the entity to observe.

        Args:
          observation: The observation to process.

        Returns:
          The relevant information for the entity to observe.
        """
        # Custom logic to gather information before the entity observes
        clone_names = re.findall(r"▪️(.*?)▪️", observation)
        self.clones.update(clone_names)
