"""Utility functions for handling parameter masks in federated learning."""
from collections.abc import Callable
from enum import Enum, auto

import numpy as np
from flwr.common.typing import NDArrays

from repo.conf.base_schema import BaseConfig
from repo.shm.utils import compress_with_strict


class ModelStateNames(str, Enum):  # noqa: UP042
    """Enum for model state names."""

    @staticmethod
    def _generate_next_value_(
        name: str,
        start: int,  # noqa: ARG004
        count: int,  # noqa: ARG004
        last_values: list[int],  # noqa: ARG004
    ) -> str:
        """Generate the next value.

        Replacement for StrEnum to support python 3.10

        Parameters
        ----------
        name: str
            Name of the strategy
        start: int
            Start value
        count: int
            Count value
        last_values: list[int]
            List of last values

        Returns
        -------
        str
            Lowercase name

        """
        return name

    ALL = auto()
    PARAMETERS = auto()
    EXP_AVG = auto()
    EXP_AVG_SQ = auto()


def mask_to_batches(
    full_mask: tuple[tuple[bool, ...], list[str], list[str]],
    layer_names_and_types: tuple[tuple[str, ModelStateNames], ...],
    n_batches: int,
) -> list[tuple[tuple[bool, ...], list[str], list[str]]]:
    """Split a full parameter mask into multiple batches.

    This function takes a full parameter mask and divides it into a specified number] of
    smaller batch masks. It distributes the True values (parameters to be included)
    across the batches as evenly as possible. This is useful for processing large
    parameter sets in smaller chunks, such as when transmitting model parameters over a
    network with limited bandwidth.

    Parameters
    ----------
    full_mask : tuple[tuple[bool, ...], list[str], list[str]]
        The full mask tuple containing:
        - A tuple of boolean values representing the mask
        - A list of layer names corresponding to True values in the mask
        - A list of model state type values corresponding to True values in the mask
    layer_names_and_types : tuple[tuple[str, ModelStateNames], ...]
        A tuple of tuples where each inner tuple contains a layer name and its
        corresponding model state type (PARAMETERS, EXP_AVG, EXP_AVG_SQ).
    n_batches : int
        The target number of batches to split the mask into. The actual number may be
        adjusted if the True values cannot be evenly distributed.

    Returns
    -------
    list[tuple[tuple[bool, ...], list[str], list[str]]]
        A list of mask tuples, each containing:
        - A tuple of boolean values representing a batch mask
        - A list of layer names corresponding to True values in the batch mask
        - A list of model state type values corresponding to True values in the batch
        mask

    Notes
    -----
    - If the True values cannot be evenly distributed, the function will adjust the
      number of batches to ensure all parameters are included.
    - The last batch may have fewer True values if there's a remainder after division.

    """
    # Get the boolean mask
    boolean_mask = full_mask[0]

    # Count the number of True values in the mask
    n_true = np.count_nonzero(boolean_mask)

    # Calculate the number of True values in each mask in the batch rounding to the next
    # integer
    batch_size = (n_true // n_batches) + (n_true % n_batches > 0)

    # Find indices of True values in the original mask
    true_indices = [i for i, value in enumerate(boolean_mask) if value]

    # Batch the mask based on the number of True values
    masks: list[list[bool]] = []
    start_idx = 0
    for i in range(n_batches):
        # Create a new mask initially all False
        mask = [False] * len(layer_names_and_types)

        # Calculate how many True values should be in this batch, the last batch may be
        # incomplete if the fraction is not exact
        batches_created = i
        current_batch_size = min(batch_size, n_true - (batches_created) * batch_size)

        # Set the appropriate indices to True
        end_idx = min(start_idx + current_batch_size, len(true_indices))
        for idx in true_indices[start_idx:end_idx]:
            mask[idx] = True

        # Add the mask to our list of masks
        masks.append(mask)

        # Update the start index for the next batch
        start_idx = end_idx

    return [
        (
            tuple(mask),
            [
                name
                for name, _ in [
                    (k, v)
                    for k, v in compress_with_strict(
                        layer_names_and_types,
                        mask,
                        strict=True,
                    )
                ]
            ],
            [
                t.value
                for _, t in [
                    (k, v)
                    for k, v in compress_with_strict(
                        layer_names_and_types,
                        mask,
                        strict=True,
                    )
                ]
            ],
        )
        for mask in masks
    ]


def combine_masks(
    mask_a: tuple[tuple[bool, ...], list[str], list[str]],
    mask_b: tuple[tuple[bool, ...], list[str], list[str]],
    layer_names_and_types: tuple[tuple[str, ModelStateNames], ...],
) -> tuple[tuple[bool, ...], list[str], list[str]]:
    """Combine two parameter mask tuples using logical OR operations.

    This function takes two mask tuples and combines them by performing a logical OR
    operation on their boolean values. It then uses the combined boolean mask to filter
    the provided layer names and types, creating a unified mask that includes parameters
    present in either of the input masks.

    Parameters
    ----------
    mask_a : tuple[tuple[bool, ...], list[str], list[str]]
        The first mask tuple containing:
        - A tuple of boolean values representing the first mask
        - A list of layer names corresponding to True values in the first mask
        - A list of model state type values corresponding to True values in the first
        mask
    mask_b : tuple[tuple[bool, ...], list[str], list[str]]
        The second mask tuple containing:
        - A tuple of boolean values representing the second mask
        - A list of layer names corresponding to True values in the second mask
        - A list of model state type values corresponding to True values in the second
        mask
    layer_names_and_types : tuple[tuple[str, ModelStateNames], ...]
        A tuple of tuples where each inner tuple contains a layer name and its
        corresponding model state type (PARAMETERS, EXP_AVG, EXP_AVG_SQ).

    Returns
    -------
    tuple[tuple[bool, ...], list[str], list[str]]
        A tuple containing:
        - A tuple of boolean values representing the combined mask
        - A list of layer names corresponding to True values in the combined mask
        - A list of model state type values corresponding to True values in the combined
        mask

    """
    # Get the boolean masks
    mask_a_values = mask_a[0]
    mask_b_values = mask_b[0]
    # Combine the masks using logical OR
    merged_mask = [a or b for a, b in zip(mask_a_values, mask_b_values, strict=True)]
    # Get the names and types of the layers
    merged_names_and_types = [
        (k, v)
        for k, v in compress_with_strict(
            layer_names_and_types,
            merged_mask,
            strict=True,
        )
    ]
    # Create the merged mask
    return (
        tuple(merged_mask),
        [name for name, _ in merged_names_and_types],
        [t.value for _, t in merged_names_and_types],
    )


def generate_full_mask(
    layer_names_and_types: tuple[tuple[str, ModelStateNames], ...],
) -> tuple[tuple[bool, ...], list[str], list[str]]:
    """Generate a mask that includes all parameters.

    This function creates a mask where all parameters are included (all True values).
    It's used when you want to process all parameters without any filtering.

    Parameters
    ----------
    layer_names_and_types : tuple[tuple[str, ModelStateNames], ...]
        A tuple of tuples where each inner tuple contains a layer name and its
        corresponding model state type (PARAMETERS, EXP_AVG, EXP_AVG_SQ).

    Returns
    -------
    tuple[tuple[bool, ...], list[str], list[str]]
        A tuple containing:
        - A tuple of boolean values (all True) representing the mask
        - A list of layer names extracted from layer_names_and_types
        - A list of state type values extracted from layer_names_and_types

    """
    # Generate a mask for all parameters
    mask = [True] * len(layer_names_and_types)
    # Return the mask, names, and types
    return (
        tuple(mask),
        [name for name, _ in layer_names_and_types],
        [t.value for _, t in layer_names_and_types],
    )


def generate_empty_mask(
    layer_names_and_types: tuple[tuple[str, ModelStateNames], ...],
) -> tuple[tuple[bool, ...], list[str], list[str]]:
    """Generate a mask that excludes all parameters.

    This function creates a mask where all parameters are excluded (all False values).
    It's used when you want to create an empty mask as a starting point, before
    potentially adding specific parameters.

    Parameters
    ----------
    layer_names_and_types : tuple[tuple[str, ModelStateNames], ...]
        A tuple of tuples where each inner tuple contains a layer name and its
        corresponding model state type (PARAMETERS, EXP_AVG, EXP_AVG_SQ).

    Returns
    -------
    tuple[tuple[bool, ...], list[str], list[str]]
        A tuple containing:
        - A tuple of boolean values (all False) representing the mask
        - A list of layer names extracted from layer_names_and_types
        - A list of state type values extracted from layer_names_and_types

    """
    # Generate a mask for all parameters
    mask = [False] * len(layer_names_and_types)
    # Return the mask, names, and types
    return (
        tuple(mask),
        [name for name, _ in layer_names_and_types],
        [t.value for _, t in layer_names_and_types],
    )


def generate_mask(
    layer_names_and_types: tuple[tuple[str, ModelStateNames], ...],
    scheduler: Callable[[str | int, int], list[ModelStateNames]],
    current_cid: int | str,
    server_round: int,
) -> tuple[tuple[bool, ...], list[str], list[str]]:
    """Generate a mask for filtering parameters based on target keys and client IDs.

    Parameters
    ----------
    layer_names_and_types : list[tuple[str, str]]
        A list of tuples where each tuple contains a parameter identifier and its key.
    scheduler : dict[int, list[Literal['parameters', 'all', 'exp_avg', 'exp_avg_sq']]]
        A dictionary mapping server rounds to lists of target keys for filtering.
        The keys can include 'parameters', 'all', 'exp_avg', and 'exp_avg_sq'.
    previous_cids : list[int]
        A list of client IDs that have participated in previous rounds.
    current_cid : int | str
        The ID of the current client.
    server_round : int
        The current server round.

    Returns
    -------
    tuple[tuple[bool,...], list[str]]
        A tuple containing:
        - A list of booleans indicating whether each parameter matches the target keys.
        - A list of target keys used for filtering.

    Notes
    -----
    - If the current client ID is not in the list of previous client IDs, all parameters
      are included in the mask.
    - The function ensures that the scheduler values only contain recognized keys.

    """
    possible_model_states = scheduler(current_cid, server_round)

    all_possible_model_states = [
        ModelStateNames.PARAMETERS,
        ModelStateNames.EXP_AVG,
        ModelStateNames.EXP_AVG_SQ,
    ]

    possible_model_states = (
        all_possible_model_states
        if possible_model_states == [ModelStateNames.ALL]
        else possible_model_states
    )

    mask = [True] * len(layer_names_and_types)
    current_cid = int(current_cid)

    for i, (_, key) in enumerate(layer_names_and_types):
        if key not in possible_model_states:
            mask[i] = False

    transmitted_param_keys = [
        (k, v)
        for k, v in compress_with_strict(layer_names_and_types, mask, strict=True)
    ]
    return (
        tuple(mask),
        [name for name, _ in transmitted_param_keys],
        [t.value for _, t in transmitted_param_keys],
    )


def reconcile_model_state_with_scheduler(
    model_states: NDArrays,
    layer_names_and_types: tuple[tuple[str, ModelStateNames], ...],
    cfg: BaseConfig,
) -> tuple[NDArrays, tuple[tuple[str, ModelStateNames], ...]]:
    """Filter model states based on parameter scheduling configuration.

    This function selectively processes model state tensors based on the parameter
    scheduler configuration in the provided configuration object. It identifies
    which layer types are scheduled for synchronization and creates a filtered
    version of the model states where tensors corresponding to unscheduled layer
    types are replaced with dummy tensors (single zero values).

    Parameters
    ----------
    model_states : NDArrays
        The complete set of model state tensors representing the model parameters.
    layer_names_and_types : tuple[tuple[str, ModelStateNames], ...]
        A tuple of tuples, each containing a layer name and its corresponding layer
        type. Each tuple maps to a tensor in model_states at the same index position.
    cfg : BaseConfig
        Configuration object containing parameter scheduler settings, specifically
        the 'fl.parameter_scheduler_kwargs' which maps layer types to their
        synchronization frequencies.

    Returns
    -------
    tuple[NDArrays, tuple[tuple[str, ModelStateNames], ...]]
        A tuple containing:
        - The filtered model state tensors with dummy values for non-scheduled layers
        - The unchanged layer_names_and_types tuple structure

    Notes
    -----
    The function preserves the structure of both the model state tensors and the layer
    name/type mappings while selectively filtering content. Layers not included in the
    parameter scheduler configuration are effectively excluded from synchronization by
    replacing their tensors with dummy values, which reduces communication overhead
    during federated learning.

    """
    # Extract the unique types of layer that are scheduled for synchronization
    scheduled_layer_types: list[str] = [
        layer_type
        for layer_type, _frequency in cfg.fl.parameter_scheduler_kwargs.items()
    ]
    unique_scheduled_layer_types = set(scheduled_layer_types)
    # Clean up the model states and layer types and names of the server state
    list_layer_names_and_type = list(layer_names_and_types)
    trimmed_model_states_ndarrays = []
    for model_state_ndarray, layer_name_and_type in zip(
        model_states,
        list_layer_names_and_type,
        strict=True,
    ):
        layer_type = layer_name_and_type[1].value
        if layer_type not in unique_scheduled_layer_types:
            trimmed_model_states_ndarrays.append(np.array([0.0]))
        else:
            trimmed_model_states_ndarrays.append(model_state_ndarray)
    return trimmed_model_states_ndarrays, tuple(list_layer_names_and_type)
