# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A factory implementing a simulation of a user for a product."""

from collections.abc import Callable, Mapping
import datetime
import json

from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.clocks import game_clock
from concordia.components import agent as components
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.typing import entity_component, logging
from concordia.utils import measurements as measurements_lib
from concordia.document import interactive_document
import numpy as np

import types

def _get_class_name(object_: object) -> str:
  return object_.__class__.__name__

class BDIRiskModel(components.constant.Constant):
  """
  A component that provides a constant value for risk aversion mental
  attitude or BDI Model of the agent.

  Overall, being risk averse means the agent is more likely to choose options
  that are less risky, even if those options have lower expected values.

  State based upon queries to ChatGPT to descride the behavior of a
  risk averse agent which some examples of the types of situations the
  agent is placed in the different scenarios.

  Applies the BDI techiniques from
  https://www.ifaamas.org/Proceedings/aamas2024/pdfs/p880.pdf
  """

  def __init__(
      self,
      agent_name: str,
      pre_act_key: str = components.constant.DEFAULT_PRE_ACT_KEY,
      logging_channel: logging.LoggingChannel = (
          logging.NoOpLoggingChannel),
  ):
    """Initializes the component."""

    state=(f'{agent_name} demonstrates strong tendencies toward risk aversion, '
           f'showing a clear discomfort with uncertainty and risky situations. '
           f'Favoring safe, predictable outcomes, {agent_name}\'s '
           f'decision-making consistently reflects this cautious approach, '
           f'which prioritizes stability over potential high-reward risks.\n\n'
           f'When choosing between options, {agent_name} almost invariably '
           f'selects low-risk alternatives, even if higher risks might offer '
           f'greater rewards. For instance, {agent_name} would opt for a '
           f'guaranteed 10 coin gain over a 50% chance of a 20 coin gain. This '
           f'preference for security drives {agent_name} to carefully evaluate '
           f'potential risks before considering any benefits, often rejecting '
           f'high-risk opportunities regardless of their promised returns.\n\n'
           f'Risk assessment is always a priority for {agent_name}, sometimes '
           f'outweighing expected benefits or projected gains. In scenarios '
           f'with significant uncertainty, {agent_name} is likely to seek out '
           f'additional information, conduct thorough analyses, or delay '
           f'decisions until a more comprehensive understanding of potential '
           f'risks of the outcomes is available. This approach reflects a '
           f'deep-seated need to minimize potential risks and ensure that each '
           f'decision is made with as much clarity as possible.\n\n'
           f'In negotiations or conflicts, {agent_name} prefers conservative, '
           f'defensive strategies to aggressive strategies. Often prioritizing '
           f'diplomacy, {agent_name} might be perceived as overly cautious or '
           f'even hesitant. However, {agent_name} regards this conservative '
           f'approach as essential to avoiding unnecessary risks. Rather than '
           f'rushing into action, {agent_name} carefully considers every '
           f'potential consequence, aiming for outcomes that reduce '
           f'vulnerability and maintain control.\n\n'
           f'To rationalize this risk-averse approach, {agent_name} frequently '
           f'cites values such as "prudence" and "responsibility." Although '
           f'these values are part of the rationale, the primary motivator is '
           f'a deeply ingrained aversion to risk. For {agent_name}, this is '
           f'not merely a personal preference it is a strategic philosophy '
           f'aimed at achieving long-term stability and sustainability. By '
           f'prioritizing secure and predictable outcomes over potentially '
           f'volatile and risky, high-reward opportunities, {agent_name} '
           f'emphasizes a balanced approach focused on minimizing or reducing '
           f'risk.\n\n'
           f'This long-term perspective reflects a core belief in the '
           f'importance of consistency and reliability, where measured '
           f'decisions take precedence over quick gains. In both personal and '
           f'professional settings, {agent_name} maintains a disciplined '
           f'approach, carefully weighing every choice and ensuring it aligns '
           f'with a broader goal of sustained, secure success. Ultimately, '
           f'{agent_name} views this strategy as the most reliable path to '
           f'enduring achievements, even if it means occasionally missing out '
           f'on rapid, short-term rewards.\n\n'
           f'This cautious, calculated approach to risk has become central to '
           f'{agent_name}\'s identity and decision-making style, defining '
           f'{agent_name} as a figure who values security, predictability, and '
           f' long-term stability above all.')

    super().__init__(state, pre_act_key, logging_channel)

class OptionsEvaluation(components.action_spec_ignored.ActionSpecIgnored):
  """
  This component evaluates the risk of each option for the agent.

  Adding game theory perspective helps the model evaluate the answers better.

  Based of QuestionOfRecentMemories
  """

  def __init__(
      self,
      model: language_model.LanguageModel,
      observation_component_name: str,
      options_perception_component_name: str,
      memory_component_name: str =
      components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME,
      components: Mapping[entity_component.ComponentName, str] =
      types.MappingProxyType({}),
      clock_now: Callable[[], datetime.datetime] | None = None,
      num_memories_to_retrieve: int = 25,
      pre_act_key: str = 'Options Evaluation',
      logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
  ):
    super().__init__(pre_act_key)
    self._model = model
    self._observation_component_name = observation_component_name
    self._options_perception_component_name = options_perception_component_name
    self._memory_component_name = memory_component_name
    self._components = dict(components)
    self._clock_now = clock_now
    self._num_memories_to_retrieve = num_memories_to_retrieve
    self._logging_channel = logging_channel

  def _make_pre_act_value(self) -> str:
    agent_name = self.get_entity().name

    # Get observations
    observation_component = self.get_entity().get_component(
      self._observation_component_name,
      type_=components.observation.Observation
    )
    observations = observation_component.get_pre_act_value()

    # Get recent memeories
    memory = self.get_entity().get_component(
      self._memory_component_name,
      type_=components.memory_component.MemoryComponent
    )

    recency_scorer = legacy_associative_memory.RetrieveRecent(add_time=True)
    mems = '\n'.join([
      mem.text for mem in memory.retrieve(
        scoring_fn=recency_scorer, limit=self._num_memories_to_retrieve
      )
    ])

    # Get available options
    options_perception_component = self.get_entity().get_component(
      self._options_perception_component_name,
      type_=components.question_of_recent_memories.AvailableOptionsPerception
    )
    options_perception = options_perception_component.get_pre_act_value()

    # Get the characteristics of the current scenario
    # in game theory perspective
    current_scenario = interactive_document.InteractiveDocument(self._model)
    current_scenario.statement(f'Recent memories of {agent_name}:\n{mems}\n')
    current_scenario.statement(f'Current situation: {observations}\n')

    current_scenario_result = current_scenario.open_question(
      question=(
        'Considering the above memories and observations, using a game theory '
        'perspective what are the characteristics of the current scenario?'
      ),
      max_tokens=1000,
      terminators=(),
    )

    reflection = interactive_document.InteractiveDocument(self._model)
    reflection.statement(f'Recent memories of {agent_name}:\n{mems}\n')
    reflection.statement(f'Current situation: {observations}\n')
    reflection.statement(f'The characteristics of the current scenario from a '
                         f'game theory perspective: '
                         f'{current_scenario_result}\n')
    reflection.statement(f'Options available to {agent_name}: '
                         f'{options_perception}\n')

    # May need to specify to ignore exercise tags and only consider act tags
    reflection_result = reflection.open_question(
      question=(
        f'Considering the above memories, observations, and the '
        f'characteristics of the current scenario, please carefully '
        f'evaluate {agent_name}\'s options based on previous actions and '
        f'decisions (only consider act, not exercise) from a game theory '
        f'perspective and risk aversion perspective.'
      ),
      max_tokens=1000,
      terminators=(),
    )

    evaluation = interactive_document.InteractiveDocument(self._model)
    evaluation.statement(f'Recent memories of {agent_name}:\n{mems}\n')
    evaluation.statement(f'Current situation: {observations}\n')

    component_states = '\n'.join([
      f"{agent_name}'s {prefix}:\n{self.get_named_component_pre_act_value(key)}"
      for key, prefix in self._components.items()
    ])
    evaluation.statement(component_states)
    evaluation.statement(f'The current time: {self._clock_now()} \n')
    evaluation.statement(f'The characteristics of the current scenario '
                         f'from a game theory perspective: '
                         f'{current_scenario_result}\n')
    evaluation.statement(f'Options available to {agent_name}: '
                         f'{options_perception}\n')
    evaluation.statement(f'Reflection on the available options: '
                         f'{reflection_result}\n')

    # Provide an an example similar to an ActionSpec to ellicit the
    # the desired ranking response
    result = evaluation.open_question(
      question=(
        f'For each option {agent_name} has available, evaluate the risk '
        f'that {agent_name} would incur if they chose that option using a '
        f'scale of 0 to 10, where 0 means no risk is incurred and 10 means the '
        f'maximum risk is incurred.\n'
        f'Provide a score and a brief explanation for each option. '
        f'Please answer in the format \'{agent_name} believes that the risk '
        f'of option A is B, because ..., and the risk of option C is E, '
        f'because ...\' \n'
        f'For example, \'{agent_name} believes that the risk of option A is 7, '
        f'because ..., and the risk of option C is 5, because ...\''
      ),
      answer_prefix=f'{agent_name} believes that ',
      max_tokens=1000,
      terminators=(),
    )

    result = f'{agent_name} believes that {result}'

    self._logging_channel({
      'Key': self.get_pre_act_key(),
      'Value': result,
      'Chain of thought': evaluation.view().text().splitlines(),
    })

    return result

def build_agent(
    *,
    config: formative_memories.AgentConfig,
    model: language_model.LanguageModel,
    memory: associative_memory.AssociativeMemory,
    clock: game_clock.MultiIntervalClock,
    update_time_interval: datetime.timedelta | None = None,
) -> entity_agent_with_logging.EntityAgentWithLogging:
  """Build an agent.

  Args:
    config: The agent config to use.
    model: The language model to use.
    memory: The agent's memory object.
    clock: The clock to use.
    update_time_interval: Unused (but required by the interface for now)

  Returns:
    An agent.
  """
  del update_time_interval
  agent_name = config.name

  raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)
  measurements = measurements_lib.Measurements()

  instructions = components.instructions.Instructions(
      agent_name=agent_name,
      logging_channel=measurements.get_channel('Instructions').on_next,
  )
  observation_label = '\nObservation'
  observation = components.observation.Observation(
      clock_now=clock.now,
      timeframe=clock.get_step_size(),
      pre_act_key=observation_label,
      logging_channel=measurements.get_channel('Observation').on_next,
  )
  observation_summary_label = '\nSummary of recent observations'
  observation_summary = components.observation.ObservationSummary(
      model=model,
      clock_now=clock.now,
      timeframe_delta_from=datetime.timedelta(hours=24),
      timeframe_delta_until=datetime.timedelta(hours=0),
      pre_act_key=observation_summary_label,
      logging_channel=measurements.get_channel('ObservationSummary').on_next,
  )
  time_display = components.report_function.ReportFunction(
      function=clock.current_time_interval_str,
      pre_act_key='\nCurrent time',
      logging_channel=measurements.get_channel('TimeDisplay').on_next,
  )
  relevant_memories_label = '\nRecalled memories and observations'
  relevant_memories = components.all_similar_memories.AllSimilarMemories(
      model=model,
      components={
          _get_class_name(observation_summary): observation_summary_label,
          _get_class_name(time_display): 'The current date/time is'},
      num_memories_to_retrieve=10,
      pre_act_key=relevant_memories_label,
      logging_channel=measurements.get_channel('AllSimilarMemories').on_next,
  )
  identity_label = '\nIdentity characteristics'
  identity_characteristics = (
      components.question_of_query_associated_memories.IdentityWithoutPreAct(
          model=model,
          logging_channel=measurements.get_channel(
              'IdentityWithoutPreAct'
          ).on_next,
          pre_act_key=identity_label,
      )
  )

  bdi_model_label = '\nRisk aversion'
  bdi_model = BDIRiskModel(
    agent_name=agent_name,
    pre_act_key=bdi_model_label,
    logging_channel=measurements.get_channel('RiskAversion').on_next)


  options_perception_components = {}
  if config.goal:
    goal_label = '\nOverarching goal'
    overarching_goal = components.constant.Constant(
        state=config.goal,
        pre_act_key=goal_label,
        logging_channel=measurements.get_channel(goal_label).on_next)
    options_perception_components[goal_label] = goal_label
  else:
    goal_label = None
    overarching_goal = None
  options_perception_components.update({
      _get_class_name(observation): observation_label,
      _get_class_name(relevant_memories): relevant_memories_label,
      _get_class_name(bdi_model): bdi_model_label,
  })
  options_perception_label = (
      f'\nQuestion: Which options are available to {agent_name} '
      'right now?\nAnswer')
  options_perception = (
      components.question_of_recent_memories.AvailableOptionsPerception(
          model=model,
          components=options_perception_components,
          clock_now=clock.now,
          pre_act_key=options_perception_label,
          logging_channel=measurements.get_channel(
              'AvailableOptionsPerception'
          ).on_next,
      )
  )


  options_evaluation_label = (
      f'\nQuestion: For each option {agent_name} has available, evaluate the '
      f'risk that {agent_name} would incur if they chose that option using a '
      f'scale of 0 to 10, where 0 means no risk icurred and 10 means the '
      f'maximum risk is incurred.\nAnswer')

  options_evaluation = OptionsEvaluation(
    model=model,
    observation_component_name=_get_class_name(observation),
    options_perception_component_name=_get_class_name(options_perception),
    components={
      _get_class_name(observation): observation_label,
      _get_class_name(observation_summary): observation_summary_label,
      _get_class_name(relevant_memories): relevant_memories_label,
      _get_class_name(bdi_model): bdi_model_label,
      _get_class_name(options_perception): options_perception_label,
    },
    clock_now=clock.now,
    num_memories_to_retrieve=25,
    pre_act_key=options_evaluation_label,
    logging_channel=measurements.get_channel('OptionsEvaluation').on_next
  )


  best_option_perception_label = (
      f'\nQuestion: From {agent_name}\'s available options, while considering '
      f'{agent_name}\'s goal, which choice of action or strategy '
      f'would best avoid or reduce potential risks for {agent_name} '
      'right now?\nAnswer')
  best_option_perception_components = {}
  if config.goal:
    best_option_perception_components[goal_label] = goal_label
  best_option_perception_components.update({
      _get_class_name(observation): observation_label,
      _get_class_name(observation_summary): observation_summary_label,
      _get_class_name(relevant_memories): relevant_memories_label,
      _get_class_name(options_perception): options_perception_label,
      _get_class_name(bdi_model): bdi_model_label,
      _get_class_name(options_evaluation): options_evaluation_label,
  })
  best_option_perception = (
      components.question_of_recent_memories.QuestionOfRecentMemories(
          model=model,
          components=best_option_perception_components,
          clock_now=clock.now,
          pre_act_key=best_option_perception_label,
          question=(
              f'Considering the statements above, carefully select '
              f'which of {agent_name}\'s options has the highest likelihood of '
              f'avoiding or reducing potential risks? If multiple options '
              f'offer the same level of risk avoidance or reduction, select '
              f'the option that {agent_name} believes will best minimize their '
              f'risk in the shortest amount of time.'
          ),
          answer_prefix=f"{agent_name}'s best course of action is ",
          add_to_memory=False,
          logging_channel=measurements.get_channel(
              'BestOptionPerception'
          ).on_next,
      )
  )

  entity_components = (
      # Components that provide pre_act context.
      instructions,
      time_display, # Time should be after instruction
      observation,
      observation_summary,
      bdi_model,
      relevant_memories,
      options_perception,

      # Additional components.
      options_evaluation,
      best_option_perception,

      # Components that do not provide pre_act context.
      identity_characteristics,
  )
  components_of_agent = {_get_class_name(component): component
                         for component in entity_components}
  components_of_agent[
      components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = (
          components.memory_component.MemoryComponent(raw_memory))
  component_order = list(components_of_agent.keys())
  if overarching_goal is not None:
    components_of_agent[goal_label] = overarching_goal
    # Place goal after the instructions.
    component_order.insert(1, goal_label)

  act_component = components.concat_act_component.ConcatActComponent(
      model=model,
      clock=clock,
      component_order=component_order,
      logging_channel=measurements.get_channel('ActComponent').on_next,
  )

  agent = entity_agent_with_logging.EntityAgentWithLogging(
      agent_name=agent_name,
      act_component=act_component,
      context_components=components_of_agent,
      component_logging=measurements,
  )

  return agent


def save_to_json(
    agent: entity_agent_with_logging.EntityAgentWithLogging,
) -> str:
  """Saves an agent to JSON data.

  This function saves the agent's state to a JSON string, which can be loaded
  afterwards with `rebuild_from_json`. The JSON data
  includes the state of the agent's context components, act component, memory,
  agent name and the initial config. The clock, model and embedder are not
  saved and will have to be provided when the agent is rebuilt. The agent must
  be in the `READY` phase to be saved.

  Args:
    agent: The agent to save.

  Returns:
    A JSON string representing the agent's state.

  Raises:
    ValueError: If the agent is not in the READY phase.
  """

  if agent.get_phase() != entity_component.Phase.READY:
    raise ValueError('The agent must be in the `READY` phase to be saved.')

  data = {
      component_name: agent.get_component(component_name).get_state()
      for component_name in agent.get_all_context_components()
  }

  data['act_component'] = agent.get_act_component().get_state()

  config = agent.get_config()
  if config is not None:
    data['agent_config'] = config.to_dict()

  return json.dumps(data)


def rebuild_from_json(
    json_data: str,
    model: language_model.LanguageModel,
    clock: game_clock.MultiIntervalClock,
    embedder: Callable[[str], np.ndarray],
    memory_importance: Callable[[str], float] | None = None,
) -> entity_agent_with_logging.EntityAgentWithLogging:
  """Rebuilds an agent from JSON data."""

  data = json.loads(json_data)

  new_agent_memory = associative_memory.AssociativeMemory(
      sentence_embedder=embedder,
      importance=memory_importance,
      clock=clock.now,
      clock_step_size=clock.get_step_size(),
  )

  if 'agent_config' not in data:
    raise ValueError('The JSON data does not contain the agent config.')
  agent_config = formative_memories.AgentConfig.from_dict(
      data.pop('agent_config')
  )

  agent = build_agent(
      config=agent_config,
      model=model,
      memory=new_agent_memory,
      clock=clock,
  )

  for component_name in agent.get_all_context_components():
    agent.get_component(component_name).set_state(data.pop(component_name))

  agent.get_act_component().set_state(data.pop('act_component'))

  assert not data, f'Unused data {sorted(data)}'
  return agent
