"""Utility functions for strategies."""

from repo.server.utils import ServerState
from repo.strategy.fedadam import FedAdam
from repo.strategy.fedavg_eff import FedAvgEfficient
from repo.strategy.fedmom import FedMom
from repo.strategy.fednestorov import FedNesterov
from repo.strategy.fedyogi import FedYogi


def initialize_strategy(
    server_state: ServerState,
) -> None:
    """Initialize strategy object with proper attributes.

    This function prepares the strategy object by setting necessary attributes
    from the server state. It clears initial parameters references, assigns
    model parameters, and initializes momentum vectors for strategies that
    require them.

    Parameters
    ----------
    server_state : ServerState
        The server state containing the strategy to initialize and the necessary
        attributes like model states and momentum vectors.

    """
    # NOTE: The strategy needs to hold a copy of the initial parameters, but we want
    # it to free the reference it holds as an attribute. Then, similarly to the
    # `initialize_parameters()` of FedAvg, we nullify such attribute
    if server_state.strategy.initial_parameters:
        server_state.strategy.initial_parameters = None
    # NOTE: Since we initialized the strategy object before creating the parameters,
    # we must assign to the strategy attributes the parameters we got from the
    # initialization
    if isinstance(
        server_state.strategy,
        FedNesterov | FedMom | FedYogi | FedAdam | FedAvgEfficient,
    ):
        server_state.strategy.parameters = server_state.model_states
    if isinstance(server_state.strategy, FedNesterov | FedMom | FedYogi | FedAdam):
        assert (
            server_state.momentum_vector is not None
        ), "Momentum vector must be initialized"
        server_state.strategy.momentum_vector = server_state.momentum_vector
    else:
        server_state.momentum_vector = None
    if isinstance(server_state.strategy, FedYogi | FedAdam):
        assert (
            server_state.second_momentum_vector is not None
        ), "Second momentum vector must be initialized"
        server_state.strategy.second_momentum_vector = (
            server_state.second_momentum_vector
        )
    else:
        server_state.second_momentum_vector = None

    server_state.strategy.layer_names_and_types = server_state.layer_names_and_types
