from .RobustCommunication import *


class ByzantineAttack:
    name = "Byzantine"

    def __init__(self, topology, communication):
        self.topology = topology
        self.communication = communication

    def attack_stealth(self, x_h_attack, x_honest):
        """
        :param x_h_attack: (nb_honest, d) ndarrays
         parameters that are declared to honest nodes by byzantine nodes before scaling to stealth the corrupted message
          (all byzantine neighbors of node i declare the same parameter).
          typically byzantine node j declare to honest node i : x_j = x_honest[i,:] + x_h_attack[i,:],
          here we scale u_j to bypass defences.
        :param x_honest: (nb_honest, d) ndarrays
        :return:
        """
        # scale the attack to maximize the efficiency, depends on the type of aggregation scheme
        x_h_attack = x_h_attack * 1e4

        if isinstance(self.communication, LocalTrimmedGossip):# or isinstance(self.communication, LocalClippedGossip):
            # Set the scale of node i so that the first 2*N_b(i) honest neighbors of node i are filtered
            # Hence, the norm of byzantine neighbors is just bellow this threshold

            B = self.topology.adjacency_matrix
            C_h = self.topology.interaction_matrix_honest

            lmbd_h = C_h.T @ x_honest
            if lmbd_h.ndim == 1:
                lmbd_h = lmbd_h[:, np.newaxis]

            # Computing node-wise trimming thresholds
            tau = np.ones((self.topology.nb_honest,))
            for i_honest in range(self.topology.nb_honest):
                edges_i_honest = (np.bool8(
                    B[i_honest, :self.topology.nb_honest ** 2]))
                nb_byz_neighbors_i = int(np.sum(B[i_honest, self.topology.nb_honest ** 2:]))

                sorted_norms_neighbors =np.sort(
                    np.linalg.norm(lmbd_h[edges_i_honest, :], axis=1)
                )
                if self.communication.local_ratio * nb_byz_neighbors_i + 1 > sorted_norms_neighbors.shape[0]:
                    tau[i_honest] = 0
                else:
                    tau[i_honest] = sorted_norms_neighbors[
                                        - self.communication.local_ratio * nb_byz_neighbors_i - 1
                                        ] * (0.999)

            # Then the attack is defined using this honest-node wise threshold

            expand = lambda matrix: np.repeat(np.expand_dims(matrix, axis=1),
                                              repeats=x_h_attack.shape[1], axis=1)

            norm = expand(np.linalg.norm(x_h_attack, axis=1))
            return x_honest + np.divide(x_h_attack, norm, where=norm > 0) * expand(tau)

        elif isinstance(self.communication, ClippedGossip):
            # under clippedGossip aggregation, no filtering is done
            scale = 1
            return x_h_attack * scale + x_honest

    def declare_byz_parameters(self, x_honest):
        pass


class SpectralDissensusAttack(ByzantineAttack):
    name = "Spectral Dissensus"

    def __init__(self, topology, communication):
        super().__init__(topology, communication)

        eigenvalues, eigenvectors = np.linalg.eigh(self.topology.laplacian_honest())
        atck_number = 1
        while (np.abs(eigenvectors[:, atck_number])< 1e-10).any():
            if atck_number == eigenvalues.shape[0]-1:
                break
            if (eigenvalues[atck_number+1] - eigenvalues[atck_number])<1e-10:
                atck_number += 1
            else:
                break
        self.eigenvector_attack = eigenvectors[:, atck_number]

    def declare_byz_parameters(self, x_honest):
        C = self.topology.interaction_matrix
        B = self.topology.adjacency_matrix

        attack_projection = np.outer(self.eigenvector_attack, self.eigenvector_attack)

        x_h_attack =  self.attack_stealth(attack_projection @ x_honest, x_honest)
        # works only if the number of byzantine units = number of edges linking byzantine to honest nodes
        # (w.l.o.g assumption, as Byzantines can send different messages to their neighbors)
        x_byz = (- C[self.topology.nb_honest:, self.topology.nb_honest ** 2:] @
                 B[:self.topology.nb_honest, self.topology.nb_honest ** 2:].T @ x_h_attack)

        return x_byz

#### Not used
class TimeSpecificDissensus(ByzantineAttack):
    name = "Time Specific Dissensus"

    def __init__(self, topology, communication):
        super().__init__(topology, communication)
        self.time_horizon = 10

    def declare_byz_parameters(self, x_honest):
        C = self.topology.interaction_matrix
        B = self.topology.adjacency_matrix

        lap_h = self.topology.laplacian_honest()
        id_h = np.eye(lap_h.shape)

        x_h_attack =  self.attack_stealth(lap_h @ np.linalg.matrix_power(
            id_h - self.communication.step_size * lap_h, 2 * self.time_horizon
        ) @ x_honest, x_honest)

        # works only if the number of byzantine units = number of edges linking byzantine to honest nodes
        # (w.l.o.g assumption, as Byzantines can send different messages to their neighbors)

        x_byz = (- C[self.topology.nb_honest:, self.topology.nb_honest ** 2:] @
                 B[:self.topology.nb_honest, self.topology.nb_honest ** 2:].T @ x_h_attack)

        return x_byz


class PlainDissensusAttack(ByzantineAttack):
    name = "Plain Dissensus"

    # Dissensus attack, as described in [He and Karimireddy 2022]

    def __init__(self, topology, communication):
        super().__init__(topology, communication)

    def declare_byz_parameters(self, x_honest):
        C = self.topology.interaction_matrix
        B = self.topology.adjacency_matrix

        lap_h = self.topology.laplacian_honest()

        x_h_attack = self.attack_stealth(lap_h @ x_honest, x_honest)

        # works only if the number of byzantine units = number of edges linking byzantine to honest nodes
        # (w.l.o.g assumption, as Byzantines can send different messages to their neighbors)

        x_byz = (- C[self.topology.nb_honest:, self.topology.nb_honest ** 2:] @
                 B[:self.topology.nb_honest, self.topology.nb_honest ** 2:].T @ x_h_attack)

        return x_byz

##### Not used
class ConsensusAttack(ByzantineAttack):
    name = "Consensus"

    def __init__(self, topology, communication):
        super().__init__(topology, communication)

        self.attack_direction = 0

    def set_attack_direction(self, direction):
        self.attack_direction = direction * 1e16

    def declare_byz_parameters(self, x_honest):
        C = self.topology.interaction_matrix
        B = self.topology.adjacency_matrix

        declared_position_byz = np.repeat(
            self.attack_direction[:, np.newaxis].T, axis=0, repeats=x_honest.shape[0]
        )

        x_h_attack = self.attack_stealth(declared_position_byz, x_honest)

        # works only if the number of byzantine units = number of edges linking byzantine to honest nodes
        # (w.l.o.g assumption, as Byzantines can send different messages to their neighbors)

        x_byz = (- C[self.topology.nb_honest:, self.topology.nb_honest ** 2:] @
                 B[:self.topology.nb_honest, self.topology.nb_honest ** 2:].T @ x_h_attack)

        return x_byz

class ALIE(ByzantineAttack):
    name = "ALIE"

    def __init__(self, topology, communication):
        super().__init__(topology, communication)

    def declare_byz_parameters(self, x_honest):
        C = self.topology.interaction_matrix
        B = self.topology.adjacency_matrix

        coordinate_wise_std = np.std(x_honest, axis=0)
        attack_direction = np.mean(x_honest,axis=0) + coordinate_wise_std

        declared_position_byz = (np.repeat(
            attack_direction[:, np.newaxis].T, axis=0, repeats=x_honest.shape[0]
        ) - x_honest)

        x_h_attack = self.attack_stealth(declared_position_byz , x_honest)

        # works only if the number of byzantine units = number of edges linking byzantine to honest nodes
        # (w.l.o.g assumption, as Byzantines can send different messages to their neighbors)

        x_byz = (- C[self.topology.nb_honest:, self.topology.nb_honest ** 2:] @
                 B[:self.topology.nb_honest, self.topology.nb_honest ** 2:].T @ x_h_attack)

        return x_byz

class FOE(ByzantineAttack):
    name = "ALIE"

    def __init__(self, topology, communication):
        super().__init__(topology, communication)

    def declare_byz_parameters(self, x_honest):
        C = self.topology.interaction_matrix
        B = self.topology.adjacency_matrix

        attack_direction = - np.mean(x_honest,axis=0)

        declared_position_byz = (np.repeat(
            attack_direction[:, np.newaxis].T, axis=0, repeats=x_honest.shape[0]
        ) - x_honest)

        x_h_attack = self.attack_stealth(declared_position_byz, x_honest)

        # works only if the number of byzantine units = number of edges linking byzantine to honest nodes
        # (w.l.o.g assumption, as Byzantines can send different messages to their neighbors)

        x_byz = (- C[self.topology.nb_honest:, self.topology.nb_honest ** 2:] @
                 B[:self.topology.nb_honest, self.topology.nb_honest ** 2:].T @ x_h_attack)

        return x_byz


#### Baseline - not used
class NoAttackAttack(ByzantineAttack):
    name = "No attack attack"

    def __init__(self, topology, communication):
        super().__init__(topology, communication)

    def declare_byz_parameters(self, x_honest):
        C = self.topology.interaction_matrix
        B = self.topology.adjacency_matrix

        x_h_attack = x_honest  # Byzantine nodes send the same value as their honest neighbors
        # The attack comes only from side effects of the robust aggregation scheme

        # works only if the number of byzantine units = number of edges linking byzantine to honest nodes
        # (w.l.o.g assumption, as Byzantines can send different messages to their neighbors)
        x_byz = (- C[self.topology.nb_honest:, self.topology.nb_honest ** 2:] @
                 B[:self.topology.nb_honest, self.topology.nb_honest ** 2:].T @ x_h_attack)

        return x_byz
