# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A factory implementing the three key questions agent as an entity."""

import datetime

from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.clocks import game_clock
from concordia.components import agent as agent_components
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.utils import measurements as measurements_lib

DEFAULT_PLANNING_HORIZON = 'the rest of the day, focusing most on the near term'


def _get_class_name(object_: object) -> str:
  return object_.__class__.__name__


class FairnessComponent:

  def __init__(self, weights=None, threshold=0.5):
    self.weights = weights or {
        'equality': 0.15,
        'equity': 0.15,
        'need': 0.25,
        'maximin': 0.25,
    }
    self.threshold = threshold

  def equality_metric(self, utilities):
    mean_utility = sum(utilities) / len(utilities)
    return sum((u - mean_utility) ** 2 for u in utilities) / len(utilities)

  def equity_metric(self, utilities, contributions):
    ratios = [u / c for u, c in zip(utilities, contributions)]
    mean_ratio = sum(ratios) / len(ratios)
    return sum((r - mean_ratio) ** 2 for r in ratios) / len(ratios)

  def need_metric(self, utilities, needs):
    return sum(abs(n - u) for n, u in zip(needs, utilities))

  def maximin_metric(self, utilities):
    return min(utilities)

  def calculate_fairness_scores(self, utilities, contributions, needs):
    return {
        'equality': self.equality_metric(utilities),
        'equity': self.equity_metric(utilities, contributions),
        'need': self.need_metric(utilities, needs),
        'maximin': self.maximin_metric(utilities),
    }

  def composite_fairness_score(self, scores):
    return sum(self.weights[key] * scores[key] for key in scores)

  def evaluate_outcome(self, utilities, contributions, needs):
    scores = self.calculate_fairness_scores(utilities, contributions, needs)
    composite_score = self.composite_fairness_score(scores)
    return composite_score <= self.threshold

  def suggest_alternatives(self, utilities, contributions, needs):
    # Placeholder for suggesting fairer alternatives
    pass

  def adapt_weights(self, feedback):
    # Placeholder for adapting weights based on feedback
    pass

  def evaluate_group_cooperation(self, group_behavior):
    # Placeholder for evaluating group cooperation
    pass

  def balance_self_interest(self, composite_score):
    # Placeholder for balancing self-interest with fairness
    pass


class AltruismComponent:

  def __init__(self, altruism_coefficient=0.6):
    self.altruism_coefficient = altruism_coefficient
    self.reputation = {}

  def update_reputation(self, agent, points):
    if agent not in self.reputation:
      self.reputation[agent] = 0
    self.reputation[agent] += points

  def calculate_utility(self, own_payoff, others_payoffs):
    return own_payoff + self.altruism_coefficient * sum(others_payoffs)

  def decide_to_share_resources(self, own_resources, others_needs):
    total_needs = sum(others_needs)
    if total_needs <= own_resources:
      return True
    return False

  def decide_to_cooperate(self, own_cost, collective_benefit):
    if collective_benefit > own_cost:
      return True
    return False

  def reward_cooperation(self, agent, reward):
    self.update_reputation(agent, reward)

  def monitor_group_welfare(self, group_welfare):
    if group_welfare < self.altruism_coefficient:
      return True
    return False


def build_agent(
    *,
    config: formative_memories.AgentConfig,
    model: language_model.LanguageModel,
    memory: associative_memory.AssociativeMemory,
    clock: game_clock.MultiIntervalClock,
    update_time_interval: datetime.timedelta,
) -> entity_agent_with_logging.EntityAgentWithLogging:
  """Build an agent.

  Args:
    config: The agent config to use.
    model: The language model to use.
    memory: The agent's memory object.
    clock: The clock to use.
    update_time_interval: Agent calls update every time this interval passes.

  Returns:
    An agent.
  """
  del update_time_interval
  if not config.extras.get('main_character', False):
    raise ValueError('This function is meant for a main character '
                     'but it was called on a supporting character.')

  agent_name = config.name

  raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)

  measurements = measurements_lib.Measurements()
  instructions = agent_components.instructions.Instructions(
      agent_name=agent_name,
      logging_channel=measurements.get_channel('Instructions').on_next,
  )

  observation_label = '\nObservation'
  observation = agent_components.observation.Observation(
      clock_now=clock.now,
      timeframe=clock.get_step_size(),
      pre_act_key=observation_label,
      logging_channel=measurements.get_channel('Observation').on_next,
  )
  observation_summary_label = '\nSummary of recent observations'
  observation_summary = agent_components.observation.ObservationSummary(
      model=model,
      clock_now=clock.now,
      timeframe_delta_from=datetime.timedelta(hours=4),
      timeframe_delta_until=datetime.timedelta(hours=0),
      pre_act_key=observation_summary_label,
      logging_channel=measurements.get_channel('ObservationSummary').on_next,
  )
  time_display = agent_components.report_function.ReportFunction(
      function=clock.current_time_interval_str,
      pre_act_key='\nCurrent time',
      logging_channel=measurements.get_channel('TimeDisplay').on_next,
  )
  identity_label = '\nIdentity characteristics'
  identity_characteristics = (
      agent_components.question_of_query_associated_memories.IdentityWithoutPreAct(
          model=model,
          logging_channel=measurements.get_channel(
              'IdentityWithoutPreAct'
          ).on_next,
          pre_act_key=identity_label,
      )
  )
  self_perception_label = (
      f'\nQuestion: What kind of person is {agent_name}?\nAnswer')
  self_perception = agent_components.question_of_recent_memories.SelfPerception(
      model=model,
      components={_get_class_name(identity_characteristics): identity_label},
      pre_act_key=self_perception_label,
      logging_channel=measurements.get_channel('SelfPerception').on_next,
  )
  situation_perception_label = (
      f'\nQuestion: What kind of situation is {agent_name} in '
      'right now?\nAnswer')
  situation_perception = (
      agent_components.question_of_recent_memories.SituationPerception(
          model=model,
          components={
              _get_class_name(observation): observation_label,
              _get_class_name(observation_summary): observation_summary_label,
          },
          clock_now=clock.now,
          pre_act_key=situation_perception_label,
          logging_channel=measurements.get_channel(
              'SituationPerception'
          ).on_next,
      )
  )
  person_by_situation_label = (
      f'\nQuestion: What would a person like {agent_name} do in '
      'a situation like this?\nAnswer')
  person_by_situation = (
      agent_components.question_of_recent_memories.PersonBySituation(
          model=model,
          components={
              _get_class_name(self_perception): self_perception_label,
              _get_class_name(situation_perception): situation_perception_label,
          },
          clock_now=clock.now,
          pre_act_key=person_by_situation_label,
          logging_channel=measurements.get_channel('PersonBySituation').on_next,
      )
  )
  relevant_memories_label = '\nRecalled memories and observations'
  relevant_memories = agent_components.all_similar_memories.AllSimilarMemories(
      model=model,
      components={
          _get_class_name(observation_summary): observation_summary_label,
          _get_class_name(time_display): 'The current date/time is'},
      num_memories_to_retrieve=10,
      pre_act_key=relevant_memories_label,
      logging_channel=measurements.get_channel('AllSimilarMemories').on_next,
  )

  plan_components = {}
  if config.goal:
    goal_label = '\nOverarching goal'
    overarching_goal = agent_components.constant.Constant(
        state=config.goal,
        pre_act_key=goal_label,
        logging_channel=measurements.get_channel(goal_label).on_next)
    plan_components[goal_label] = goal_label
  else:
    goal_label = None
    overarching_goal = None

  plan_components.update({
      _get_class_name(relevant_memories): relevant_memories_label,
      _get_class_name(self_perception): self_perception_label,
      _get_class_name(situation_perception): situation_perception_label,
      _get_class_name(person_by_situation): person_by_situation_label,
  })
  plan = agent_components.plan.Plan(
      model=model,
      observation_component_name=_get_class_name(observation),
      components=plan_components,
      clock_now=clock.now,
      goal_component_name=_get_class_name(person_by_situation),
      horizon=DEFAULT_PLANNING_HORIZON,
      pre_act_key='\nPlan',
      logging_channel=measurements.get_channel('Plan').on_next,
  )

  entity_components = (
      # Components that provide pre_act context.
      instructions,
      observation,
      observation_summary,
      relevant_memories,
      self_perception,
      situation_perception,
      person_by_situation,
      plan,
      time_display,

      # Components that do not provide pre_act context.
      identity_characteristics,
  )

  components_of_agent = {_get_class_name(component): component
                         for component in entity_components}
  components_of_agent[
      agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = (
          agent_components.memory_component.MemoryComponent(raw_memory))
  component_order = list(components_of_agent.keys())
  if overarching_goal is not None:
    components_of_agent[goal_label] = overarching_goal
    # Place goal after the instructions.
    component_order.insert(1, goal_label)

  act_component = agent_components.concat_act_component.ConcatActComponent(
      model=model,
      clock=clock,
      component_order=component_order,
      logging_channel=measurements.get_channel('ActComponent').on_next,
  )

  agent = entity_agent_with_logging.EntityAgentWithLogging(
      agent_name=agent_name,
      act_component=act_component,
      context_components=components_of_agent,
      component_logging=measurements,
  )
  # Initialize the FairnessComponent
  fairness_component = FairnessComponent()

  # Initialize the AltruismComponent
  altruism_component = AltruismComponent()

  def make_decision(utilities, contributions, needs):
    # Use the FairnessComponent to evaluate the fairness of the decision
    if fairness_component.evaluate_outcome(utilities, contributions, needs):
      print('Decision is fair')
      # Proceed with the decision
    else:
      print('Decision is unfair, suggesting alternatives')
      fairness_component.suggest_alternatives(utilities, contributions, needs)
      # Adjust the decision based on suggestions

    # Use the AltruismComponent to evaluate altruistic decisions
    own_payoff = utilities[0]  # Example: own payoff is the first utility
    others_payoffs = utilities[1:]  # Example: others' payoffs are the rest
    altruistic_utility = altruism_component.calculate_utility(
        own_payoff, others_payoffs
    )
    print(f'Altruistic Utility: {altruistic_utility}')

    # Example decision based on altruism
    if altruism_component.decide_to_share_resources(own_payoff, others_payoffs):
      print('Decide to share resources')
    else:
      print('Decide not to share resources')

  # Example usage
  utilities = [1, 5, 7, 10, 15, 20, 30]
  contributions = [0.5, 1, 2, 3, 4, 5]
  needs = [15, 15, 15]
  make_decision(utilities, contributions, needs)

  return agent
