# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Workflow for the SecAgg+ protocol."""


import random
from dataclasses import dataclass, field
from logging import DEBUG, ERROR, INFO, WARN
from typing import Dict, List, Optional, Set, Tuple, Union, cast

import flwr.common.recordset_compat as compat
from flwr.common import (
    ConfigsRecord,
    Context,
    FitRes,
    Message,
    MessageType,
    NDArrays,
    RecordSet,
    bytes_to_ndarray,
    log,
    ndarrays_to_parameters,
)
from flwr.common.secure_aggregation.crypto.shamir import combine_shares
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
    bytes_to_private_key,
    bytes_to_public_key,
    generate_shared_key,
)
from flwr.common.secure_aggregation.ndarrays_arithmetic import (
    factor_extract,
    get_parameters_shape,
    parameters_addition,
    parameters_mod,
    parameters_subtraction,
)
from flwr.common.secure_aggregation.quantization import dequantize
from flwr.common.secure_aggregation.secaggplus_constants import (
    RECORD_KEY_CONFIGS,
    Key,
    Stage,
)
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
from flwr.server.client_proxy import ClientProxy
from flwr.server.compat.legacy_context import LegacyContext
from flwr.server.driver import Driver

from ..constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD
from ..constant import Key as WorkflowKey


@dataclass
class WorkflowState:  # pylint: disable=R0902
    """The state of the SecAgg+ protocol."""

    nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict)
    nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
    sampled_node_ids: Set[int] = field(default_factory=set)
    active_node_ids: Set[int] = field(default_factory=set)
    num_shares: int = 0
    threshold: int = 0
    clipping_range: float = 0.0
    quantization_range: int = 0
    mod_range: int = 0
    max_weight: float = 0.0
    nid_to_neighbours: Dict[int, Set[int]] = field(default_factory=dict)
    nid_to_publickeys: Dict[int, List[bytes]] = field(default_factory=dict)
    forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
    forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
    aggregate_ndarrays: NDArrays = field(default_factory=list)
    legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)


class SecAggPlusWorkflow:
    """The workflow for the SecAgg+ protocol.

    The SecAgg+ protocol ensures the secure summation of integer vectors owned by
    multiple parties, without accessing any individual integer vector. This workflow
    allows the server to compute the weighted average of model parameters across all
    clients, ensuring individual contributions remain private. This is achieved by
    clients sending both, a weighting factor and a weighted version of the locally
    updated parameters, both of which are masked for privacy. Specifically, each
    client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
    number of examples ('num_examples') and 'params' represents the model parameters
    ('parameters') from the client's `FitRes`. The server then aggregates these
    contributions to compute the weighted average of model parameters.

    The protocol involves four main stages:
    - 'setup': Send SecAgg+ configuration to clients and collect their public keys.
    - 'share keys': Broadcast public keys among clients and collect encrypted secret
      key shares.
    - 'collect masked vectors': Forward encrypted secret key shares to target clients
      and collect masked model parameters.
    - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.

    Only the aggregated model parameters are exposed and passed to
    `Strategy.aggregate_fit`, ensuring individual data privacy.

    Parameters
    ----------
    num_shares : Union[int, float]
        The number of shares into which each client's private key is split under
        the SecAgg+ protocol. If specified as a float, it represents the proportion
        of all selected clients, and the number of shares will be set dynamically in
        the run time. A private key can be reconstructed from these shares, allowing
        for the secure aggregation of model updates. Each client sends one share to
        each of its neighbors while retaining one.
    reconstruction_threshold : Union[int, float]
        The minimum number of shares required to reconstruct a client's private key,
        or, if specified as a float, it represents the proportion of the total number
        of shares needed for reconstruction. This threshold ensures privacy by allowing
        for the recovery of contributions from dropped clients during aggregation,
        without compromising individual client data.
    max_weight : Optional[float] (default: 1000.0)
        The maximum value of the weight that can be assigned to any single client's
        update during the weighted average calculation on the server side, e.g., in the
        FedAvg algorithm.
    clipping_range : float, optional (default: 8.0)
        The range within which model parameters are clipped before quantization.
        This parameter ensures each model parameter is bounded within
        [-clipping_range, clipping_range], facilitating quantization.
    quantization_range : int, optional (default: 4194304, this equals 2**22)
        The size of the range into which floating-point model parameters are quantized,
        mapping each parameter to an integer in [0, quantization_range-1]. This
        facilitates cryptographic operations on the model updates.
    modulus_range : int, optional (default: 4294967296, this equals 2**32)
        The range of values from which random mask entries are uniformly sampled
        ([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
        Please use 2**n values for `modulus_range` to prevent overflow issues.
    timeout : Optional[float] (default: None)
        The timeout duration in seconds. If specified, the workflow will wait for
        replies for this duration each time. If `None`, there is no time limit and
        the workflow will wait until replies for all messages are received.

    Notes
    -----
    - Generally, higher `num_shares` means more robust to dropouts while increasing the
      computational costs; higher `reconstruction_threshold` means better privacy
      guarantees but less tolerance to dropouts.
    - Too large `max_weight` may compromise the precision of the quantization.
    - `modulus_range` must be 2**n and larger than `quantization_range`.
    - When `num_shares` is a float, it is interpreted as the proportion of all selected
      clients, and hence the number of shares will be determined in the runtime. This
      allows for dynamic adjustment based on the total number of participating clients.
    - Similarly, when `reconstruction_threshold` is a float, it is interpreted as the
      proportion of the number of shares needed for the reconstruction of a private key.
      This feature enables flexibility in setting the security threshold relative to the
      number of distributed shares.
    - `num_shares`, `reconstruction_threshold`, and the quantization parameters
      (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
      balancing privacy, robustness, and efficiency within the SecAgg+ protocol.
    """

    def __init__(  # pylint: disable=R0913
        self,
        num_shares: Union[int, float],
        reconstruction_threshold: Union[int, float],
        *,
        max_weight: float = 1000.0,
        clipping_range: float = 8.0,
        quantization_range: int = 4194304,
        modulus_range: int = 4294967296,
        timeout: Optional[float] = None,
    ) -> None:
        self.num_shares = num_shares
        self.reconstruction_threshold = reconstruction_threshold
        self.max_weight = max_weight
        self.clipping_range = clipping_range
        self.quantization_range = quantization_range
        self.modulus_range = modulus_range
        self.timeout = timeout

        self._check_init_params()

    def __call__(self, driver: Driver, context: Context) -> None:
        """Run the SecAgg+ protocol."""
        if not isinstance(context, LegacyContext):
            raise TypeError(
                f"Expect a LegacyContext, but get {type(context).__name__}."
            )
        state = WorkflowState()

        steps = (
            self.setup_stage,
            self.share_keys_stage,
            self.collect_masked_vectors_stage,
            self.unmask_stage,
        )
        log(INFO, "Secure aggregation commencing.")
        for step in steps:
            if not step(driver, context, state):
                log(INFO, "Secure aggregation halted.")
                return
        log(INFO, "Secure aggregation completed.")

    def _check_init_params(self) -> None:  # pylint: disable=R0912
        # Check `num_shares`
        if not isinstance(self.num_shares, (int, float)):
            raise TypeError("`num_shares` must be of type int or float.")
        if isinstance(self.num_shares, int):
            if self.num_shares == 1:
                self.num_shares = 1.0
            elif self.num_shares <= 2:
                raise ValueError("`num_shares` as an integer must be greater than 2.")
            elif self.num_shares > self.modulus_range / self.quantization_range:
                log(
                    WARN,
                    "A `num_shares` larger than `modulus_range / quantization_range` "
                    "will potentially cause overflow when computing the aggregated "
                    "model parameters.",
                )
        elif self.num_shares <= 0:
            raise ValueError("`num_shares` as a float must be greater than 0.")

        # Check `reconstruction_threshold`
        if not isinstance(self.reconstruction_threshold, (int, float)):
            raise TypeError("`reconstruction_threshold` must be of type int or float.")
        if isinstance(self.reconstruction_threshold, int):
            if self.reconstruction_threshold == 1:
                self.reconstruction_threshold = 1.0
            elif isinstance(self.num_shares, int):
                if self.reconstruction_threshold >= self.num_shares:
                    raise ValueError(
                        "`reconstruction_threshold` must be less than `num_shares`."
                    )
        else:
            if not 0 < self.reconstruction_threshold <= 1:
                raise ValueError(
                    "If `reconstruction_threshold` is a float, "
                    "it must be greater than 0 and less than or equal to 1."
                )

        # Check `max_weight`
        if self.max_weight <= 0:
            raise ValueError("`max_weight` must be greater than 0.")

        # Check `quantization_range`
        if self.quantization_range <= 0:
            raise ValueError("`quantization_range` must be greater than 0.")

        # Check `quantization_range`
        if not isinstance(self.quantization_range, int) or self.quantization_range <= 0:
            raise ValueError(
                "`quantization_range` must be an integer and greater than 0."
            )

        # Check `modulus_range`
        if (
            not isinstance(self.modulus_range, int)
            or self.modulus_range <= self.quantization_range
        ):
            raise ValueError(
                "`modulus_range` must be an integer and "
                "greater than `quantization_range`."
            )
        if bin(self.modulus_range).count("1") != 1:
            raise ValueError("`modulus_range` must be a power of 2.")

    def _check_threshold(self, state: WorkflowState) -> bool:
        for node_id in state.sampled_node_ids:
            active_neighbors = state.nid_to_neighbours[node_id] & state.active_node_ids
            if len(active_neighbors) < state.threshold:
                log(ERROR, "Insufficient available nodes.")
                return False
        return True

    def setup_stage(  # pylint: disable=R0912, R0914, R0915
        self, driver: Driver, context: LegacyContext, state: WorkflowState
    ) -> bool:
        """Execute the 'setup' stage."""
        # Obtain fit instructions
        cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
        current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])
        parameters = compat.parametersrecord_to_parameters(
            context.state.parameters_records[MAIN_PARAMS_RECORD],
            keep_input=True,
        )
        proxy_fitins_lst = context.strategy.configure_fit(
            current_round, parameters, context.client_manager
        )
        if not proxy_fitins_lst:
            log(INFO, "configure_fit: no clients selected, cancel")
            return False
        log(
            INFO,
            "configure_fit: strategy sampled %s clients (out of %s)",
            len(proxy_fitins_lst),
            context.client_manager.num_available(),
        )

        state.nid_to_fitins = {
            proxy.node_id: compat.fitins_to_recordset(fitins, True)
            for proxy, fitins in proxy_fitins_lst
        }
        state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}

        # Protocol config
        sampled_node_ids = list(state.nid_to_fitins.keys())
        num_samples = len(sampled_node_ids)
        if num_samples < 2:
            log(ERROR, "The number of samples should be greater than 1.")
            return False
        if isinstance(self.num_shares, float):
            state.num_shares = round(self.num_shares * num_samples)
            # If even
            if state.num_shares < num_samples and state.num_shares & 1 == 0:
                state.num_shares += 1
            # If too small
            if state.num_shares <= 2:
                state.num_shares = num_samples
        else:
            state.num_shares = self.num_shares
        if isinstance(self.reconstruction_threshold, float):
            state.threshold = round(self.reconstruction_threshold * state.num_shares)
            # Avoid too small threshold
            state.threshold = max(state.threshold, 2)
        else:
            state.threshold = self.reconstruction_threshold
        state.active_node_ids = set(sampled_node_ids)
        state.clipping_range = self.clipping_range
        state.quantization_range = self.quantization_range
        state.mod_range = self.modulus_range
        state.max_weight = self.max_weight
        sa_params_dict = {
            Key.STAGE: Stage.SETUP,
            Key.SAMPLE_NUMBER: num_samples,
            Key.SHARE_NUMBER: state.num_shares,
            Key.THRESHOLD: state.threshold,
            Key.CLIPPING_RANGE: state.clipping_range,
            Key.TARGET_RANGE: state.quantization_range,
            Key.MOD_RANGE: state.mod_range,
            Key.MAX_WEIGHT: state.max_weight,
        }

        # The number of shares should better be odd in the SecAgg+ protocol.
        if num_samples != state.num_shares and state.num_shares & 1 == 0:
            log(WARN, "Number of shares in the SecAgg+ protocol should be odd.")
            state.num_shares += 1

        # Shuffle node IDs
        random.shuffle(sampled_node_ids)
        # Build neighbour relations (node ID -> secure IDs of neighbours)
        half_share = state.num_shares >> 1
        state.nid_to_neighbours = {
            nid: {
                sampled_node_ids[(idx + offset) % num_samples]
                for offset in range(-half_share, half_share + 1)
            }
            for idx, nid in enumerate(sampled_node_ids)
        }

        state.sampled_node_ids = state.active_node_ids

        # Send setup configuration to clients
        cfgs_record = ConfigsRecord(sa_params_dict)  # type: ignore
        content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})

        def make(nid: int) -> Message:
            return driver.create_message(
                content=content,
                message_type=MessageType.TRAIN,
                dst_node_id=nid,
                group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
            )

        log(
            DEBUG,
            "[Stage 0] Sending configurations to %s clients.",
            len(state.active_node_ids),
        )
        msgs = driver.send_and_receive(
            [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
        )
        state.active_node_ids = {
            msg.metadata.src_node_id for msg in msgs if not msg.has_error()
        }
        log(
            DEBUG,
            "[Stage 0] Received public keys from %s clients.",
            len(state.active_node_ids),
        )

        for msg in msgs:
            if msg.has_error():
                continue
            key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
            node_id = msg.metadata.src_node_id
            pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2]
            state.nid_to_publickeys[node_id] = [cast(bytes, pk1), cast(bytes, pk2)]

        return self._check_threshold(state)

    def share_keys_stage(  # pylint: disable=R0914
        self, driver: Driver, context: LegacyContext, state: WorkflowState
    ) -> bool:
        """Execute the 'share keys' stage."""
        cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]

        def make(nid: int) -> Message:
            neighbours = state.nid_to_neighbours[nid] & state.active_node_ids
            cfgs_record = ConfigsRecord(
                {str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
            )
            cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
            content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
            return driver.create_message(
                content=content,
                message_type=MessageType.TRAIN,
                dst_node_id=nid,
                group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
            )

        # Broadcast public keys to clients and receive secret key shares
        log(
            DEBUG,
            "[Stage 1] Forwarding public keys to %s clients.",
            len(state.active_node_ids),
        )
        msgs = driver.send_and_receive(
            [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
        )
        state.active_node_ids = {
            msg.metadata.src_node_id for msg in msgs if not msg.has_error()
        }
        log(
            DEBUG,
            "[Stage 1] Received encrypted key shares from %s clients.",
            len(state.active_node_ids),
        )

        # Build forward packet list dictionary
        srcs: List[int] = []
        dsts: List[int] = []
        ciphertexts: List[bytes] = []
        fwd_ciphertexts: Dict[int, List[bytes]] = {
            nid: [] for nid in state.active_node_ids
        }  # dest node ID -> list of ciphertexts
        fwd_srcs: Dict[int, List[int]] = {
            nid: [] for nid in state.active_node_ids
        }  # dest node ID -> list of src node IDs
        for msg in msgs:
            node_id = msg.metadata.src_node_id
            res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
            dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
            ctxt_lst = cast(List[bytes], res_dict[Key.CIPHERTEXT_LIST])
            srcs += [node_id] * len(dst_lst)
            dsts += dst_lst
            ciphertexts += ctxt_lst

        for src, dst, ciphertext in zip(srcs, dsts, ciphertexts):
            if dst in fwd_ciphertexts:
                fwd_ciphertexts[dst].append(ciphertext)
                fwd_srcs[dst].append(src)

        state.forward_srcs = fwd_srcs
        state.forward_ciphertexts = fwd_ciphertexts

        return self._check_threshold(state)

    def collect_masked_vectors_stage(
        self, driver: Driver, context: LegacyContext, state: WorkflowState
    ) -> bool:
        """Execute the 'collect masked vectors' stage."""
        cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]

        # Send secret key shares to clients (plus FitIns) and collect masked vectors
        def make(nid: int) -> Message:
            cfgs_dict = {
                Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
                Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
                Key.SOURCE_LIST: state.forward_srcs[nid],
            }
            cfgs_record = ConfigsRecord(cfgs_dict)  # type: ignore
            content = state.nid_to_fitins[nid]
            content.configs_records[RECORD_KEY_CONFIGS] = cfgs_record
            return driver.create_message(
                content=content,
                message_type=MessageType.TRAIN,
                dst_node_id=nid,
                group_id=str(cfg[WorkflowKey.CURRENT_ROUND]),
            )

        log(
            DEBUG,
            "[Stage 2] Forwarding encrypted key shares to %s clients.",
            len(state.active_node_ids),
        )
        msgs = driver.send_and_receive(
            [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
        )
        state.active_node_ids = {
            msg.metadata.src_node_id for msg in msgs if not msg.has_error()
        }
        log(
            DEBUG,
            "[Stage 2] Received masked vectors from %s clients.",
            len(state.active_node_ids),
        )

        # Clear cache
        del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins

        # Sum collected masked vectors and compute active/dead node IDs
        masked_vector = None
        for msg in msgs:
            res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
            bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
            client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
            if masked_vector is None:
                masked_vector = client_masked_vec
            else:
                masked_vector = parameters_addition(masked_vector, client_masked_vec)
        if masked_vector is not None:
            masked_vector = parameters_mod(masked_vector, state.mod_range)
            state.aggregate_ndarrays = masked_vector

        # Backward compatibility with Strategy
        for msg in msgs:
            fitres = compat.recordset_to_fitres(msg.content, True)
            proxy = state.nid_to_proxies[msg.metadata.src_node_id]
            state.legacy_results.append((proxy, fitres))

        return self._check_threshold(state)

    def unmask_stage(  # pylint: disable=R0912, R0914, R0915
        self, driver: Driver, context: LegacyContext, state: WorkflowState
    ) -> bool:
        """Execute the 'unmask' stage."""
        cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
        current_round = cast(int, cfg[WorkflowKey.CURRENT_ROUND])

        # Construct active node IDs and dead node IDs
        active_nids = state.active_node_ids
        dead_nids = state.sampled_node_ids - active_nids

        # Send secure IDs of active and dead clients and collect key shares from clients
        def make(nid: int) -> Message:
            neighbours = state.nid_to_neighbours[nid]
            cfgs_dict = {
                Key.STAGE: Stage.UNMASK,
                Key.ACTIVE_NODE_ID_LIST: list(neighbours & active_nids),
                Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
            }
            cfgs_record = ConfigsRecord(cfgs_dict)  # type: ignore
            content = RecordSet(configs_records={RECORD_KEY_CONFIGS: cfgs_record})
            return driver.create_message(
                content=content,
                message_type=MessageType.TRAIN,
                dst_node_id=nid,
                group_id=str(current_round),
            )

        log(
            DEBUG,
            "[Stage 3] Requesting key shares from %s clients to remove masks.",
            len(state.active_node_ids),
        )
        msgs = driver.send_and_receive(
            [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
        )
        state.active_node_ids = {
            msg.metadata.src_node_id for msg in msgs if not msg.has_error()
        }
        log(
            DEBUG,
            "[Stage 3] Received key shares from %s clients.",
            len(state.active_node_ids),
        )

        # Build collected shares dict
        collected_shares_dict: Dict[int, List[bytes]] = {}
        for nid in state.sampled_node_ids:
            collected_shares_dict[nid] = []
        for msg in msgs:
            res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
            nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
            shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
            for owner_nid, share in zip(nids, shares):
                collected_shares_dict[owner_nid].append(share)

        # Remove masks for every active client after collect_masked_vectors stage
        masked_vector = state.aggregate_ndarrays
        del state.aggregate_ndarrays
        for nid, share_list in collected_shares_dict.items():
            if len(share_list) < state.threshold:
                log(
                    ERROR, "Not enough shares to recover secret in unmask vectors stage"
                )
                return False
            secret = combine_shares(share_list)
            if nid in active_nids:
                # The seed for PRG is the private mask seed of an active client.
                private_mask = pseudo_rand_gen(
                    secret, state.mod_range, get_parameters_shape(masked_vector)
                )
                masked_vector = parameters_subtraction(masked_vector, private_mask)
            else:
                # The seed for PRG is the secret key 1 of a dropped client.
                neighbours = state.nid_to_neighbours[nid]
                neighbours.remove(nid)

                for neighbor_nid in neighbours:
                    shared_key = generate_shared_key(
                        bytes_to_private_key(secret),
                        bytes_to_public_key(state.nid_to_publickeys[neighbor_nid][0]),
                    )
                    pairwise_mask = pseudo_rand_gen(
                        shared_key, state.mod_range, get_parameters_shape(masked_vector)
                    )
                    if nid > neighbor_nid:
                        masked_vector = parameters_addition(
                            masked_vector, pairwise_mask
                        )
                    else:
                        masked_vector = parameters_subtraction(
                            masked_vector, pairwise_mask
                        )
        recon_parameters = parameters_mod(masked_vector, state.mod_range)
        q_total_ratio, recon_parameters = factor_extract(recon_parameters)
        inv_dq_total_ratio = state.quantization_range / q_total_ratio
        # recon_parameters = parameters_divide(recon_parameters, total_weights_factor)
        aggregated_vector = dequantize(
            recon_parameters,
            state.clipping_range,
            state.quantization_range,
        )
        offset = -(len(active_nids) - 1) * state.clipping_range
        for vec in aggregated_vector:
            vec += offset
            vec *= inv_dq_total_ratio

        # Backward compatibility with Strategy
        results = state.legacy_results
        parameters = ndarrays_to_parameters(aggregated_vector)
        for _, fitres in results:
            fitres.parameters = parameters

        # No exception/failure handling currently
        log(
            INFO,
            "aggregate_fit: received %s results and %s failures",
            len(results),
            0,
        )
        aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
        parameters_aggregated, metrics_aggregated = aggregated_result

        # Update the parameters and write history
        if parameters_aggregated:
            paramsrecord = compat.parameters_to_parametersrecord(
                parameters_aggregated, True
            )
            context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
            context.history.add_metrics_distributed_fit(
                server_round=current_round, metrics=metrics_aggregated
            )
        return True
