import numpy as np


class Communication:
    """
    Operate one step of communication
    """

    def __init__(
            self, topology, step_size
    ):
        self.topology = topology
        self.step_size = step_size

    def aggregate(self, x):
        """
        :param x: (n,d) numpy array: input parameters of each node at time t
        :return: parameters after (robust) aggregation
        """
        raise NotImplementedError


class ClippedGossip(Communication):
    def __init__(self, topology, step_size):
        super().__init__(topology, step_size)

    def clip(self, lmbd, tau):
        """
        node-wise clipping using thresholds tau
        :param lmbd: (nb_honest**2 + nb_byzantine**2, d): pair wise differences
        :param tau: (nb_honest,) array: local clipping thresholds
        :return: (nb_honest**2 + nb_byzantine**2, d) clipped lmbd
        """
        B = self.topology.adjacency_matrix
        C = self.topology.interaction_matrix

        expand = lambda matrix: matrix[:, np.newaxis] #np.repeat(np.expand_dims(matrix, axis=1), repeats=lmbd.shape[1], axis=1)

        norm = expand(np.linalg.norm(lmbd, axis=1))
        tau_edges_wise = expand(B[:self.topology.nb_honest, :].T @ tau)
        index_clipping = norm > tau_edges_wise



        blank_array = np.zeros_like(lmbd)
        np.divide(lmbd, norm, out=blank_array, where=index_clipping)
        clipped_lmbd = (blank_array * index_clipping * tau_edges_wise
                        + np.logical_not(index_clipping) * lmbd).reshape(lmbd.shape)

        if np.isnan(clipped_lmbd).any():
            print("WARNING : Clipping produce nan values")
        return clipped_lmbd


class LocalClippedGossip(ClippedGossip):
    """
    Clipping local by clipping k times the number of Byzantine neighbors
    """
    name = 'LocalClipping'

    def __init__(self, topology, step_size, local_ratio=2):
        """
        :param local_ratio: each honest node clip local_ratio times
        the number of Byzantine neighbors
        """
        super().__init__(topology, step_size)
        self.local_ratio = local_ratio

    def aggregate(self, x):
        B = self.topology.adjacency_matrix
        C = self.topology.interaction_matrix

        lmbd = C.T @ x
        if lmbd.ndim == 1:
            lmbd = np.expand_dims(lmbd, axis=1)

        # Computing node-wise clipping thresholds
        tau = np.ones((self.topology.nb_honest,))
        for i_honest in range(self.topology.nb_honest):
            edges_i = (np.bool8(B[i_honest, :]))
            nb_byz_neighbors_i = int(np.sum(B[i_honest, self.topology.nb_honest ** 2:]))
            nb_neighbors_i = int(np.sum(B[i_honest,...]))
            if self.local_ratio == 1:
                # case of [He and Karimirredy 2022]
                if (nb_byz_neighbors_i + 1) >= lmbd[edges_i, :].shape[0]:
                    tau[i_honest] = 0
                else:
                    difference_with_neighbors_sorted = np.sort(
                            np.linalg.norm(lmbd[edges_i, :], axis=1)
                        )[:- nb_byz_neighbors_i - 1]
                    tau[i_honest] = np.sqrt(self.step_size * np.sum(
                        difference_with_neighbors_sorted**2))
            else:
                if (self.local_ratio * nb_byz_neighbors_i) >= lmbd[edges_i, :].shape[0]:
                    tau[i_honest] = 0
                else:
                    tau[i_honest] = np.sort(
                            np.linalg.norm(lmbd[edges_i, :], axis=1)
                        )[- (self.local_ratio * nb_byz_neighbors_i) - 1]

        clipped_lmbd = self.clip(lmbd,tau)
        return x - self.step_size * B @ clipped_lmbd



class GlobalClipping(ClippedGossip):
    """
    Implementation of the global clipping rule.
    """

    def __init__(self, topology, step_size):
        super().__init__(topology, step_size)

        self.accumulated_clipping_threshold = 0

    def aggregate(self, x):
        B = self.topology.adjacency_matrix
        C = self.topology.interaction_matrix

        lmbd = C.T @ x
        if lmbd.ndim == 1:
            lmbd = np.expand_dims(lmbd, axis=1)

        # Computing global clipping threshold using global Clipping Rule

        norms_lmbd = np.linalg.norm(lmbd, axis=1)

        if self.topology.nb_byz <1 :
            return x - self.step_size * B @ lmbd

        norms_lmbd_h_sorted = np.sort(norms_lmbd)[:-self.topology.nb_byz * 2]
        norms_sf = np.sum(norms_lmbd_h_sorted) - np.cumsum(norms_lmbd_h_sorted)

        kappa = norms_lmbd_h_sorted.shape[0]
        for i in range(1, norms_lmbd_h_sorted.shape[0], 2):
            if (norms_sf[- i] >= norms_sf[0] * self.topology.delta_infinite() + # todo: dimension > 1
                    2 * self.step_size * self.topology.nb_byz ** 2 / self.topology.nb_honest
                    * (self.accumulated_clipping_threshold + norms_lmbd_h_sorted[-i])
            ):  # The *2 term comes from the fact that we express each edge two times using matrix C
                kappa = i
                break
        tau = norms_lmbd_h_sorted[-kappa]
        self.accumulated_clipping_threshold += tau

        tau = np.ones((self.topology.nb_honest,)) * tau

        return x - self.step_size * B @ self.clip(lmbd, tau)



class SimplifiedGlobalClipping(ClippedGossip):
    """
    Implementation of the simplified global clipping rule.
    """

    def __init__(self, topology, step_size):
        super().__init__(topology, step_size)

        self.accumulated_clipping_threshold = 0

    def aggregate(self, x):
        B = self.topology.adjacency_matrix
        C = self.topology.interaction_matrix

        lmbd = C.T @ x
        if lmbd.ndim == 1:
            lmbd = np.expand_dims(lmbd, axis=1)

        # Computing global clipping threshold using global Clipping Rule

        norms_lmbd = np.linalg.norm(lmbd, axis=1)
        norms_lmbd_h_sorted = np.sort(norms_lmbd)[:-self.topology.nb_byz * 2]

        norms_sf = np.sum(norms_lmbd_h_sorted) - np.cumsum(norms_lmbd_h_sorted)

        kappa = norms_lmbd_h_sorted.shape[0]
        for i in range(1, norms_lmbd_h_sorted.shape[0], 2):
            if (i // 2 * norms_lmbd_h_sorted[-i] >=
                    norms_sf[0] * self.topology.delta_infinite() * norms_lmbd_h_sorted[-i] +
                    2 * self.step_size * self.topology.nb_byz ** 2 / self.topology.nb_honest
                    * (self.accumulated_clipping_threshold + norms_lmbd_h_sorted[-i])
            ):  # The *2 term comes from the fact that we express each edge two times using matrix C
                kappa = i
                break
        tau = norms_lmbd_h_sorted[-kappa]
        self.accumulated_clipping_threshold += tau

        tau = np.ones((self.topology.nb_honest,)) * tau

        return x - self.step_size * B @ self.clip(lmbd, tau)


class LocalTrimmedGossip(Communication):
    """
        Trimming local by removing k times the number of Byzantine neighbors
    """

    name = 'LocalTrimmedGossip'

    def __init__(self, topology, step_size, local_ratio=2):
        """
        :param local_ratio: each honest node trim local_ratio times
        the number of Byzantine neighbors
        """
        super().__init__(topology, step_size)
        self.local_ratio = local_ratio

    def trim(self, lmbd, tau):
        """
        node-wise clipping using thresholds tau
        :param lmbd: (nb_honest**2 + nb_byzantine**2, d): pair wise differences
        :param tau: (nb_honest,) array: local trimming thresholds
        :return: (nb_honest**2 + nb_byzantine**2, d) clipped lmbd
        """
        B = self.topology.adjacency_matrix
        C = self.topology.interaction_matrix

        expand = lambda matrix: matrix[:, np.newaxis] #np.repeat(np.expand_dims(matrix, axis=1), repeats=lmbd.shape[1], axis=1)

        norm = expand(np.linalg.norm(lmbd, axis=1))
        tau_edges_wise = expand(B[:self.topology.nb_honest, :].T @ tau)
        index_trimming = norm > tau_edges_wise
        trimmed_lmbd = (np.logical_not(index_trimming) * lmbd).reshape(lmbd.shape)

        if np.isnan(trimmed_lmbd).any():
            print("WARNING: Trimming produce nan values")
        return trimmed_lmbd

    def aggregate(self, x):
        B = self.topology.adjacency_matrix
        C = self.topology.interaction_matrix

        lmbd = C.T @ x
        if lmbd.ndim == 1:
            lmbd = np.expand_dims(lmbd, axis=1)

        # Computing node-wise trimming
        tau = np.ones((self.topology.nb_honest,))
        for i_honest in range(self.topology.nb_honest):
            edges_i = (np.bool8(
                B[i_honest, :]))
            nb_byz_neighbors_i = int(np.sum(B[i_honest, self.topology.nb_honest ** 2:]))
            if (self.local_ratio * nb_byz_neighbors_i) >= lmbd[edges_i, :].shape[0]:
                tau[i_honest] = 0
            else:
                tau[i_honest] = np.sort(np.linalg.norm(lmbd[edges_i, :], axis=1))[- self.local_ratio
                                                                              * nb_byz_neighbors_i - 1]

        return x - self.step_size * B @ self.trim(lmbd, tau)
