# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import asyncio
from enum import Enum
from typing import Dict, List, Optional

from camel.tasks import Task


class PacketStatus(Enum):
    r"""The status of a packet. The packet can be in one of the following
    states:

    - ``SENT``: The packet has been sent to a worker.
    - ``PROCESSING``: The packet has been claimed by a worker and is being
    processed.
    - ``RETURNED``: The packet has been returned by the worker, meaning that
      the status of the task inside has been updated.
    - ``ARCHIVED``: The packet has been archived, meaning that the content of
      the task inside will not be changed. The task is considered
      as a dependency.
    """

    SENT = "SENT"
    PROCESSING = "PROCESSING"
    RETURNED = "RETURNED"
    ARCHIVED = "ARCHIVED"


class Packet:
    r"""The basic element inside the channel. A task is wrapped inside a
    packet. The packet will contain the task, along with the task's assignee,
    and the task's status.

    Args:
        task (Task): The task that is wrapped inside the packet.
        publisher_id (str): The ID of the workforce that published the task.
        assignee_id (str): The ID of the workforce that is assigned
            to the task. Defaults to None, meaning that the task is posted as
            a dependency in the channel.

    Attributes:
        task (Task): The task that is wrapped inside the packet.
        publisher_id (str): The ID of the workforce that published the task.
        assignee_id (Optional[str], optional): The ID of the workforce that is
            assigned to the task. Would be None if the task is a dependency.
            Defaults to None.
        status (PacketStatus): The status of the task.
    """

    def __init__(
        self,
        task: Task,
        publisher_id: str,
        assignee_id: Optional[str] = None,
        status: PacketStatus = PacketStatus.SENT,
    ) -> None:
        self.task = task
        self.publisher_id = publisher_id
        self.assignee_id = assignee_id
        self.status = status

    def __repr__(self):
        return (
            f"Packet(publisher_id={self.publisher_id}, assignee_id="
            f"{self.assignee_id}, status={self.status})"
        )


class TaskChannel:
    r"""An internal class used by Workforce to manage tasks."""

    def __init__(self) -> None:
        self._condition = asyncio.Condition()
        self._task_dict: Dict[str, Packet] = {}

    async def get_returned_task_by_publisher(self, publisher_id: str) -> Task:
        r"""Get a task from the channel that has been returned by the
        publisher.
        """
        async with self._condition:
            while True:
                for packet in self._task_dict.values():
                    if packet.publisher_id != publisher_id:
                        continue
                    if packet.status != PacketStatus.RETURNED:
                        continue
                    return packet.task
                await self._condition.wait()

    async def get_assigned_task_by_assignee(self, assignee_id: str) -> Task:
        r"""Atomically get and claim a task from the channel that has been
        assigned to the assignee. This prevents race conditions where multiple
        concurrent calls might retrieve the same task.
        """
        async with self._condition:
            while True:
                for packet in self._task_dict.values():
                    if (
                        packet.status == PacketStatus.SENT
                        and packet.assignee_id == assignee_id
                    ):
                        # Atomically claim the task by changing its status
                        packet.status = PacketStatus.PROCESSING
                        self._condition.notify_all()
                        return packet.task
                await self._condition.wait()

    async def post_task(
        self, task: Task, publisher_id: str, assignee_id: str
    ) -> None:
        r"""Send a task to the channel with specified publisher and assignee,
        along with the dependency of the task."""
        async with self._condition:
            packet = Packet(task, publisher_id, assignee_id)
            self._task_dict[packet.task.id] = packet
            self._condition.notify_all()

    async def post_dependency(
        self, dependency: Task, publisher_id: str
    ) -> None:
        r"""Post a dependency to the channel. A dependency is a task that is
        archived, and will be referenced by other tasks."""
        async with self._condition:
            packet = Packet(
                dependency, publisher_id, status=PacketStatus.ARCHIVED
            )
            self._task_dict[packet.task.id] = packet
            self._condition.notify_all()

    async def return_task(self, task_id: str) -> None:
        r"""Return a task to the sender, indicating that the task has been
        processed by the worker."""
        async with self._condition:
            if task_id in self._task_dict:
                packet = self._task_dict[task_id]
                packet.status = PacketStatus.RETURNED
            self._condition.notify_all()

    async def archive_task(self, task_id: str) -> None:
        r"""Archive a task in channel, making it to become a dependency."""
        async with self._condition:
            if task_id in self._task_dict:
                packet = self._task_dict[task_id]
                packet.status = PacketStatus.ARCHIVED
            self._condition.notify_all()

    async def remove_task(self, task_id: str) -> None:
        r"""Remove a task from the channel."""
        async with self._condition:
            # Check if task ID exists before removing
            if task_id in self._task_dict:
                del self._task_dict[task_id]
            self._condition.notify_all()

    async def get_dependency_ids(self) -> List[str]:
        r"""Get the IDs of all dependencies in the channel."""
        async with self._condition:
            dependency_ids = []
            for task_id, packet in self._task_dict.items():
                if packet.status == PacketStatus.ARCHIVED:
                    dependency_ids.append(task_id)
            return dependency_ids

    async def get_task_by_id(self, task_id: str) -> Task:
        r"""Get a task from the channel by its ID."""
        async with self._condition:
            packet = self._task_dict.get(task_id)
            if packet is None:
                raise ValueError(f"Task {task_id} not found.")
            return packet.task

    async def get_channel_debug_info(self) -> str:
        r"""Get the debug information of the channel."""
        async with self._condition:
            return str(self._task_dict)
