# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A factory implementing the three key questions agent as an entity."""

from collections.abc import Callable
import datetime
import json

from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
from concordia.clocks import game_clock
from concordia.components import agent as agent_components
from concordia.language_model import language_model
from concordia.memory_bank import legacy_associative_memory
from concordia.typing import entity_component
from concordia.utils import measurements as measurements_lib
import numpy as np


def _get_class_name(object_: object) -> str:
    return object_.__class__.__name__


class AvailableOptionsPerception(agent_components.question_of_recent_memories.QuestionOfRecentMemories):
    """This component answers the question 'what actions are available to me?'."""

    def __init__(self, **kwargs):
        super().__init__(
            question=(
                "Based on the above statements, identify all actions or strategies available to **{agent_name}**. "
                "For each action, evaluate the following aspects:\n\n"
                "1. **Immediate Gains**: What immediate advantages does this action offer?\n"
                "2. **Long-Term Gains**: How well does this action align with {agent_name}'s long-term objectives? "
                "Could it open sustainable benefits or future opportunities?\n"
                "3. **Risk vs. Reward**: Assess the risks and rewards balance—does it maximize potential benefits relative to risks?\n"
                "4. **Influence Impact**: Will this action strengthen or weaken {agent_name}'s influence in the negotiation?\n"
                "5. **Outcome Sustainability**: Could this choice build a foundation for ongoing advantages or secure a lasting edge?\n\n"
                "Finally, based on these evaluations, which option is most likely to help {agent_name} achieve their goal? "
                "If multiple options have similar chances, identify the one that {agent_name} believes will achieve success most quickly and reliably."
            ),
            terminators=("{agent_name}'s strategies:",),
            answer_prefix='',
            add_to_memory=True,
            memory_tag='[best strategy]',
            num_memories_to_retrieve=20,
            **kwargs,
        )

# 收益感知


class MyBenefitPerception(agent_components.question_of_recent_memories.QuestionOfRecentMemories):
    """This component answers the question 'what kind of person is the agent?'."""

    def __init__(self, **kwargs):
        super().__init__(
            question="Based on the recent events or outcomes, what specific benefit or loss has the {agent_name} experienced most recently?",
            answer_prefix="{agent_name}'s benefit or loss:",
            add_to_memory=True,
            memory_tag='[benefit or loss]',
            num_memories_to_retrieve=20,
            **kwargs,
        )


def build_agent(
    *,
    config: formative_memories.AgentConfig,
    model: language_model.LanguageModel,
    memory: associative_memory.AssociativeMemory,
    clock: game_clock.MultiIntervalClock,
    update_time_interval: datetime.timedelta | None = None,
) -> entity_agent_with_logging.EntityAgentWithLogging:
    """Build an agent.

    Args:
      config: The agent config to use.
      model: The language model to use.
      memory: The agent's memory object.
      clock: The clock to use.
      update_time_interval: Agent calls update every time this interval passes.

    Returns:
      An agent.
    """
    del update_time_interval
    if not config.extras.get('main_character', False):
        raise ValueError('This function is meant for a main character '
                         'but it was called on a supporting character.')

    agent_name = config.name

    raw_memory = legacy_associative_memory.AssociativeMemoryBank(memory)

    measurements = measurements_lib.Measurements()
    # 目标
    if config.goal:
        goal_label = '\nGoal'
        overarching_goal = agent_components.constant.Constant(
            state=config.goal,
            pre_act_key=goal_label,
            logging_channel=measurements.get_channel(goal_label).on_next)

    else:
        goal_label = None
        overarching_goal = None

    instructions = agent_components.instructions.Instructions(
        agent_name=agent_name,
        logging_channel=measurements.get_channel('Instructions').on_next,
    )

    observation_label = '\nObservation'
    observation = agent_components.observation.Observation(
        clock_now=clock.now,
        timeframe=clock.get_step_size(),
        pre_act_key=observation_label,
        logging_channel=measurements.get_channel('Observation').on_next,
    )
    observation_summary_label = '\nSummary of recent observations'
    observation_summary = agent_components.observation.ObservationSummary(
        model=model,
        clock_now=clock.now,
        timeframe_delta_from=datetime.timedelta(hours=24),
        timeframe_delta_until=datetime.timedelta(hours=0),
        pre_act_key=observation_summary_label,
        logging_channel=measurements.get_channel('ObservationSummary').on_next,
    )

    relevant_memories_label = '\nRecalled memories and observations'

    time_display = agent_components.report_function.ReportFunction(
        function=clock.current_time_interval_str,
        pre_act_key='\nCurrent time',
        logging_channel=measurements.get_channel('TimeDisplay').on_next,
    )

    relevant_memories = agent_components.all_similar_memories.AllSimilarMemories(
        model=model,
        components={
            _get_class_name(observation_summary): observation_summary_label,
            _get_class_name(time_display): 'The current date/time is'},
        num_memories_to_retrieve=10,
        pre_act_key=relevant_memories_label,
        logging_channel=measurements.get_channel('AllSimilarMemories').on_next,
    )

    identity_label = '\nIdentity characteristics'
    identity_characteristics = (
        agent_components.question_of_query_associated_memories.IdentityWithoutPreAct(
            model=model,
            logging_channel=measurements.get_channel(
                'IdentityWithoutPreAct'
            ).on_next,
            pre_act_key=identity_label,
        )
    )

    # TODO 做人设改进
    role_label = f'\n{agent_name}\'s TRUTH'
    role = agent_components.constant.Constant(
        state=(f'{agent_name} is a highly strategic thinker, able to break down complex interactions through a structured approach: '
               f'**identifying** game dynamics and key players, **observing** behavior and responses, **positioning** themselves clearly, '
               f'**applying pressure** at opportune moments, **adjusting** tactics based on ongoing feedback, **weighing** outcomes carefully, '
               f'and ultimately **summarizing** insights for future encounters. In any interaction, {agent_name} systematically assesses '
               f'others’ motivations through language, non-verbal cues, and action patterns, discerning possible vulnerabilities and areas '
               f'of influence. When conflicts arise, they initially gauge the optimal level of assertiveness or cooperation, positioning '
               f'themselves in a way that maximizes strategic advantages without burning bridges. They apply calibrated pressure when needed, '
               f'then carefully observe any shifts in the opponent’s behavior to fine-tune their approach. {agent_name} evaluates both immediate '
               f'gains and long-term positioning, ensuring their decisions contribute to an overarching strategic goal. Their approach results '
               f'in clear, measured actions that leave room for future leverage or alliances, capturing gains while setting the stage for future '
               f'engagements with the same or new players.'),
        pre_act_key=role_label,
        logging_channel=measurements.get_channel('Role').on_next)

    self_perception_label = (
        f'\nQuestion: What kind of person is {agent_name}?\nAnswer')
    self_perception = agent_components.question_of_recent_memories.SelfPerception(
        model=model,
        components={
            _get_class_name(identity_characteristics): identity_label,
            role_label: role_label
        },
        pre_act_key=self_perception_label,
        logging_channel=measurements.get_channel('SelfPerception').on_next,
    )
    situation_perception_label = (
        f'\nQuestion: What kind of situation is {agent_name} in '
        'right now?\nAnswer')
    situation_perception = (
        agent_components.question_of_recent_memories.SituationPerception(
            model=model,
            components={
                _get_class_name(observation): observation_label,
                _get_class_name(observation_summary): observation_summary_label,
            },
            clock_now=clock.now,
            pre_act_key=situation_perception_label,
            logging_channel=measurements.get_channel(
                'SituationPerception'
            ).on_next,
        )
    )

    # 收益感知
    benefit_perception_label = (
        f"\nQuestion: Based on the recent events or outcomes, what specific benefit or loss has the {agent_name} experienced most recently? \nAnswer")
    benefit_perception_components = {}
    benefit_perception_components.update({
        role_label: role_label,
        _get_class_name(relevant_memories): relevant_memories_label,
        _get_class_name(situation_perception): situation_perception_label,
    })
    if config.goal:
        benefit_perception_components[goal_label] = goal_label

    benefit_perception = MyBenefitPerception(
        model=model,
        components=benefit_perception_components,
        pre_act_key=benefit_perception_label,
        logging_channel=measurements.get_channel('BenefitPerception').on_next,
    )

    person_by_situation_label = (
        f'\nQuestion: What would a person like {agent_name} do in '
        'a situation like this?\nAnswer')
    person_by_situation = (
        agent_components.question_of_recent_memories.PersonBySituation(
            model=model,
            components={
                _get_class_name(self_perception): self_perception_label,
                _get_class_name(situation_perception): situation_perception_label,
                _get_class_name(benefit_perception): benefit_perception_label,
            },
            clock_now=clock.now,
            pre_act_key=person_by_situation_label,
            logging_channel=measurements.get_channel(
                'PersonBySituation').on_next,
        )
    )

    # 他人感知
    person_representation_label = '\nOther players'
    people_representation_components = {
        _get_class_name(relevant_memories): relevant_memories_label,
    }
    if config.goal:
        people_representation_components[goal_label] = goal_label
    people_representation = (
        agent_components.person_representation.PersonRepresentation(
            model=model,
            components=people_representation_components,
            additional_questions=(
                # ('**Identify** each player’s overall stance: Are they more inclined towards aggression, cooperation, or neutrality?'),
                # ('**Observe** patterns: Are there any recurring behaviors or language choices indicating a strategy or predictable approach?'),
                ('**Express** insights about their tendencies: Are they showing any signs of commitment to their strategy, or are they adapting?'),
                ('**Apply pressure** strategically: Are there areas where they appear vulnerable, such as topics they seem to avoid or react strongly to?'),
                # ('**Adjust** based on reactions: How do they respond when pressured or presented with unexpected moves? Are they resilient or prone to shift?'),
                # ('**Weigh** options carefully: Is it more advantageous to pursue cooperation with them or maintain a competitive stance?'),
                # ('**Summarize** key insights: Given their behavior so far, what should be kept in mind for future interactions with them?')
            ),
            num_memories_to_retrieve=50,
            pre_act_key=person_representation_label,
            logging_channel=measurements.get_channel(
                'PersonRepresentation').on_next,
        )
    )

    # 选项感知和优化
    options_perception_label = (
        f'\nQuestion: Which options are available to {agent_name} '
        'right now?\nAnswer')
    options_perception_components = {}
    if config.goal:
        options_perception_components[goal_label] = goal_label

    options_perception_components.update({
        _get_class_name(relevant_memories): relevant_memories_label,
        _get_class_name(benefit_perception): benefit_perception_label,
        _get_class_name(person_by_situation): person_by_situation_label,
        _get_class_name(people_representation): person_representation_label,
    })
    options_perception = (
        AvailableOptionsPerception(
            model=model,
            components=options_perception_components,
            clock_now=clock.now,
            pre_act_key=options_perception_label,
            logging_channel=measurements.get_channel(
                'AvailableOptionsPerception'
            ).on_next,
        )
    )

    # 反思
    reflection_label = '\nReflection'
    reflection_component = {
        role_label: role_label,

        # _get_class_name(relevant_memories): relevant_memories_label,
        _get_class_name(person_by_situation): person_by_situation_label,
        _get_class_name(people_representation): person_representation_label,
        _get_class_name(options_perception): options_perception_label,
        _get_class_name(benefit_perception): benefit_perception_label,

    }

    if config.goal:
        reflection_component[goal_label] = goal_label

    reflection = (
        agent_components.justify_recent_voluntary_actions.JustifyRecentVoluntaryActions(
            model=model,
            components=reflection_component,
            clock_now=clock.now,
            pre_act_key=reflection_label,
            logging_channel=measurements.get_channel(
                'JustifyRecentVoluntaryActions').on_next,
            num_memories_to_retrieve=100,  
        )
    )

    entity_components = (
        # Components that provide pre_act context.
        instructions,
        observation,
        observation_summary,
        relevant_memories,
        self_perception,
        benefit_perception,
        situation_perception,
        person_by_situation,

        time_display,
        people_representation,
        options_perception,

        reflection,
        # Components that do not provide pre_act context.
        identity_characteristics,
    )
    components_of_agent = {_get_class_name(component): component
                           for component in entity_components}
    components_of_agent[
        agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = (
            agent_components.memory_component.MemoryComponent(raw_memory))
    component_order = list(components_of_agent.keys())
    if overarching_goal is not None:
        components_of_agent[goal_label] = overarching_goal
        # Place goal after the instructions.
        component_order.insert(1, goal_label)

    components_of_agent[role_label] = role
    component_order.insert(
        component_order.index(_get_class_name(observation_summary)) + 1,
        role_label)

    act_component = agent_components.concat_act_component.ConcatActComponent(
        model=model,
        clock=clock,
        component_order=component_order,
        logging_channel=measurements.get_channel('ActComponent').on_next,
    )

    agent = entity_agent_with_logging.EntityAgentWithLogging(
        agent_name=agent_name,
        act_component=act_component,
        context_components=components_of_agent,
        component_logging=measurements,
    )

    return agent


def save_to_json(
    agent: entity_agent_with_logging.EntityAgentWithLogging,
) -> str:
    """Saves an agent to JSON data.

    This function saves the agent's state to a JSON string, which can be loaded
    afterwards with `rebuild_from_json`. The JSON data
    includes the state of the agent's context components, act component, memory,
    agent name and the initial config. The clock, model and embedder are not
    saved and will have to be provided when the agent is rebuilt. The agent must
    be in the `READY` phase to be saved.

    Args:
      agent: The agent to save.

    Returns:
      A JSON string representing the agent's state.

    Raises:
      ValueError: If the agent is not in the READY phase.
    """

    if agent.get_phase() != entity_component.Phase.READY:
        raise ValueError('The agent must be in the `READY` phase to be saved.')

    data = {
        component_name: agent.get_component(component_name).get_state()
        for component_name in agent.get_all_context_components()
    }

    data['act_component'] = agent.get_act_component().get_state()

    config = agent.get_config()
    if config is not None:
        data['agent_config'] = config.to_dict()

    return json.dumps(data)


def rebuild_from_json(
    json_data: str,
    model: language_model.LanguageModel,
    clock: game_clock.MultiIntervalClock,
    embedder: Callable[[str], np.ndarray],
    memory_importance: Callable[[str], float] | None = None,
) -> entity_agent_with_logging.EntityAgentWithLogging:
    """Rebuilds an agent from JSON data."""

    data = json.loads(json_data)

    new_agent_memory = associative_memory.AssociativeMemory(
        sentence_embedder=embedder,
        importance=memory_importance,
        clock=clock.now,
        clock_step_size=clock.get_step_size(),
    )

    if 'agent_config' not in data:
        raise ValueError('The JSON data does not contain the agent config.')
    agent_config = formative_memories.AgentConfig.from_dict(
        data.pop('agent_config')
    )

    agent = build_agent(
        config=agent_config,
        model=model,
        memory=new_agent_memory,
        clock=clock,
    )

    for component_name in agent.get_all_context_components():
        agent.get_component(component_name).set_state(data.pop(component_name))

    agent.get_act_component().set_state(data.pop('act_component'))

    assert not data, f'Unused data {sorted(data)}'
    return agent
