# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from  https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT

import logging
from typing import Optional

from .agentchat import Agent
from .import_utils import optional_import_block, require_optional_import

with optional_import_block():
    import matplotlib.pyplot as plt
    import networkx as nx


def has_self_loops(allowed_speaker_transitions: dict[str, list[Agent]]) -> bool:
    """Check if there are self loops in the allowed_speaker_transitions.

    Args:
        allowed_speaker_transitions (dict[str, list[Agent]]): A dictionary of keys and list as values. The keys are the names of the agents, and the values are the names of the agents that the key agent can transition

    Returns:
        True if there are self loops in the allowed_speaker_transitions_Dict.
    """

    return any([key in value for key, value in allowed_speaker_transitions.items()])


def check_graph_validity(
    allowed_speaker_transitions_dict: dict[str, list[Agent]],
    agents: list[Agent],
) -> None:
    """Check the validity of the allowed_speaker_transitions_dict.

    Args:
        allowed_speaker_transitions_dict (dict[str, list[Agent]]):
            A dictionary of keys and list as values. The keys are the names of the agents, and the values are the names of the agents that the key agent can transition to.
        agents (list[Agent]): A list of Agents

    agents: A list of Agents

    Checks for the following:
        Errors
        1. The dictionary must have a structure of keys and list as values
        2. Every key exists in agents.
        3. Every value is a list of Agents (not string).

    Warnings:
        1. Warning if there are isolated agent nodes
        2. Warning if the set of agents in allowed_speaker_transitions do not match agents
        3. Warning if there are duplicated agents in any values of `allowed_speaker_transitions_dict`
    """
    # Errors

    # Check 1. The dictionary must have a structure of keys and list as values
    if not isinstance(allowed_speaker_transitions_dict, dict):
        raise ValueError("allowed_speaker_transitions_dict must be a dictionary.")

    # All values must be lists of Agent or empty
    if not all([isinstance(value, list) for value in allowed_speaker_transitions_dict.values()]):
        raise ValueError("allowed_speaker_transitions_dict must be a dictionary with lists as values.")

    # Check 2. Every key exists in agents
    if not all([key in agents for key in allowed_speaker_transitions_dict]):
        raise ValueError("allowed_speaker_transitions_dict has keys not in agents.")

    # Check 3. Every value is a list of Agents or empty list (not string).
    if not all([
        all([isinstance(agent, Agent) for agent in value]) for value in allowed_speaker_transitions_dict.values()
    ]):
        raise ValueError("allowed_speaker_transitions_dict has values that are not lists of Agents.")

    # Warnings
    # Warning 1. Warning if there are isolated agent nodes, there are not incoming nor outgoing edges
    # Concat keys if len(value) is positive
    has_outgoing_edge = []
    for key, agent_list in allowed_speaker_transitions_dict.items():
        if len(agent_list) > 0:
            has_outgoing_edge.append(key)
    no_outgoing_edges = [agent for agent in agents if agent not in has_outgoing_edge]

    # allowed_speaker_transitions_dict.values() is a list of list of Agents
    # values_all_agents is a list of all agents in allowed_speaker_transitions_dict.values()
    has_incoming_edge = []
    for agent_list in allowed_speaker_transitions_dict.values():
        if len(agent_list) > 0:
            has_incoming_edge.extend(agent_list)

    no_incoming_edges = [agent for agent in agents if agent not in has_incoming_edge]

    isolated_agents = set(no_incoming_edges).intersection(set(no_outgoing_edges))
    if len(isolated_agents) > 0:
        logging.warning(
            f"""Warning: There are isolated agent nodes, there are not incoming nor outgoing edges. Isolated agents: {[agent.name for agent in isolated_agents]}"""
        )

    # Warning 2. Warning if the set of agents in allowed_speaker_transitions do not match agents
    # Get set of agents
    agents_in_allowed_speaker_transitions = set(has_incoming_edge).union(set(has_outgoing_edge))
    full_anti_join = set(agents_in_allowed_speaker_transitions).symmetric_difference(set(agents))
    if len(full_anti_join) > 0:
        logging.warning(
            f"""Warning: The set of agents in allowed_speaker_transitions do not match agents. Offending agents: {[agent.name for agent in full_anti_join]}"""
        )

    # Warning 3. Warning if there are duplicated agents in any values of `allowed_speaker_transitions_dict`
    for key, values in allowed_speaker_transitions_dict.items():
        duplicates = [item for item in values if values.count(item) > 1]
        unique_duplicates = list(set(duplicates))
        if unique_duplicates:
            logging.warning(
                f"Agent '{key.name}' has duplicate elements: {[agent.name for agent in unique_duplicates]}. Please remove duplicates manually."
            )


def invert_disallowed_to_allowed(
    disallowed_speaker_transitions_dict: dict[str, list[Agent]], agents: list[Agent]
) -> dict[str, list[Agent]]:
    """Invert the disallowed_speaker_transitions_dict to form the allowed_speaker_transitions_dict.

    Start with a fully connected allowed_speaker_transitions_dict of all agents. Remove edges from the fully connected allowed_speaker_transitions_dict according to the disallowed_speaker_transitions_dict to form the allowed_speaker_transitions_dict.

    Args:
        disallowed_speaker_transitions_dict: A dictionary of keys and list as values. The keys are the names of the agents, and the values are the names of the agents that the key agent cannot transition to.
        agents: A list of Agents

    Returns:
        allowed_speaker_transitions_dict: A dictionary of keys and list as values. The keys are the names of the agents, and the values are the names of the agents that the key agent can transition to.
    """
    # Create a fully connected allowed_speaker_transitions_dict of all agents
    allowed_speaker_transitions_dict = {agent: [other_agent for other_agent in agents] for agent in agents}

    # Remove edges from allowed_speaker_transitions_dict according to the disallowed_speaker_transitions_dict
    for key, value in disallowed_speaker_transitions_dict.items():
        allowed_speaker_transitions_dict[key] = [
            agent for agent in allowed_speaker_transitions_dict[key] if agent not in value
        ]

    return allowed_speaker_transitions_dict


@require_optional_import(["matplotlib", "networkx"], "graph")
def visualize_speaker_transitions_dict(
    speaker_transitions_dict: dict[str, list[Agent]], agents: list[Agent], export_path: Optional[str] = None
) -> None:
    """Visualize the speaker_transitions_dict using networkx.

    Args:
        speaker_transitions_dict: A dictionary of keys and list as values. The keys are the names of the agents, and the values are the names of the agents that the key agent can transition to.
        agents: A list of Agents
        export_path: The path to export the graph. If None, the graph will be shown.

    Returns:
        None


    """

    g = nx.DiGraph()

    # Add nodes
    g.add_nodes_from([agent.name for agent in agents])

    # Add edges
    for key, value in speaker_transitions_dict.items():
        for agent in value:
            g.add_edge(key.name, agent.name)

    # Visualize
    nx.draw(g, with_labels=True, font_weight="bold")

    if export_path is not None:
        plt.savefig(export_path)
    else:
        plt.show()
