import torch
import numpy as np
from scipy.stats import norm

from ..simulators.worker import ByzantineWorker
from ..utils import get_vectorized_parameters


class ALittleIsEnoughAttack(ByzantineWorker):
    """
    Args:
        n (int): Total number of workers
        m (int): Number of Byzantine workers
    """

    def __init__(self, n, m, z=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Number of supporters
        if z is not None:
            self.z_max = z
        else:
            s = np.floor(n / 2 + 1) - m
            cdf_value = (n - m - s) / (n - m)
            self.z_max = norm.ppf(cdf_value)
        self.n_good = n - m
        self._gradient = None

    def __str__(self) -> str:
        return "ALittleIsEnoughWorker"

    def get_gradient(self):
        return self._gradient

    def omniscient_callback(self):
        # Loop over good workers and accumulate their gradients
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        stacked_gradients = torch.stack(gradients, 1)
        mu = torch.mean(stacked_gradients, 1)
        std = torch.std(stacked_gradients, 1)

        self._gradient = mu - std * self.z_max

    def get_update(self, server_iterate):
        results = {}
        results['loss'] = 0
        results['length'] = 0
        results['metrics'] = {}
        for name, metric in self.metrics.items():
            results["metrics"][name] = 0

        update = self.get_gradient()
        if update is None:
            # for the first round, workers haven't yet sent their updates
            # (so send back the server's iterate)
            try:
                server_iterate = get_vectorized_parameters(server_iterate) # to enable compatibility with fltrust.
                update = server_iterate
            except:
                update = server_iterate

        results['local_iterate'] = update
        return results

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError


