# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Aggregation functions for strategy implementations."""

from functools import reduce
from typing import List, Tuple
import numpy as np

from flwr.common import NDArrays

def aggregate(results: Tuple[ List[Tuple[NDArrays, int]], List[NDArrays]]  ) -> Tuple[NDArrays, List[int]]:
    """Compute weighted average."""
    results = results[0]
    num_examples_total = sum([num_examples for _, num_examples in results])
    weighted_weights = [
        [layer * num_examples for layer in weights] for weights, num_examples in results
    ]

    weights_prime: NDArrays = [
        reduce(np.add, layer_updates) / num_examples_total
        for layer_updates in zip(*weighted_weights)
    ]
    return weights_prime, [], None

def simple_average(results: Tuple[ List[Tuple[NDArrays, int]], List[NDArrays]]  ) -> Tuple[NDArrays, List[int]]:
    """Compute unweighted average."""
    results = results[0]

    weights = [
        [layer for layer in weights] for weights, _ in results
    ]

    weights_prime: NDArrays = [
        reduce(np.add, layer_updates) / len(results)
        for layer_updates in zip(*weights)
    ]
    return weights_prime, [], None

def weighted_loss_avg(results:List[Tuple[NDArrays, int]]) -> float:
    """Aggregate evaluation results obtained from multiple clients."""
    num_total_evaluation_examples = sum([num_examples for num_examples, _ in results])
    weighted_losses = [num_examples * loss for num_examples, loss in results]
    return sum(weighted_losses) / num_total_evaluation_examples