from typing import List, Union

from ..agent import SIGNAL_END_OF_CONVERSATION, Moderator
from ..config import AgentConfig, EnvironmentConfig
from ..message import Message, MessagePool
from .base import Environment, TimeStep


class Conversation(Environment):
    """
    Turn-based fully observable conversation environment.

    Next speaker order is either parallel or round-robin.
    """

    type_name = "conversation"

    def __init__(self, player_names: List[str], parallel: bool = False, **kwargs):
        super().__init__(player_names=player_names, parallel=parallel, **kwargs)

        self.parallel = parallel

        # The "state" of the environment is maintained by the message pool
        self.message_pool = MessagePool()

        self._current_turn = 0
        self._next_player_index = 0

    def reset(self):
        self._current_turn = 0
        self._next_player_index = 0
        self.message_pool.reset()

        init_timestep = TimeStep(
            observation=[], reward=self.get_zero_rewards(), terminal=False
        )
        return init_timestep

    @property
    def phase_index(self):
        return self._phase_index

    @phase_index.setter
    def phase_index(self, value):
        self._phase_index = value

    def to_config(self) -> EnvironmentConfig:
        return EnvironmentConfig(
            env_type=self.type_name,
            player_names=self.player_names,
            parallel=self.parallel,
        )

    def print(self):
        self.message_pool.print()

    def get_next_player(self) -> str:
        """Get the next player."""
        return self.player_names[self._next_player_index]

    def get_observation(self, player_name=None) -> List[Message]:
        """Get observation for the player."""
        if player_name is None:
            return self.message_pool.get_all_messages()
        else:
            return self.message_pool.get_visible_messages(
                player_name, turn=self._current_turn
            )

    def is_terminal(self) -> bool:
        """Check if the conversation is over."""
        # If the last message is the signal, then the conversation is over
        if self.message_pool.last_message.content.startswith(
            SIGNAL_END_OF_CONVERSATION
        ):
            return True

    def step(self, player_name: str, action: str) -> TimeStep:
        """
        Step function that is called by the arena.

        Args:
            player_name: the name of the player that takes the action
            action: the action that the agents wants to take
        """
        message = Message(
            agent_name=player_name, content=action, turn=self._current_turn
        )
        self.message_pool.append_message(message)

        # Update the counters
        if not self.parallel or self._next_player_index == 0:
            self._current_turn += 1
        self._next_player_index = (self._next_player_index + 1) % self.num_players

        timestep = TimeStep(
            observation=self.get_observation(),
            reward=self.get_zero_rewards(),
            terminal=self.is_terminal(),
        )  # Return all the messages
        return timestep


class ModeratedConversation(Conversation):
    """
    Turn-based fully observable conversation environment.

    Next speaker order is either parallel or round-robin.
    Moderator is a special agent that can see all messages and can decide whether the conversation is over.
    """

    type_name = "moderated_conversation"

    def __init__(
        self,
        player_names: List[str],
        moderator: Union[Moderator, AgentConfig],
        parallel: bool = False,
        moderator_visibility="all",
        moderator_period=None,
        **kwargs,
    ):
        super().__init__(player_names=player_names, parallel=parallel, **kwargs)

        if isinstance(moderator, AgentConfig):
            moderator_config = moderator
            moderator = Moderator.from_config(moderator_config)
        elif not isinstance(moderator, Moderator):
            raise ValueError(
                "moderator must be either an AgentConfig or a Moderator instance."
            )

        self.moderator = moderator
        self.moderator_visibility = moderator_visibility
        if moderator_period is None:
            if parallel:
                self.moderator_period = "round"
            else:
                self.moderator_period = "turn"
        else:
            self.moderator_period = moderator_period

    def to_config(self) -> EnvironmentConfig:
        # This environment contains some special config arguments that needs to be handle specially
        return EnvironmentConfig(
            env_type=self.type_name,
            player_names=self.player_names,
            parallel=self.parallel,
            moderator=self.moderator.to_config(),
            moderator_visibility=self.moderator_visibility,
            moderator_period=self.moderator_period,
        )

    def step(self, player_name: str, action: str) -> TimeStep:
        """
        Step function that is called by the arena.

        Args:
            player_name: the name of the player that takes the action
            action: the action that the agents wants to take
        """
        message = Message(
            agent_name=player_name, content=action, turn=self._current_turn
        )
        self.message_pool.append_message(message)

        # Round-robin order for the next player
        self._next_player_index = (self._next_player_index + 1) % self.num_players

        if self.moderator_period == "turn" or (
            self.moderator_period == "round" and self._next_player_index == 0
        ):
            # Moderator's turn
            moderator_history = self.message_pool.get_all_messages()
            moderator_response = self.moderator(moderator_history)
            moderator_message = Message(
                agent_name=self.moderator.name,
                content=moderator_response,
                turn=self._current_turn,
                visible_to=self.moderator_visibility,
            )
            self.message_pool.append_message(moderator_message)
            terminal = (
                self.moderator.is_terminal(moderator_history) or self.is_terminal()
            )
        else:
            terminal = self.is_terminal()

        # Update the counters
        if not self.parallel or self._next_player_index == 0:
            self._current_turn += 1

        timestep = TimeStep(
            observation=self.get_observation(),
            reward=self.get_zero_rewards(),
            terminal=terminal,
        )  # Return all the messages
        return timestep
