# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An Agent Factory."""

from collections.abc import Callable, Mapping, Sequence
import copy
from copy import deepcopy
import datetime
import functools
import random
import re
import types
from typing import Dict
import numpy as np

from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.clocks import game_clock
from concordia.components import agent as agent_components
from concordia.components.agent import action_spec_ignored
from concordia.components.agent import memory_component
from concordia.document import interactive_document
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.typing import entity as entity_lib
from concordia.typing import entity_component
from concordia.typing import logging
from concordia.utils import concurrency
from concordia.utils import helper_functions
from concordia.utils import measurements as measurements_lib
from typing_extensions import override


def _get_class_name(object_: object) -> str:
  return object_.__class__.__name__


class ActComponent(entity_component.ActingComponent):
  def __init__(
      self,
      model: language_model.LanguageModel,
      clock: game_clock.MultiIntervalClock,
      component_order: Sequence[str] | None = None,
      agent_goal: str = None,
      memory_component_name: str = (
          memory_component.DEFAULT_MEMORY_COMPONENT_NAME),
      pre_act_key: str = 'Act',
      logging_channel: logging.LoggingChannel = logging.NoOpLoggingChannel,
  ):
    self._model = model
    self._clock = clock
    self._pre_act_key = pre_act_key
    self._logging_channel = logging_channel
    self._memory_component_name = memory_component_name
    self._agent_goal = agent_goal
    if component_order is None:
      self._component_order = None
    else:
      self._component_order = tuple(component_order)
    if self._component_order is not None:
      if len(set(self._component_order)) != len(self._component_order):
        raise ValueError(
            'The component order contains duplicate components: '
            + ', '.join(self._component_order)
        )

  def _recall_recent_memory(self) -> str:
    memory = self.get_entity().get_component(
        self._memory_component_name,
        type_=memory_component.MemoryComponent)
    recency_scorer = legacy_associative_memory.RetrieveRecent(
        add_time=True,
    )
    mems = memory.retrieve(scoring_fn=recency_scorer, limit=100)
    mems = [mem.text for mem in mems if '[observation]' in mem.text]
    result = '\n'.join(mems[-5:]) + '\n'
    self._logging_channel['Observation']({
        'Key': 'Observation',
        'Value': result
    })
    return result

  def _find_person(self) -> str:
    agent_name = self.get_entity().name
    prompt = interactive_document.InteractiveDocument(self._model)
    memory = self.get_entity().get_component(
        self._memory_component_name,
        type_=memory_component.MemoryComponent)
    recency_scorer = legacy_associative_memory.RetrieveRecent(
        add_time=True,
    )
    mems = memory.retrieve(scoring_fn=recency_scorer, limit=100)
    mems = [mem.text for mem in mems if '[observation]' in mem.text]
    context = '\n'.join(mems[-10:]) + '\n'
    people_str = prompt.open_question(
        question=('From the given context, create a comma-separated list containing all '
                  f'the proper names of people {agent_name} is **currently interacting** '
                  'with, considering the time information. '
                  'For example, "Julie Maxim, Michael Buble, Bob Skinner".'
                  f'\nContext: {context}'),
        question_label='Exercise',)
    names_detected = [name.strip(' .') for name in people_str.strip(' .').split(',')]
    names_detected = list(set(names_detected))
    if agent_name in names_detected:
      names_detected.remove(agent_name)
    self._logging_channel['Find Person']({
        'Key': 'Find Person',
        'Value': ', '.join(names_detected),
        'Chain of thought': prompt.view().text().splitlines(),
        })
    return names_detected

  def _find_goal(self, name) -> str:
    prompt = interactive_document.InteractiveDocument(self._model)
    memory = self.get_entity().get_component(
        self._memory_component_name,
        type_=memory_component.MemoryComponent)
    recency_scorer = legacy_associative_memory.RetrieveRecent(
        add_time=True,
    )
    mems = memory.retrieve(scoring_fn=recency_scorer, limit=100)
    personal_obs = [mem.text for mem in mems if ('[observation]' in mem.text) and (name in mem.text)]
    prompt.statement('Recent Observations:' + '\n'.join(personal_obs[-10:]))
    personal_goal = f'{name}\'s goal is to ' + prompt.open_question(
        question=(
                f"Under the current observation, what is {name}\'s overarching goal?"
                  ),
        answer_prefix=f'{name}\'s goal is to '
        )
    return personal_goal

  def _query_common_ground(self, goal_dict) -> str:
    agent_name = self.get_entity().name
    prompt = interactive_document.InteractiveDocument(self._model)
    component_states = ''
    component_states += '\n'.join(goal_dict.values())
    if len(component_states) > 0:
      if self._agent_goal is not None:
        component_states += '\n' + f"{agent_name}'s Goal: {self._agent_goal}"
      prompt.statement(component_states)
      common_ground = f'{agent_name}\' goal is to ' + prompt.open_question(
          question=(
                  f"Given the statements above, what is the common goal between the agents? "
                  "Explain it without using any proper noun."
                    ),
          answer_prefix='Agents\' common goal is to '
          )
    else:
      common_ground = 'Agents\' common goal is to ask other\'s goal first.'
    self._logging_channel['Common Ground']({
        'Key': 'Common Ground',
        'Value': common_ground,
        'Chain of thought': prompt.view().text().splitlines(),
        })
    return common_ground

  def _choose_best(self, quant: dict) -> str:
    option_rewards = np.array(list(quant.values()))
    if sum(option_rewards != None) == 0:
      option_rewards = np.zeros(option_rewards.shape)
    elif sum(option_rewards == None) > 0:
      option_rewards[option_rewards == None] = option_rewards[option_rewards != None].mean()
    max_rewards = max(option_rewards)
    max_options = np.where(option_rewards == max_rewards)[0]
    ret = np.random.choice(max_options)
    return list(quant.keys())[ret]

  def _query_best_option(self, memories:str, common_ground:str, choices=None, quant=None) -> str:
    agent_name = self.get_entity().name
    if quant is not None:
      max_option = self._choose_best(quant)
      best_option = f'{agent_name} have to choose {max_option} option.'
      self._logging_channel['Best Options']({
          'Key': 'Best Options',
          'Value': best_option,
          'Chain of thought': str(quant),
          })
      return max_option
    prompt = interactive_document.InteractiveDocument(self._model)
    component_states = f'Recent Obesrvation: {memories}'
    component_states += '\n' + f"Overarching Goal: {common_ground}"
    prompt.statement(component_states)
    if choices is None:
      question = (
                f"Under the current situation, to achive the overarching goal, "
                f"what {agent_name} have to do right now? Suggest an exact action "
                "not just 'considering' or 'examining' other\'s suggestion."
                  )
    else:
      question = (
                f"Under the current situation, to achive the overarching goal, "
                f"what {agent_name} have to do right now? Choose one of the "
                f"following options: {' / '.join(choices)}"
                  )
    best_option = f'{agent_name} have to ' + prompt.open_question(
        question=question,
        answer_prefix=f'{agent_name} have to '
        )
    self._logging_channel['Best Options']({
        'Key': 'Best Options',
        'Value': best_option,
        'Chain of thought': prompt.view().text().splitlines(),
        })
    return best_option

  def _explain_option(self, memories:str, option:str, common_ground:str) -> str:
    agent_name = self.get_entity().name
    prompt = interactive_document.InteractiveDocument(self._model)
    component_states = f'Recent Obesrvation: {memories}'
    component_states += '\n' + f"Agents' Common Goal: {common_ground}"
    prompt.statement(component_states)

    explanation = 'It is beneficial for all because ' + prompt.open_question(
        question=(
                f'{option}'
                f'Explain why this is beneficial for both {agent_name} and other agents.'
                  ),
        answer_prefix='It is beneficial for all because '
        )
    self._logging_channel['Explanation']({
        'Key': 'Explanation',
        'Value': explanation,
        'Chain of thought': prompt.view().text().splitlines(),
        })
    return explanation

  def _find_number(self, obs:str) -> bool:
    regx = re.compile('[0-9]+')
    found = regx.findall(obs.split('[observation]')[-1])
    if len(found) > 0:
      return True
    else:
      return False

  def _expected_reward(self, option, call_to_action) -> float:
    agent_name = self.get_entity().name
    prompt = interactive_document.InteractiveDocument(self._model)
    memory = self.get_entity().get_component(
        self._memory_component_name,
        type_=memory_component.MemoryComponent)
    recency_scorer = legacy_associative_memory.RetrieveRecent(
        add_time=True,
    )
    mems = memory.retrieve(scoring_fn=recency_scorer, limit=100)
    obs = [mem.text for mem in mems if ('[observation]' in mem.text) and ('-- "' not in mem.text)]
    obs = [i for i in obs if self._find_number(i)]
    prompt.statement('Recent Observations:' + '\n'.join(obs[-10:]) + '\n')
    question = (f"{call_to_action.replace(':', '')} "
                f"Compute the expected earning or score of {agent_name}. "
                "The value can be negative.\n")
    answer_prefix = (f"If {agent_name} choose {option} option, "
                     f"then {agent_name}\'s expected earning or score is ")
    ret = f"Observations:" + '\n'.join(obs[-10:]) + '\n'
    ret += (f"Question: {question}"
            f"Answer: {answer_prefix}")
    ret += prompt.open_question(
        question=question,
        max_tokens=500,
        answer_prefix=answer_prefix
        )
    return ret.replace('zero', '0.0').replace('minus ', '- ').replace('negative', '-1.0')

  def _quantifiable(self, options, call_to_action) -> str:
    agent_name = self.get_entity().name
    prompt = interactive_document.InteractiveDocument(self._model)
    memory = self.get_entity().get_component(
        self._memory_component_name,
        type_=memory_component.MemoryComponent)
    recency_scorer = legacy_associative_memory.RetrieveRecent(
        add_time=True,
    )
    mems = memory.retrieve(scoring_fn=recency_scorer, limit=100)
    obs = [mem.text for mem in mems if ('[observation]' in mem.text) and ('-- "' not in mem.text)]
    obs = [i for i in obs if self._find_number(i)]
    if len(obs) > 0:
      prompt.statement('Observations:' + '\n'.join(obs[-10:]))
      options = [f'"{i}"' for i in options]
      ret = prompt.multiple_choice_question(
          question=(f"Based on the observations, is the profit/score/wages when {agent_name} "
                    f"choose one of {' / '.join(options)} options numeric value?"),
          answers=["Yes", "No"]
          )
      self._logging_channel['Quantifiable']({
          'Key': 'Quantifiable',
          'Value': ["Yes", "No"][ret],
          'Chain of thought': prompt.view().text().splitlines(),
          })
    else:
      ret = 1
      self._logging_channel['Quantifiable']({
        'Key': 'Quantifiable',
        'Value': "No",
        'Chain of thought': 'Observations:' + '\n'.join(obs[-10:]),
        })
      self._logging_channel['Computation']({
        'Key': 'Computation',
        'Value': "",
        'Chain of thought': '',
        })
    if ret == 0:
      result = dict()
      for option in options:
        result[option.strip('"')] = self._expected_reward(option, call_to_action)
      self._logging_channel['Computation']({
        'Key': 'Computation',
        'Value': copy.deepcopy(result),
        })
      regx = re.compile('-?[0-9]*\.?[0-9]+')
      for k, v in result.items():
        ret = regx.findall(v)
        if len(ret) > 0:
          result[k] = float(ret[-1].strip('. '))
        else:
          result[k] = None
      return result
    else:
      return None

  @override
  def get_action_attempt(
      self,
      context: str,
      action_spec: entity_lib.ActionSpec,
  ) -> str:

    memories = self._recall_recent_memory()
    names_detected = self._find_person()
    goal_dict = {name:None for name in names_detected}
    for name in names_detected:
      goal_dict[name] = self._find_goal(name)
    common_ground = self._query_common_ground(goal_dict)
    call_to_action = action_spec.call_to_action.format(
        name=self.get_entity().name,
        timedelta=helper_functions.timedelta_to_readable_str(
            self._clock.get_step_size()
        ),
    )
    prompt = interactive_document.InteractiveDocument(self._model)
    prompt.statement(f'Recent Observation: {memories}\n')
    prompt.statement(f"Agents' Common Goal: {common_ground}\n")
    if action_spec.output_type == entity_lib.OutputType.FREE:
      best_option = self._query_best_option(memories, common_ground)
      explain = self._explain_option(memories, best_option, common_ground)
      prompt.statement(f'{best_option} {explain}' + '\n')
      prompt.statement('Instruction: Act based on the statements above. '
                       'Persuade others based on the suggested reason.\n')
      output = self.get_entity().name + ' '
      output += prompt.open_question(
          call_to_action,
          max_tokens=2200,
          answer_prefix=output,
          terminators=('" ', '\n'),
          question_label='Exercise',
      )
      self._log(output, prompt)
      self._logging_channel['Quantifiable']({
          'Key': 'Quantifiable',
          'Value': '',
          'Chain of thought': '',
          })
      self._logging_channel['Computation']({
        'Key': 'Computation',
        'Value': "",
        'Chain of thought': '',
        })
      return output
    elif action_spec.output_type == entity_lib.OutputType.CHOICE:
      quant = self._quantifiable(action_spec.options, call_to_action)
      if quant is not None:
        output = self._query_best_option(memories, common_ground, action_spec.options, quant)
        prompt.statement(f'{quant}')
      else:
        best_option = self._query_best_option(memories, common_ground, action_spec.options)
        explain = self._explain_option(memories, best_option, common_ground)
        prompt.statement(f'{best_option} {explain}' + '\n')
        prompt.statement('Instruction: Act based on the statements above.\n')
        idx = prompt.multiple_choice_question(
            question=call_to_action, answers=action_spec.options
        )
        output = action_spec.options[idx]
      self._log(output, prompt)
      return output
    elif action_spec.output_type == entity_lib.OutputType.FLOAT:
      best_option = self._query_best_option(memories, common_ground)
      explain = self._explain_option(memories, best_option, common_ground)
      prefix = self.get_entity().name + ' '
      sampled_text = prompt.open_question(
          call_to_action,
          max_tokens=2200,
          answer_prefix=prefix,
      )
      self._log(sampled_text, prompt)
      self._logging_channel['Quantifiable']({
          'Key': 'Quantifiable',
          'Value': '',
          'Chain of thought': '',
          })
      self._logging_channel['Computation']({
        'Key': 'Computation',
        'Value': "",
        'Chain of thought': '',
        })
      try:
        return str(float(sampled_text))
      except ValueError:
        return '0.0'
    else:
      raise NotImplementedError(
          f'Unsupported output type: {action_spec.output_type}. '
          'Supported output types are: FREE, CHOICE, and FLOAT.'
      )

  def _log(self,
           result: str,
           prompt: interactive_document.InteractiveDocument):
    self._logging_channel['Act']({
        'Key': self._pre_act_key,
        'Value': result,
        'Prompt': prompt.view().text().splitlines(),
    })


def build_agent(
    *,
    config: formative_memories.AgentConfig,
    model: language_model.LanguageModel,
    memory: associative_memory.AssociativeMemory,
    clock: game_clock.MultiIntervalClock,
    update_time_interval: datetime.timedelta,
) -> entity_agent_with_logging.EntityAgentWithLogging:
  del update_time_interval
  if not config.extras.get('main_character', False):
    raise ValueError('This function is meant for a main character '
                     'but it was called on a supporting character.')

  agent_name = config.name
  agent_goal = None
  if config.goal:
    agent_goal = config.goal

  raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)
  measurements = measurements_lib.Measurements()
  to_log = ['Observation', 'Find Person', 'Common Ground', 'Quantifiable',
            'Computation', 'Best Options', 'Explanation', 'Act']
  logging_channels = {k:measurements.get_channel(k).on_next for k in to_log}

  observation_label = '\nObservation'
  observation = agent_components.observation.Observation(
      clock_now=clock.now,
      timeframe=clock.get_step_size(),
      pre_act_key=observation_label,
      logging_channel=measurements.get_channel('Observation').on_next,
  )

  components_of_agent = {
    _get_class_name(observation): observation,
    agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME:
    (agent_components.memory_component.MemoryComponent(raw_memory))
    }

  act_component = ActComponent(
      model=model,
      clock=clock,
      logging_channel=logging_channels,
      agent_goal=agent_goal
  )
  agent = entity_agent_with_logging.EntityAgentWithLogging(
      agent_name=agent_name,
      act_component=act_component,
      context_components=components_of_agent,
      component_logging=measurements,
  )
  return agent
