from __future__ import annotations

from typing import TYPE_CHECKING, List

from . import role_assigner_registry
from .base import BaseRoleAssigner

if TYPE_CHECKING:
    from agentverse.message import RoleAssignerMessage
    from agentverse.agents import CriticAgent, RoleAssignerAgent


@role_assigner_registry.register("role_description")
class DescriptionAssigner(BaseRoleAssigner):
    """
    Generates descriptions for each agent.
    """

    async def astep(
        self,
        role_assigner: RoleAssignerAgent,
        group_members: List[CriticAgent],
        advice: str = "No advice yet.",
        task_description: str = "",
        *args,
        **kwargs,
    ) -> List[CriticAgent]:
        assert task_description != ""
        assert len(group_members) > 0

        roles = await role_assigner.astep(advice, task_description, len(group_members))
        if len(roles.content) < len(group_members):
            raise ValueError(
                f"Number of roles ({len(roles.content)}) and number of group members ({len(group_members)}) do not match."
            )
        for role, member in zip(roles.content[: len(group_members)], group_members):
            description = role.strip().strip(".")
            member.role_description = description
            member.name = description

        return group_members

    def reset(self):
        pass


@role_assigner_registry.register("role_description_name")
class DescriptionNameAssigner(BaseRoleAssigner):
    """
    Generates description and name for each agent.
    """

    async def astep(
        self,
        role_assigner: RoleAssignerAgent,
        group_members: List[CriticAgent],
        advice: str = "No advice yet.",
        task_description: str = "",
        *args,
        **kwargs,
    ) -> List[CriticAgent]:
        assert task_description != ""
        assert len(group_members) > 0

        # roles: [{'name': 'xxx', 'description': 'xxx'}, ...]
        roles = await role_assigner.astep(advice, task_description, len(group_members))

        if len(group_members) < 2:
            pass
        else:
            if len(roles.content) != len(group_members):
                raise ValueError(
                    f"Number of roles ({len(roles.content)}) and number of group members ({len(group_members)}) do not match."
                )

        for role_dict, member in zip(roles.content, group_members):
            description = role_dict["description"].strip().strip(".")
            member.role_description = description
            member.name = role_dict["name"].strip()

        return group_members
