# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import re
from enum import Enum
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Union,
)

from pydantic import BaseModel, Field

if TYPE_CHECKING:
    from camel.agents import ChatAgent
import uuid

from camel.logger import get_logger
from camel.messages import BaseMessage
from camel.prompts import TextPrompt

# Note: validate_task_content moved here to avoid circular imports
from .task_prompt import (
    TASK_COMPOSE_PROMPT,
    TASK_DECOMPOSE_PROMPT,
    TASK_EVOLVE_PROMPT,
)

logger = get_logger(__name__)


def validate_task_content(
    content: str, task_id: str = "unknown", min_length: int = 10
) -> bool:
    r"""Validates task result content to avoid silent failures.
    It performs basic checks to ensure the content meets minimum
    quality standards.

    Args:
        content (str): The task result content to validate.
        task_id (str): Task ID for logging purposes.
            (default: :obj:`"unknown"`)
        min_length (int): Minimum content length after stripping whitespace.
            (default: :obj:`10`)

    Returns:
        bool: True if content passes validation, False otherwise.
    """
    # 1: Content must not be None
    if content is None:
        logger.warning(f"Task {task_id}: None content rejected")
        return False

    # 2: Content must not be empty after stripping whitespace
    stripped_content = content.strip()
    if not stripped_content:
        logger.warning(
            f"Task {task_id}: Empty or whitespace-only content rejected."
        )
        return False

    # 3: Content must meet minimum meaningful length
    if len(stripped_content) < min_length:
        logger.warning(
            f"Task {task_id}: Content too short ({len(stripped_content)} "
            f"chars < {min_length} minimum). Content preview: "
            f"'{stripped_content[:50]}...'"
        )
        return False

    # All validation checks passed
    logger.debug(
        f"Task {task_id}: Content validation passed "
        f"({len(stripped_content)} chars)"
    )
    return True


def parse_response(
    response: str, task_id: Optional[str] = None
) -> List["Task"]:
    r"""Parse Tasks from a response.

    Args:
        response (str): The model response.
        task_id (str, optional): a parent task id,
            the default value is "0"

    Returns:
        List[Task]: A list of tasks which is :obj:`Task` instance.
    """
    pattern = "<task>(.*?)</task>"
    tasks_content = re.findall(pattern, response, re.DOTALL)

    tasks = []
    if task_id is None:
        task_id = "0"
    for i, content in enumerate(tasks_content):
        stripped_content = content.strip()
        # validate subtask content before creating the task
        if validate_task_content(stripped_content, f"{task_id}.{i}"):
            tasks.append(Task(content=stripped_content, id=f"{task_id}.{i}"))
        else:
            logger.warning(
                f"Skipping invalid subtask {task_id}.{i} "
                f"during decomposition: "
                f"Content '{stripped_content[:50]}...' failed validation"
            )
    return tasks


class TaskState(str, Enum):
    OPEN = "OPEN"
    RUNNING = "RUNNING"
    DONE = "DONE"
    FAILED = "FAILED"
    DELETED = "DELETED"

    @classmethod
    def states(cls):
        return [s.value for s in cls]


class Task(BaseModel):
    r"""Task is specific assignment that can be passed to a agent.

    Attributes:
        content (str): string content for task.
        id (str): An unique string identifier for the task. This should
            ideally be provided by the provider/model which created the task.
            (default: :obj:`uuid.uuid4()`)
        state (TaskState): The state which should be OPEN, RUNNING, DONE or
            DELETED. (default: :obj:`TaskState.FAILED`)
        type (Optional[str]): task type. (default: :obj:`None`)
        parent (Optional[Task]): The parent task, None for root task.
            (default: :obj:`None`)
        subtasks (List[Task]): The childrent sub-tasks for the task.
            (default: :obj:`[]`)
        result (Optional[str]): The answer for the task.
            (default: :obj:`""`)
        failure_count (int): The failure count for the task.
            (default: :obj:`0`)
        additional_info (Optional[Dict[str, Any]]): Additional information for
            the task. (default: :obj:`None`)
    """

    content: str

    id: str = Field(default_factory=lambda: str(uuid.uuid4()))

    state: TaskState = (
        TaskState.FAILED
    )  # TODO: Add logic for OPEN in workforce.py

    type: Optional[str] = None

    parent: Optional["Task"] = None

    subtasks: List["Task"] = []

    result: Optional[str] = ""

    failure_count: int = 0

    additional_info: Optional[Dict[str, Any]] = None

    def __repr__(self) -> str:
        r"""Return a string representation of the task."""
        content_preview = self.content
        return (
            f"Task(id='{self.id}', content='{content_preview}', "
            f"state='{self.state.value}')"
        )

    @classmethod
    def from_message(cls, message: BaseMessage) -> "Task":
        r"""Create a task from a message.

        Args:
            message (BaseMessage): The message to the task.

        Returns:
            Task
        """
        return cls(content=message.content, id="0")

    @staticmethod
    def to_message():
        r"""Convert a Task to a Message."""
        # TODO
        pass

    def reset(self):
        r"""Reset Task to initial state."""
        self.state = (
            TaskState.FAILED
        )  # TODO: Add logic for OPEN in workforce.py
        self.result = ""

    def update_result(self, result: str):
        r"""Set task result and mark the task as DONE.

        Args:
            result (str): The task result.
        """
        self.result = result
        self.set_state(TaskState.DONE)

    def set_id(self, id: str):
        r"""Set the id of the task.

        Args:
            id (str): The id of the task.
        """
        self.id = id

    def set_state(self, state: TaskState):
        r"""Recursively set the state of the task and its subtasks.

        Args:
            state (TaskState): The giving state.
        """
        self.state = state
        if state == TaskState.DONE:
            for subtask in self.subtasks:
                if subtask.state != TaskState.DELETED:
                    subtask.set_state(state)
        elif state == TaskState.RUNNING and self.parent:
            self.parent.set_state(state)

    def add_subtask(self, task: "Task"):
        r"""Add a subtask to the current task.

        Args:
            task (Task): The subtask to be added.
        """
        task.parent = self
        self.subtasks.append(task)

    def remove_subtask(self, id: str):
        r"""Remove a subtask from the current task.

        Args:
            id (str): The id of the subtask to be removed.
        """
        self.subtasks = [task for task in self.subtasks if task.id != id]

    def get_running_task(self) -> Optional["Task"]:
        r"""Get RUNNING task."""
        for sub in self.subtasks:
            if sub.state == TaskState.RUNNING:
                return sub.get_running_task()
        if self.state == TaskState.RUNNING:
            return self
        return None

    def to_string(self, indent: str = "", state: bool = False) -> str:
        r"""Convert task to a string.

        Args:
            indent (str): The ident for hierarchical tasks.
            state (bool): Include or not task state.

        Returns:
            str: The printable task string.
        """
        if state:
            _str = f"{indent}[{self.state}] Task {self.id}: {self.content}\n"
        else:
            _str = f"{indent}Task {self.id}: {self.content}\n"
        for subtask in self.subtasks:
            _str += subtask.to_string(indent + "  ", state)
        return _str

    def get_result(self, indent: str = "") -> str:
        r"""Get task result to a string.

        Args:
            indent (str): The ident for hierarchical tasks.

        Returns:
            str: The printable task string.
        """
        _str = f"{indent}Task {self.id} result: {self.result}\n"
        for subtask in self.subtasks:
            _str += subtask.get_result(indent + "  ")
        return _str

    def decompose(
        self,
        agent: "ChatAgent",
        prompt: Optional[str] = None,
        task_parser: Callable[[str, str], List["Task"]] = parse_response,
    ) -> List["Task"]:
        r"""Decompose a task to a list of sub-tasks. It can be used for data
        generation and planner of agent.

        Args:
            agent (ChatAgent): An agent that used to decompose the task.
            prompt (str, optional): A prompt to decompose the task. If not
                provided, the default prompt will be used.
            task_parser (Callable[[str, str], List[Task]], optional): A
                function to extract Task from response. If not provided,
                the default parse_response will be used.

        Returns:
            List[Task]: A list of tasks which are :obj:`Task` instances.
        """

        role_name = agent.role_name
        content = prompt or TASK_DECOMPOSE_PROMPT.format(
            role_name=role_name,
            content=self.content,
        )
        msg = BaseMessage.make_user_message(
            role_name=role_name, content=content
        )
        response = agent.step(msg)
        tasks = task_parser(response.msg.content, self.id)
        for task in tasks:
            task.additional_info = self.additional_info
        return tasks

    def compose(
        self,
        agent: "ChatAgent",
        template: TextPrompt = TASK_COMPOSE_PROMPT,
        result_parser: Optional[Callable[[str], str]] = None,
    ):
        r"""compose task result by the sub-tasks.

        Args:
            agent (ChatAgent): An agent that used to compose the task result.
            template (TextPrompt, optional): The prompt template to compose
                task. If not provided, the default template will be used.
            result_parser (Callable[[str, str], List[Task]], optional): A
                function to extract Task from response.
        """

        if not self.subtasks:
            return

        sub_tasks_result = self.get_result()

        role_name = agent.role_name
        content = template.format(
            role_name=role_name,
            content=self.content,
            additional_info=self.additional_info,
            other_results=sub_tasks_result,
        )
        msg = BaseMessage.make_user_message(
            role_name=role_name, content=content
        )
        response = agent.step(msg)
        result = response.msg.content
        if result_parser:
            result = result_parser(result)
        self.update_result(result)

    def get_depth(self) -> int:
        r"""Get current task depth."""
        if self.parent is None:
            return 1
        return 1 + self.parent.get_depth()


class TaskManager:
    r"""TaskManager is used to manage tasks.

    Attributes:
        root_task: The root task.
        tasks: The ordered tasks.
        task_map: A map for task.id to Task.
        current_task_id: The current "RUNNING" task.id.

    Args:
        task (Task): The root Task.
    """

    def __init__(self, task: Task):
        self.root_task: Task = task
        self.current_task_id: str = task.id
        self.tasks: List[Task] = [task]
        self.task_map: Dict[str, Task] = {task.id: task}

    def gen_task_id(self) -> str:
        r"""Generate a new task id."""
        return f"{len(self.tasks)}"

    def exist(self, task_id: str) -> bool:
        r"""Check if a task with the given id exists."""
        return task_id in self.task_map

    @property
    def current_task(self) -> Optional[Task]:
        r"""Get the current task."""
        return self.task_map.get(self.current_task_id, None)

    @staticmethod
    def topological_sort(tasks: List[Task]) -> List[Task]:
        r"""Sort a list of tasks by topological way.

        Args:
            tasks (List[Task]): The giving list of tasks.

        Returns:
            The sorted list of tasks.
        """
        stack = []
        visited = set()

        # recursive visit the vertices
        def visit(task: Task):
            if task.id in visited:
                return
            visited.add(task.id)

            # go deep for dependencies
            for sub_task in task.subtasks:
                visit(sub_task)

            # add current task to stack which have no dependencies.
            stack.append(task)

        for task in tasks:
            visit(task)

        return stack

    @staticmethod
    def set_tasks_dependence(
        root: Task,
        others: List[Task],
        type: Literal["serial", "parallel"] = "parallel",
    ):
        r"""Set relationship between root task and other tasks.
        Two relationships are currently supported: serial and parallel.
        `serial` :  root -> other1 -> other2
        `parallel`: root -> other1
                         -> other2

        Args:
            root (Task): A root task.
            others (List[Task]): A list of tasks.
        """
        # filter the root task in the others to avoid self-loop dependence.
        others = [other for other in others if other != root]

        if len(others) == 0:
            return
        if type == "parallel":
            for other in others:
                root.add_subtask(other)
        else:
            parent = root
            for child in others:
                parent.add_subtask(child)
                parent = child

    def add_tasks(self, tasks: Union[Task, List[Task]]) -> None:
        r"""self.tasks and self.task_map will be updated by the input tasks."""
        if not tasks:
            return
        if not isinstance(tasks, List):
            tasks = [tasks]
        for task in tasks:
            assert not self.exist(task.id), f"`{task.id}` already existed."
        self.tasks = self.topological_sort(self.tasks + tasks)
        self.task_map = {task.id: task for task in self.tasks}

    def evolve(
        self,
        task: Task,
        agent: "ChatAgent",
        template: Optional[TextPrompt] = None,
        task_parser: Optional[Callable[[str, str], List[Task]]] = None,
    ) -> Optional[Task]:
        r"""Evolve a task to a new task.
            Evolve is only used for data generation.

        Args:
            task (Task): A given task.
            agent (ChatAgent): An agent that used to evolve the task.
            template (TextPrompt, optional): A prompt template to evolve task.
                If not provided, the default template will be used.
            task_parser (Callable, optional): A function to extract Task from
                response. If not provided, the default parser will be used.

        Returns:
            Task: The created :obj:`Task` instance or None.
        """

        if template is None:
            template = TASK_EVOLVE_PROMPT

        role_name = agent.role_name
        content = template.format(role_name=role_name, content=task.content)
        msg = BaseMessage.make_user_message(
            role_name=role_name, content=content
        )
        response = agent.step(msg)
        if task_parser is None:
            task_parser = parse_response
        tasks = task_parser(response.msg.content, task.id)
        if tasks:
            return tasks[0]
        return None
