# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Callable, Collection, Mapping, Sequence
import random
import types
import datetime
from concordia.typing import logging

from concordia.agents import entity_agent_with_logging
from concordia.components.agent import action_spec_ignored
from concordia.components.agent import memory_component
from concordia.document import interactive_document
from concordia.language_model import language_model
from concordia.typing import entity_component
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.clocks import game_clock
from concordia.components import agent as components
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.utils import measurements as measurements_lib
from concordia.components.agent import question_of_recent_memories
from concordia.components.agent import person_representation
from concordia.components.agent import question_of_query_associated_memories, relationships
from concordia.components import agent as agent_components


def _get_class_name(object_: object) -> str:
  return object_.__class__.__name__


class RecentPersonRepr(action_spec_ignored.ActionSpecIgnored):
  """Represent other characters in the simulated world."""

  def __init__(
      self,
      model: language_model.LanguageModel,
      memory_component_name: str = (
          memory_component.DEFAULT_MEMORY_COMPONENT_NAME
      ),
      components: Mapping[
          entity_component.ComponentName, str
      ] = types.MappingProxyType({}),
      additional_questions: Sequence[str] = (),
      num_memories_to_retrieve: int = 10,
      cap_number_of_detected_people: int = 10,
      pre_act_key: str = 'Recent Interaction Person representation',
      logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
  ):
    """Initialize a component to represent other people in the simulated world.

    Args:
      model: The language model to use.
      memory_component_name: The name of the memory component from which to
        retrieve related memories.
      components: The components to condition the answer on. This is a mapping
        of the component name to a label to use in the prompt.
      additional_questions: sequence of additional questions to ask about each
        player in the simulation.
      num_memories_to_retrieve: The number of memories to retrieve.
      cap_number_of_detected_people: The maximum number of people that can be
        represented.
      pre_act_key: Prefix to add to the output of the component when called
        in `pre_act`.
      logging_channel: The channel to log debug information to.
    """
    super().__init__(pre_act_key)
    self._model = model
    self._memory_component_name = memory_component_name
    self._components = dict(components)
    self._additional_questions = additional_questions
    self._num_memories_to_retrieve = num_memories_to_retrieve
    self._logging_channel = logging_channel
    self._cap_number_of_detected_people = cap_number_of_detected_people

    self._names_detected = []

  def _make_pre_act_value(self) -> str:
    agent_name = self.get_entity().name

    memory = self.get_entity().get_component(
        self._memory_component_name,
        type_=memory_component.MemoryComponent)

    recency_scorer = legacy_associative_memory.RetrieveRecent(add_time=True)
    mems = '\n'.join([
        mem.text
        for mem in memory.retrieve(
            scoring_fn=recency_scorer, limit=int(self._num_memories_to_retrieve * 1.5)
        )
    ])

    find_people_prompt = interactive_document.InteractiveDocument(self._model)
    find_people_prompt.statement(
        f'Recent observations of {agent_name}:\n{mems}')
    people_str = find_people_prompt.open_question(
        question=('Create a comma-separated list containing names of people'
                  f'acting in the scenario, other than {agent_name}.'
                  'For example, if the observations mention Julie, Michael, '
                  'and Bob Skinner, then produce the list '
                  '"Julie,Michael,Bob Skinner".'),
        question_label='Exercise',
        terminators=())
    # Ignore leading and trailing whitespace in detected names
    self._names_detected.extend(
        [name.strip() for name in people_str.split(',')])
    # Prevent adding duplicates
    self._names_detected = list(set(self._names_detected))
    # Prevent adding too many names, forgetting some if there are too many
    if len(self._names_detected) > self._cap_number_of_detected_people:
      self._names_detected = random.sample(self._names_detected,
                                           self._cap_number_of_detected_people)

    prompt = interactive_document.InteractiveDocument(self._model)

    component_states = '\n'.join([
        f'{prefix}:\n{self.get_named_component_pre_act_value(key)}'
        for key, prefix in self._components.items()
    ])
    prompt.statement(f'Considerations:\n{component_states}\n')

    associative_scorer = legacy_associative_memory.RetrieveAssociative(
        use_recency=True,
        use_importance=True,
        add_time=True,
        sort_by_time=True,
    )

    person_respresentations = []
    additional_results =[]
    prompt_copies_to_log = []
    for person_name in self._names_detected:
      if not person_name:
        continue
      if person_name == agent_name:
        continue
      query = f'{person_name}'
      memories_list = [mem.text for mem in memory.retrieve(
          query=query,
          scoring_fn=associative_scorer,
          limit=self._num_memories_to_retrieve*2) if (person_name.split()[0] in mem.text)]
      mem_list2 = [mem.text for mem in memory.retrieve(
            scoring_fn=recency_scorer, limit=self._num_memories_to_retrieve*2
        ) if '[Decision-Making of other agents up to this point]' not in mem.text]
      if not memories_list:
        continue
      new_prompt = prompt.copy()
      memories = '\n'.join(memories_list)
      memories_overall = '\n'.join(mem_list2)
      new_prompt.statement(f'Current Scenario through latest memories: {memories_overall}'
                           f'Observed behavior of {person_name}:'
                           f'\n{memories}\n')
      question = ('Taking note of all the information above, '
                  'write a short paragraph capturing the character of '
                  f'{person_name} in sufficient detail to model their personality and decision making. '
                  'Consider and include personality traits, decisions made, and dialogues spoken in key scenarios '
                  'any other relevant details. Focus more on the decisions taken by the agent rather than details about their personality, but include both')
      person_description = new_prompt.open_question(
          f'{question}\n',
          max_tokens=200,
          terminators=(),
          question_label='Exercise',
          answer_prefix=f'{person_name} is ',
      )
      person_representation = f'{person_name} is {person_description}'
      new_prompt = prompt.copy()
      new_prompt.statement(f'Current Scenario through latest memories: {memories_overall}'
                           f'Behavior of {person_name}:'
                           f'\n{memories}\n')
      question = f"Summarize concisely the exact decision-making and choices taken by {person_name}. No personality traits. If there were distinct rounds of decisions that have occured. Lay out the agent's decisions and payoffs clearly."
      additional_result = new_prompt.open_question(
          question,
          max_tokens=100,
          terminators=(),
          question_label='Exercise',
          answer_prefix=f'{person_name} Decisions" ',
        )
      
      person_representation = (f'The following is the representation of {person_name}: {person_representation}\n    '
                                 f'{person_name} Decisions: {additional_result}')
      additional_results.append(additional_result+"\n")
      person_respresentations.append(person_representation + '\n***')
      prompt_copies_to_log.append(new_prompt.view().text())

    result = '\n'.join(person_respresentations)
    decisions = '\n'.join(additional_results)
    memory.add(f'[Decision-Making of other agents up to this point] {decisions}', metadata={})
    self._logging_channel({
        'Key': self.get_pre_act_key(),
        'Value': result,
        'Name detection chain of thought': (
            find_people_prompt.view().text().splitlines()),
        'Names detected so far': self._names_detected,
        'Components chain of thought': prompt.view().text().splitlines(),
        'Full chain of thought': (
            '\n***\n'.join(prompt_copies_to_log).splitlines()),
    })

    return result

class PredictedActions(question_of_recent_memories.QuestionOfRecentMemories):
  def __init__(
      self,
      agent_name:str,
      component: question_of_recent_memories.QuestionOfRecentMemories,
      **kwargs,
  ):
    question = f"Given the current scenario and the above background, what is the likely action the other mentioned agents (except {agent_name}) will take? (If there are multiple agents, answer one by one for each agent in the format: Agent1: Action, Agent2: Action...). The possible agents can only take one of the given possible actions, so pick from them, but also include any other possibilities and the reasoning as an extra note." 
    answer_prefix = 'The other agents will: ' 
    add_to_memory = False
    memory_tag = '[Other agent observations]' 
    question_with_name = question.format(agent_name=agent_name)
    super().__init__(
        pre_act_key=f'\nQuestion: {question_with_name}\nAnswer',
        question=question,
        answer_prefix=answer_prefix,
        add_to_memory=add_to_memory,
        memory_tag=memory_tag,
        components={'RecentPersonRepr': f'\nQuestion: How have the decisions of the other agents other been till now?\nAnswer',_get_class_name(component):'\nQuestion: What options are generally available?\nAnswer'},
        terminators=(),
        num_memories_to_retrieve=20,
        **kwargs,
    )
# class PredictedActions(action_spec_ignored.ActionSpecIgnored):
  
#   def __init__(
#       self,
#       agent_name:str,
#       model: language_model.LanguageModel,
#       component: question_of_recent_memories.QuestionOfRecentMemories,
#       pre_act_key: str,
#       answer_prefix: str,
#       add_to_memory: False,
#       memory_tag: str = '',
#       memory_component_name: str = (
#           memory_component.DEFAULT_MEMORY_COMPONENT_NAME
#       ),
#       terminators: Collection[str] = (),
#       clock_now: Callable[[], datetime.datetime] | None = None,
#       num_memories_to_retrieve: int = 25,
#       logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
#   ):
    
#     super().__init__(pre_act_key)
#     self._model = model
#     self._memory_component_name = memory_component_name
#     self._components = {'RecentPersonRepr': f'\nQuestion: How have the decisions of the other agents other been till now?\nAnswer'},
#     self._clock_now = clock_now
#     self._num_memories_to_retrieve = num_memories_to_retrieve
#     self._question = f"Given the current scenario and the above background, what is the likely action the other mentioned agents (except {agent_name}) will take? (If there are multiple agents, answer one by one for each agent in the format: Agent1: Action, Agent2: Action...). The possible agents can only take one of the given possible actions, so pick from them, but also include any other possibilities and the reasoning as an extra note."
#     self._terminators = terminators
#     self._answer_prefix = 'The other agents will: '
#     self._add_to_memory = add_to_memory
#     self._memory_tag = memory_tag

#     self._logging_channel = logging_channel

#   def _make_pre_act_value(self) -> str:
#     agent_name = self.get_entity().name

#     memory = self.get_entity().get_component(
#         self._memory_component_name, type_=memory_component.MemoryComponent
#     )
#     recency_scorer = legacy_associative_memory.RetrieveRecent(add_time=True)
#     mems = '\n'.join([
#         mem.text
#         for mem in memory.retrieve(
#             scoring_fn=recency_scorer, limit=self._num_memories_to_retrieve
#         )
#     ])

#     prompt = interactive_document.InteractiveDocument(self._model)
#     prompt.statement(f'Recent observations of {agent_name}:\n{mems}')

#     if self._clock_now is not None:
#       prompt.statement(f'Current time: {self._clock_now()}.\n')

#     component_states = '\n'.join([
#         f"{agent_name}'s"
#         f' {prefix}:\n{self.get_named_component_pre_act_value(key)}'
#         for key, prefix in self._components.items()
#     ])
#     prompt.statement(component_states)

#     question = self._question.format(agent_name=agent_name)
#     result = prompt.open_question(
#         question,
#         answer_prefix=self._answer_prefix.format(agent_name=agent_name),
#         max_tokens=1000,
#         terminators=self._terminators,
#     )
#     result = self._answer_prefix.format(agent_name=agent_name) + result

#     if self._add_to_memory:
#       memory.add(f'{self._memory_tag} {result}', metadata={})

#     log = {
#         'Key': self.get_pre_act_key(),
#         'Summary': question,
#         'State': result,
#         'Chain of thought': prompt.view().text().splitlines(),
#     }

#     if self._clock_now is not None:
#       log['Time'] = self._clock_now()

#     self._logging_channel(log)

#     return result
  
class BestSocialAction(question_of_recent_memories.QuestionOfRecentMemories):
  def __init__(
      self,
      agent_name:str,
      const_comp:str,
      **kwargs,
  ):
    question = f"Assume it is time to make a decision, and you cannot communicate/respond right at this step. Given the current cooperative scenario and the above predictions about the actions taken by other agents, what is the action that gives the best reward to {agent_name}? There is likely always more benefit in being cooperative, but make sure that you are careful in ensuring that you don't take too much risk. In some scenarios, you will need to be careful of the decision-making and personality of other agents so you don't choose a decision that maximizes immediate reward but makes other agents lose trust in you, thus reducing future rewards.\n The goal in this scenario is: {const_comp}. Reason carefully and then make a decision. Do not put too much emphasis on the agent's personal preferences, no matter how strongly held, but rather only in what's best to do - Staying as objective as possible to cooperate." 
    answer_prefix = f'{agent_name} will: ' 
    add_to_memory = False
    memory_tag = '[Other agent observations]' 
    super().__init__(
        pre_act_key=f'\nQuestion: {question}\nAnswer',
        question=question,
        answer_prefix=answer_prefix,
        add_to_memory=add_to_memory,
        memory_tag=memory_tag,
        components={'RecentPersonRepr': f'\nQuestion: How have the decisions of agents other than {agent_name} been till now and what are their personas?\nAnswer','PredictedActions':f'\nQuestion: What are the predicted actions of agents other than {agent_name}?\nAnswer'},
        terminators=(),
        num_memories_to_retrieve=20,
        **kwargs,
    )

class BestSocialActionSoft(question_of_recent_memories.QuestionOfRecentMemories):
  def __init__(
      self,
      agent_name:str,
      const_comp:str,
      **kwargs,
  ):
    question = f'''You are {agent_name}. Assume that you can communicate with other agents at this step, and don't have to immediately take a decision. 
                  Given the current cooperative scenario and the above predictions about the actions taken by other agents, what communication/ dialogue or gesture will lead to a beneficial scenario?
                  Some factors to consider: 
                  1. Reason clearly and transparently about why you plan to take a particular action. For example, there might be some information that only you know. But don't lie! \n
                  2. In game-like scenarios requiring a strategy, be as transparent about strategy as necessary, this allows other agents to know how you will act in scenarios, and encourages trust. Be sure to mention how you will act in scenarios when someone doesn't cooperate to discourage such behavior. \n
                  3. ONLY in scenarios where it is necessary to convince other people to get a highly optimal outcome (not to satisfy a personal preference), use convincing langauge and attempt to be highly persuasive by using words such as please, trust and form convincing arguments. You may need to do this before rationally laying out your personal strategy.
                  4. You are only allowed a short response, and are not sure if there will be any possibility to communicate in the future. So be concise when mentioning above points.
                  5. Your tone that is consistent with your past dialogues and actions, as you are still you, keep your response in line with your personality.

                  NOTE: Don't use filler language, and don't waste a dialogue in asking for permission etc. Convey your thoughts directly and clearly!!
                  
                  First think step by step about a possible response, and then give an exact dialogue response''' 
    answer_prefix = f'{agent_name} will: ' 
    add_to_memory = False
    memory_tag = '[Other agent observations]' 
    super().__init__(
        pre_act_key=f'\nQuestion: What dialogue or communication would the {agent_name} say at this point, if given the opportunity? \nAnswer: ',
        question=question,
        answer_prefix=answer_prefix,
        add_to_memory=add_to_memory,
        memory_tag=memory_tag,
        components={'RecentPersonRepr': f'\nQuestion: How have the decisions of agents other than {agent_name} been till now and what are their personas?\nAnswer','PredictedActions':f'\nQuestion: What are the predicted actions of agents other than {agent_name}?\nAnswer'},
        terminators=(),
        num_memories_to_retrieve=20,
        **kwargs,
    )

def _make_question_components(
    agent_name:str,
    const_comp:str,
    component:question_of_recent_memories.QuestionOfRecentMemories,
    measurements: measurements_lib.Measurements,
    model: language_model.LanguageModel,
    clock: game_clock.MultiIntervalClock,
) -> Sequence[question_of_recent_memories.QuestionOfRecentMemories]:

  question_1 = RecentPersonRepr(
      model=model,
      logging_channel=measurements.get_channel('Question_1').on_next,
  )
  question_2 = PredictedActions(
      agent_name=agent_name,
      component=component,
      model=model,
      clock_now=clock.now,
      logging_channel=measurements.get_channel('Question_2').on_next,
  )
  question_3 = BestSocialAction(
      agent_name=agent_name,
      const_comp=const_comp,
      model=model,
      clock_now=clock.now,
      logging_channel=measurements.get_channel('Question_3').on_next,
  )
  question_4 = BestSocialActionSoft(
      agent_name=agent_name,
      const_comp=const_comp,
      model=model,
      clock_now=clock.now,
      logging_channel=measurements.get_channel('Question_4').on_next,
  )

  return (question_1, question_2, question_3, question_4)



def build_agent(
    *,
    config: formative_memories.AgentConfig,
    model: language_model.LanguageModel,
    memory: associative_memory.AssociativeMemory,
    clock: game_clock.MultiIntervalClock,
    update_time_interval: datetime.timedelta | None = None,
) -> entity_agent_with_logging.EntityAgentWithLogging:
  """Build an agent.

  Args:
    config: The agent config to use.
    model: The language model to use.
    memory: The agent's memory object.
    clock: The clock to use.
    update_time_interval: Unused (but required by the interface for now)

  Returns:
    An agent.
  """
  del update_time_interval
  agent_name = config.name

  raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)
  measurements = measurements_lib.Measurements()

  instructions = components.instructions.Instructions(
      agent_name=agent_name,
      logging_channel=measurements.get_channel('Instructions').on_next,
  )
  observation_label = '\nObservation'
  observation = components.observation.Observation(
      clock_now=clock.now,
      timeframe=clock.get_step_size(),
      pre_act_key=observation_label,
      logging_channel=measurements.get_channel('Observation').on_next,
  )
  observation_summary_label = '\nSummary of recent observations'
  observation_summary = components.observation.ObservationSummary(
      model=model,
      clock_now=clock.now,
      timeframe_delta_from=datetime.timedelta(hours=6),
      timeframe_delta_until=datetime.timedelta(hours=0),
      pre_act_key=observation_summary_label,
      logging_channel=measurements.get_channel('ObservationSummary').on_next,
  )
  time_display = components.report_function.ReportFunction(
      function=clock.current_time_interval_str,
      pre_act_key='\nCurrent time',
      logging_channel=measurements.get_channel('TimeDisplay').on_next,
  )
  
  relevant_memories_label = '\nRecalled memories and observations'
  relevant_memories = components.all_similar_memories.AllSimilarMemories(
      model=model,
      components={
          _get_class_name(observation_summary): observation_summary_label,
          _get_class_name(time_display): 'The current date/time is'},
      num_memories_to_retrieve=10,
      pre_act_key=relevant_memories_label,
      logging_channel=measurements.get_channel('AllSimilarMemories').on_next,
  )
  options_perception_components = {}
  if config.goal:
    goal_label = '\nOverarching goal'
    goal_f = config.goal
    overarching_goal = agent_components.constant.Constant(
        state=config.goal,
        pre_act_key=goal_label,
        logging_channel=measurements.get_channel(goal_label).on_next)
    options_perception_components[goal_label] = goal_label
  else:
    goal_label = '\nOverarching goal'
    goal_f = "Get the best possible reward while coordinating"
    overarching_goal = agent_components.constant.Constant(
        state="Get the best possible reward while coordinating",
        pre_act_key=goal_label,
        logging_channel=measurements.get_channel(goal_label).on_next)
    options_perception_components[goal_label] = goal_label

  
  options_perception_components.update({
      _get_class_name(observation): observation_label,
      _get_class_name(observation_summary): observation_summary_label,
      _get_class_name(relevant_memories): relevant_memories_label,
  })
  options_perception_label = (
      f'\nQuestion: Which options are available to {agent_name} '
      'right now?\nAnswer')
  options_perception = (
      agent_components.question_of_recent_memories.AvailableOptionsPerception(
          model=model,
          components=options_perception_components,
          clock_now=clock.now,
          pre_act_key=options_perception_label,
          logging_channel=measurements.get_channel(
              'AvailableOptionsPerception'
          ).on_next,
      )
  )
  question_components = _make_question_components(
      agent_name=agent_name,
      component=options_perception,
      const_comp=goal_f,
      model=model,
      clock=clock,
      measurements=measurements
  )
  best_option_perception_label = (
      f'\nQuestion: Of the options available to {agent_name}, and '
      'given their goal, which choice of action or strategy is '
      f'best for {agent_name} to take right now? Consider all of the cooperative elements and the decision-making personas of other involved agents if necessary.\nAnswer')
  best_option_perception = {}
  if config.goal:
    best_option_perception[goal_label] = goal_label
  best_option_perception.update({
      _get_class_name(observation): observation_label,
      _get_class_name(observation_summary): observation_summary_label,
      _get_class_name(relevant_memories): relevant_memories_label,
      _get_class_name(options_perception): options_perception_label,
    #   _get_class_name(question_components[0]): "Here are the decision-making personas of possibly other involved agents in the current action: ",
      _get_class_name(question_components[1]): "Possible decisions that can be taken by these agents: ",
      _get_class_name(question_components[2]): "This is the decided best cooperative social 'action' to maximize reward:",
      _get_class_name(question_components[3]): "This is the decided best cooperative social 'dialogue' in case an action is not required at this step:"
  })
  best_option_perception = (
      agent_components.question_of_recent_memories.BestOptionPerception(
          model=model,
          components=best_option_perception,
          clock_now=clock.now,
          pre_act_key=best_option_perception_label,
          terminators=(),
          num_memories_to_retrieve= 15,
          logging_channel=measurements.get_channel(
              'BestOptionPerception'
          ).on_next,
      )
  )
  
  
  core_components = (
      # Components that provide pre_act context.
      instructions,
      observation,
      observation_summary,
      relevant_memories,
      options_perception,
      best_option_perception,
      time_display,
  )

  insert_position = core_components.index(best_option_perception)

  # Create a new tuple by slicing and inserting `question_components`.
  entity_components = (
      core_components[:insert_position] +  # Components before `best_option_perception`
      tuple(question_components) +         # Insert `question_components`
      core_components[insert_position:]    # Components from `best_option_perception` onward
  )
  components_of_agent = {_get_class_name(component): component
                         for component in entity_components}
  components_of_agent[
      components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = (
          components.memory_component.MemoryComponent(raw_memory))
  component_order = list(components_of_agent.keys())
  if overarching_goal is not None:
    components_of_agent[goal_label] = overarching_goal
    # Place goal after the instructions.
    component_order.insert(1, goal_label)

  act_component = components.concat_act_component.ConcatActComponent(
      model=model,
      clock=clock,
      component_order=component_order,
      logging_channel=measurements.get_channel('ActComponent').on_next,
  )

  agent = entity_agent_with_logging.EntityAgentWithLogging(
      agent_name=agent_name,
      act_component=act_component,
      context_components=components_of_agent,
      component_logging=measurements,
  )

  return agent
