"""Load balancer for clients and nodes used by the server."""

import random
from collections import defaultdict
from logging import DEBUG, WARNING

from flwr.common import log
from flwr.server import Driver


class ServerLoadBalancer:
    """Load balancer for clients and nodes used by the server."""

    def __init__(
        self,
        driver: Driver,
        total_number_of_clients: int,
        rng: random.Random,
        *,
        is_production: bool,
    ) -> None:
        """Initialize the server load balancer with node and client state tracking.

        This load balancer is responsible for keeping track of connected nodes, their
        states, and which clients are assigned to each node. It also ensures new nodes
        are registered, dropped nodes are removed, and calls a health-check upon
        initialization.

        Parameters
        ----------
        driver : Driver
            The Flower Driver instance for interacting with server nodes.
        total_number_of_clients : int
            The total number of clients in the FL system.
        rng : random.Random
            Random generator for node/client selection.
        is_production : bool
            Whether the system is running in production mode.

        """
        self.driver = driver
        self.total_number_of_clients = total_number_of_clients
        self.rng = rng
        self.is_production = is_production
        self.nodes_states: dict[int, bool] = {}
        self.nodes_to_fake_ids: dict[int, int] = {}
        self.historical_node_counter = 0
        self.node_id_to_client_ids: dict[int, list[int]] = defaultdict(list)

        self.check_nodes_health()

    def set_all_nodes_states(self, *, state: bool) -> None:
        """Set the state for all nodes registered in the driver.

        This function updates the internal state tracking for all registered nodes to
        the  provided state value. It iterates through all nodes known to the driver and
        updates their state in the nodes_states dictionary, logging each update at the
        DEBUG level.

        Parameters
        ----------
        state : bool
            The state to set for all nodes. True typically indicates a node is available
            and ready to accept work, while False indicates the node is unavailable.

        """
        for node_id in self.nodes_states:
            self.nodes_states[node_id] = state
            log(DEBUG, f"Node {node_id} state set to {state}.")

    def check_nodes_health(self) -> None:
        """Check the health status of all nodes and update internal tracking.

        This function checks which nodes are currently available by querying the driver,
        adds any newly detected nodes to the internal tracking dictionaries, and removes
        any nodes that are no longer available. For new nodes, it initializes them with
        a disabled state (False) and assigns a historical ID for tracking. For dropped
        nodes, it cleans up all associated tracking data and logs a warning message.

        """
        all_old_nodes = self.nodes_states.keys()
        healthy_old_nodes: list[int] = []
        # Check if the nodes are available
        for node_id in self.driver.get_node_ids():
            healthy_old_nodes.append(node_id)
            if node_id not in self.nodes_states:
                self.nodes_states[node_id] = False
                self.nodes_to_fake_ids[node_id] = self.historical_node_counter
                self.historical_node_counter += 1
                log(DEBUG, f"New node {node_id} added.")
        # Check for dropped nodes
        dropped_nodes = set(all_old_nodes) - set(healthy_old_nodes)
        if dropped_nodes:
            for node_id in dropped_nodes:
                self.nodes_states.pop(node_id)
                self.nodes_to_fake_ids.pop(node_id)
                self.node_id_to_client_ids.pop(node_id)
                # Log the dropped nodes
                log(WARNING, f"Node {node_id} is not available anymore.")

    def get_available_nodes(self) -> list[int]:
        """Get a list of available node IDs.

        This function filters the registered nodes based on their state and returns only
        the node IDs that are currently available (have their state set to True). It
        iterates through the nodes_states dictionary and collects the IDs of nodes that
        are marked as available.

        Returns
        -------
        list[int]
            A list containing the IDs of all nodes that are currently available for
            accepting work.

        """
        return [node_id for node_id, state in self.nodes_states.items() if state]

    def get_registered_nodes(self) -> list[int]:
        """Get a list of all registered node IDs.

        This function returns all node IDs that are currently registered in the system,
        regardless of their availability state. It returns all keys from the
        nodes_states dictionary, which includes both available and unavailable nodes.

        Returns
        -------
        list[int]
            A list containing the IDs of all nodes that are currently registered in the
            system, regardless of their availability state.

        """
        return [node_id for node_id, _state in self.nodes_states.items()]

    def assign_clients_to_nodes(self, sampled_clients: list[int]) -> bool:
        """Assign sampled clients to available nodes for processing.

        This function manages the assignment of clients to available processing nodes.
        It first removes any previously assigned clients that are no longer in the
        current sample, then identifies any sampled clients that haven't been assigned
        to a node. New client assignments are made to balance the load across nodes,
        giving preference to nodes with fewer assigned clients. Finally, it attempts to
        perfectly balance the client distribution across all nodes.

        Parameters
        ----------
        sampled_clients : list[int]
            List of client IDs that have been selected for the current round of
            processing.

        Returns
        -------
        bool
            Whether the client assignments could be perfectly balanced across all nodes.
            True indicates that each node has exactly the same number of clients
            assigned.

        """
        # Pop out the client ids that are currently assigned to the nodes and that are
        # not sampled in the current round
        for available_node_id in self.get_available_nodes():
            for cid in self.node_id_to_client_ids[available_node_id]:
                if cid not in sampled_clients:
                    self.node_id_to_client_ids[available_node_id].remove(cid)
                    log(DEBUG, f"Client {cid} removed from node {available_node_id}.")
        # Check which cids are currently assigned to nodes
        currently_assigned_cids = {
            cid
            for node_id in self.node_id_to_client_ids
            for cid in self.node_id_to_client_ids[node_id]
        }
        # Check the cids that are not assigned to any node
        cid_to_assign = set(sampled_clients) - currently_assigned_cids
        # Iteratively assign the cids to the node with the least number of clients
        for cid in cid_to_assign:
            # Get the node with the least number of clients
            min_node_id = min(
                self.node_id_to_client_ids,
                key=lambda x: len(self.node_id_to_client_ids[x]),
            )
            # Assign the cid to the node
            self.node_id_to_client_ids[min_node_id].append(cid)
            log(DEBUG, f"Client {cid} assigned to node {min_node_id}.")
        # Rebalance the nodes
        return self.balance_nodes_assignments()

    def balance_nodes_assignments(self) -> bool:
        """Balance client assignments across available nodes.

        This function attempts to evenly distribute client assignments across all
        registered nodes. It first determines if a perfect balance is mathematically
        possible by checking if the total number of clients is divisible by the number
        of nodes. Then it iteratively moves clients from overloaded nodes to underloaded
        nodes until either a perfect balance is achieved or the best possible
        distribution is reached given the constraints.

        Returns
        -------
        bool
            Whether a perfect balance is mathematically possible. True indicates that
            the total number of clients is divisible by the number of nodes, making it
            possible for each node to have exactly the same number of clients.

        """
        # Loop until all nodes have the same number of clients looking at the
        # `self.node_id_to_client_ids`
        total_number_of_assigned_clients = sum(
            len(client_ids) for client_ids in self.node_id_to_client_ids.values()
        )
        is_number_of_clients_divisible = (
            total_number_of_assigned_clients % len(self.node_id_to_client_ids) == 0
        )
        while is_number_of_clients_divisible:
            # Get the number of clients assigned to each node
            nodes_clients_count = {
                node_id: len(client_ids)
                for node_id, client_ids in self.node_id_to_client_ids.items()
            }
            # Get the maximum and minimum number of clients assigned to a node
            max_clients = max(nodes_clients_count.values())
            min_clients = min(nodes_clients_count.values())
            # If all nodes have the same number of clients, break the loop
            if max_clients == min_clients:
                break
            # Find the nodes with the maximum and minimum number of clients
            max_nodes = [
                node_id
                for node_id, count in nodes_clients_count.items()
                if count == max_clients
            ]
            min_nodes = [
                node_id
                for node_id, count in nodes_clients_count.items()
                if count == min_clients
            ]
            # Move a client from a max node to a min node
            client_to_move = self.node_id_to_client_ids[max_nodes[0]].pop()
            self.node_id_to_client_ids[min_nodes[0]].append(client_to_move)
        return is_number_of_clients_divisible
