from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, List, Optional

from . import order_registry as OrderRegistry
from .base import BaseOrder
from agentverse.logging import logger

if TYPE_CHECKING:
    from agentverse.environments import BaseEnvironment


@OrderRegistry.register("classroom")
class ClassroomOrder(BaseOrder):
    """The order for a classroom discussion
    The agents speak in the following order:
    1. The professor speaks first
    2. Then the professor can continue to speak, and the students can raise hands
    3. The professor can call on a student, then the student can speak or ask a question
    4. In the group discussion, the students in the group can speak in turn
    """

    def get_next_agent_idx(self, environment: BaseEnvironment) -> List[int]:
        # `is_grouped_ended`: whether the group discussion just ended
        # `is_grouped`: whether it is currently in a group discussion
        if environment.rule_params.get("is_grouped_ended", False):
            return [0]
        if environment.rule_params.get("is_grouped", False):
            return self.get_next_agent_idx_grouped(environment)
        else:
            return self.get_next_agent_idx_ungrouped(environment)

    def get_next_agent_idx_ungrouped(self, environment: BaseEnvironment) -> List[int]:
        if len(environment.last_messages) == 0:
            # If the class just begins or no one speaks in the last turn, we let only the professor speak
            return [0]
        elif len(environment.last_messages) == 1:
            message = environment.last_messages[0]
            sender = message.sender
            content = message.content
            if sender.startswith("Professor"):
                if content.startswith("[CallOn]"):
                    # 1. professor calls on someone, then the student should speak
                    result = re.search(r"\[CallOn\] Yes, ([sS]tudent )?(\w+)", content)
                    if result is not None:
                        name_to_id = {
                            agent.name[len("Student ") :]: i
                            for i, agent in enumerate(environment.agents)
                        }
                        return [name_to_id[result.group(2)]]
                else:
                    # 2. professor normally speaks, then anyone can act
                    return list(range(len(environment.agents)))
            elif sender.startswith("Student"):
                # 3. student ask question after being called on, or
                # 4. only one student raises hand, and the professor happens to listen
                # 5. the group discussion is just over, and there happens to be only a student speaking in the last turn
                return [0]
        else:
            # If len(last_messages) > 1, then
            # 1. there must be at least one student raises hand or speaks.
            # 2. the group discussion is just over.
            return [0]
        assert (
            False
        ), f"Should not reach here, last_messages: {environment.last_messages}"

    def get_next_agent_idx_grouped(self, environment: BaseEnvironment) -> List[int]:
        # Get the grouping information
        # groups: A list of list of agent ids, the i-th list contains
        #   the agent ids in the i-th group
        # group_speaker_mapping: A mapping from group id to the id of
        #   the speaker in the group
        # `groups` should be set in the corresponding `visibility`,
        # and `group_speaker_mapping` should be maintained here.
        if "groups" not in environment.rule_params:
            logger.warn(
                "The environment is grouped, but the grouping information is not provided."
            )
        groups = environment.rule_params.get(
            "groups", [list(range(len(environment.agents)))]
        )
        group_speaker_mapping = environment.rule_params.get(
            "group_speaker_mapping", {i: 0 for i in range(len(groups))}
        )

        # For grouped environment, we let the students speak in turn within each group
        next_agent_idx = []
        for group_id in range(len(groups)):
            speaker_index = group_speaker_mapping[group_id]
            speaker = groups[group_id][speaker_index]
            next_agent_idx.append(speaker)

        # Maintain the `group_speaker_mapping`
        for k, v in group_speaker_mapping.items():
            group_speaker_mapping[k] = (v + 1) % len(groups[k])
        environment.rule_params["group_speaker_mapping"] = group_speaker_mapping

        return next_agent_idx
