# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from  https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import asyncio
import datetime
import logging
import warnings
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from typing import Any

from ..doc_utils import export_module
from ..events.agent_events import PostCarryoverProcessingEvent
from ..io.base import IOStream
from .utils import consolidate_chat_info

logger = logging.getLogger(__name__)
Prerequisite = tuple[int, int]

__all__ = ["ChatResult", "a_initiate_chats", "initiate_chats"]


@dataclass
@export_module("autogen")
class ChatResult:
    """(Experimental) The result of a chat. Almost certain to be changed."""

    chat_id: int = None
    """chat id"""
    chat_history: list[dict[str, Any]] = None
    """The chat history."""
    summary: str = None
    """A summary obtained from the chat."""
    cost: dict[str, dict[str, Any]] = (
        None  # keys: "usage_including_cached_inference", "usage_excluding_cached_inference"
    )
    """The cost of the chat.
       The value for each usage type is a dictionary containing cost information for that specific type.
           - "usage_including_cached_inference": Cost information on the total usage, including the tokens in cached inference.
           - "usage_excluding_cached_inference": Cost information on the usage of tokens, excluding the tokens in cache. No larger than "usage_including_cached_inference".
    """
    human_input: list[str] = None
    """A list of human input solicited during the chat."""


def _validate_recipients(chat_queue: list[dict[str, Any]]) -> None:
    """Validate recipients exits and warn repetitive recipients."""
    receipts_set = set()
    for chat_info in chat_queue:
        assert "recipient" in chat_info, "recipient must be provided."
        receipts_set.add(chat_info["recipient"])
    if len(receipts_set) < len(chat_queue):
        warnings.warn(
            "Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
            UserWarning,
        )


def __create_async_prerequisites(chat_queue: list[dict[str, Any]]) -> list[Prerequisite]:
    """Create list of Prerequisite (prerequisite_chat_id, chat_id)"""
    prerequisites = []
    for chat_info in chat_queue:
        if "chat_id" not in chat_info:
            raise ValueError("Each chat must have a unique id for async multi-chat execution.")
        chat_id = chat_info["chat_id"]
        pre_chats = chat_info.get("prerequisites", [])
        for pre_chat_id in pre_chats:
            if not isinstance(pre_chat_id, int):
                raise ValueError("Prerequisite chat id is not int.")
            prerequisites.append((chat_id, pre_chat_id))
    return prerequisites


def __find_async_chat_order(chat_ids: set[int], prerequisites: list[Prerequisite]) -> list[int]:
    """Find chat order for async execution based on the prerequisite chats

    Args:
        chat_ids: A set of all chat IDs that need to be scheduled
        prerequisites: A list of tuples (chat_id, prerequisite_chat_id) where each tuple indicates that chat_id depends on prerequisite_chat_id

    Returns:
        list: a list of chat_id in order.
    """
    edges = defaultdict(set)
    indegree = defaultdict(int)
    for pair in prerequisites:
        chat, pre = pair[0], pair[1]
        if chat not in edges[pre]:
            indegree[chat] += 1
            edges[pre].add(chat)
    bfs = [i for i in chat_ids if i not in indegree]
    chat_order = []
    steps = len(indegree)
    for _ in range(steps + 1):
        if not bfs:
            break
        chat_order.extend(bfs)
        nxt = []
        for node in bfs:
            if node in edges:
                for course in edges[node]:
                    indegree[course] -= 1
                    if indegree[course] == 0:
                        nxt.append(course)
                        indegree.pop(course)
                edges.pop(node)
        bfs = nxt

    if indegree:
        return []
    return chat_order


def _post_process_carryover_item(carryover_item):
    if isinstance(carryover_item, str):
        return carryover_item
    elif isinstance(carryover_item, dict) and "content" in carryover_item:
        return str(carryover_item["content"])
    else:
        return str(carryover_item)


def __post_carryover_processing(chat_info: dict[str, Any]) -> None:
    iostream = IOStream.get_default()

    if "message" not in chat_info:
        warnings.warn(
            "message is not provided in a chat_queue entry. input() will be called to get the initial message.",
            UserWarning,
        )

    iostream.send(PostCarryoverProcessingEvent(chat_info=chat_info))


@export_module("autogen")
def initiate_chats(chat_queue: list[dict[str, Any]]) -> list[ChatResult]:
    """Initiate a list of chats.

    Args:
        chat_queue (List[Dict]): A list of dictionaries containing the information about the chats.

            Each dictionary should contain the input arguments for
            [`ConversableAgent.initiate_chat`](../ConversableAgent#initiate-chat).
            For example:
                - `"sender"` - the sender agent.
                - `"recipient"` - the recipient agent.
                - `"clear_history"` (bool) - whether to clear the chat history with the agent.
                Default is True.
                - `"silent"` (bool or None) - (Experimental) whether to print the messages in this
                conversation. Default is False.
                - `"cache"` (Cache or None) - the cache client to use for this conversation.
                Default is None.
                - `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat
                will continue until a termination condition is met. Default is None.
                - `"summary_method"` (str or callable) - a string or callable specifying the method to get
                a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
                - `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method.
                Default is {}.
                - `"message"` (str, callable or None) - if None, input() will be called to get the
                initial message.
                - `**context` - additional context information to be passed to the chat.
                - `"carryover"` - It can be used to specify the carryover information to be passed
                to this chat. If provided, we will combine this carryover with the "message" content when
                generating the initial chat message in `generate_init_message`.
                - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list,
                from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list,
                then summary from all the finished chats will be taken.

    Returns:
        (list): a list of ChatResult objects corresponding to the finished chats in the chat_queue.
    """
    consolidate_chat_info(chat_queue)
    _validate_recipients(chat_queue)
    current_chat_queue = chat_queue.copy()
    finished_chats = []
    while current_chat_queue:
        chat_info = current_chat_queue.pop(0)
        _chat_carryover = chat_info.get("carryover", [])
        finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
            "finished_chat_indexes_to_exclude_from_carryover", []
        )

        if isinstance(_chat_carryover, str):
            _chat_carryover = [_chat_carryover]
        chat_info["carryover"] = _chat_carryover + [
            r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover
        ]

        if not chat_info.get("silent", False):
            __post_carryover_processing(chat_info)

        sender = chat_info["sender"]
        chat_res = sender.initiate_chat(**chat_info)
        finished_chats.append(chat_res)
    return finished_chats


def __system_now_str():
    ct = datetime.datetime.now()
    return f" System time at {ct}. "


def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int):
    """Update ChatResult when async Task for Chat is completed."""
    logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str())
    chat_result = chat_future.result()
    chat_result.chat_id = chat_id


async def _dependent_chat_future(
    chat_id: int, chat_info: dict[str, Any], prerequisite_chat_futures: dict[int, asyncio.Future]
) -> asyncio.Task:
    """Create an async Task for each chat."""
    logger.debug(f"Create Task for chat {chat_id}." + __system_now_str())
    _chat_carryover = chat_info.get("carryover", [])
    finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
        "finished_chat_indexes_to_exclude_from_carryover", []
    )
    finished_chats = dict()
    for chat in prerequisite_chat_futures:
        chat_future = prerequisite_chat_futures[chat]
        if chat_future.cancelled():
            raise RuntimeError(f"Chat {chat} is cancelled.")

        # wait for prerequisite chat results for the new chat carryover
        finished_chats[chat] = await chat_future

    if isinstance(_chat_carryover, str):
        _chat_carryover = [_chat_carryover]
    data = [
        chat_result.summary
        for chat_id, chat_result in finished_chats.items()
        if chat_id not in finished_chat_indexes_to_exclude_from_carryover
    ]
    chat_info["carryover"] = _chat_carryover + data
    if not chat_info.get("silent", False):
        __post_carryover_processing(chat_info)

    sender = chat_info["sender"]
    chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
    call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
    chat_res_future.add_done_callback(call_back_with_args)
    logger.debug(f"Task for chat {chat_id} created." + __system_now_str())
    return chat_res_future


async def a_initiate_chats(chat_queue: list[dict[str, Any]]) -> dict[int, ChatResult]:
    """(async) Initiate a list of chats.

    Args:
        chat_queue (List[Dict]): A list of dictionaries containing the information about the chats.

            Each dictionary should contain the input arguments for
            [`ConversableAgent.initiate_chat`](../../../ConversableAgent#initiate-chat).
            For example:
                - `"sender"` - the sender agent.
                - `"recipient"` - the recipient agent.
                - `"clear_history"` (bool) - whether to clear the chat history with the agent.
                Default is True.
                - `"silent"` (bool or None) - (Experimental) whether to print the messages in this
                conversation. Default is False.
                - `"cache"` (Cache or None) - the cache client to use for this conversation.
                Default is None.
                - `"max_turns"` (int or None) - maximum number of turns for the chat. If None, the chat
                will continue until a termination condition is met. Default is None.
                - `"summary_method"` (str or callable) - a string or callable specifying the method to get
                a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
                - `"summary_args"` (dict) - a dictionary of arguments to be passed to the summary_method.
                Default is {}.
                - `"message"` (str, callable or None) - if None, input() will be called to get the
                initial message.
                - `**context` - additional context information to be passed to the chat.
                - `"carryover"` - It can be used to specify the carryover information to be passed
                to this chat. If provided, we will combine this carryover with the "message" content when
                generating the initial chat message in `generate_init_message`.
                - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list,
                from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list,
                then summary from all the finished chats will be taken.


    Returns:
        - (Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
    """
    consolidate_chat_info(chat_queue)
    _validate_recipients(chat_queue)
    chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
    num_chats = chat_book.keys()
    prerequisites = __create_async_prerequisites(chat_queue)
    chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
    finished_chat_futures = dict()
    for chat_id in chat_order_by_id:
        chat_info = chat_book[chat_id]
        prerequisite_chat_ids = chat_info.get("prerequisites", [])
        pre_chat_futures = dict()
        for pre_chat_id in prerequisite_chat_ids:
            pre_chat_future = finished_chat_futures[pre_chat_id]
            pre_chat_futures[pre_chat_id] = pre_chat_future
        current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures)
        finished_chat_futures[chat_id] = current_chat_future
    await asyncio.gather(*list(finished_chat_futures.values()))
    finished_chats = dict()
    for chat in finished_chat_futures:
        chat_result = finished_chat_futures[chat].result()
        finished_chats[chat] = chat_result
    return finished_chats
