
import numpy as np


from .RobustCommunication import *
from .Topology import *
from .ByzantineAttack import *
from .Optimization import *
from .identifiers import *


class Simulation(object):
    def __init__(
            self, topology, communication_rule, attack: AttackType,
            step_size_communication='Auto', nb_iterations=100, save_trajectories=False,
            optimizing_factor=5, step_size='Auto', wandb_run=None, log_every_n_step=10
    ):
        if step_size != 'Auto':
            step_size_communication = step_size
            print("Warning: step_size is deprecated")
        
        self.topology = topology
        self._direction = None
        self.save_trajectories = save_trajectories
        self.optimizing_factor = optimizing_factor
        self.log_every_n_step = log_every_n_step
        self.wandb_run =wandb_run

        if step_size_communication == 'Auto':
            eigenvalues, _ = np.linalg.eigh(self.topology.laplacian_honest())
            self.step_size_communication = 1 / (eigenvalues[-1] + 1e-8)
        else:
            self.step_size_communication = step_size_communication

        self.nb_iterations = nb_iterations
        self.data = None

        # initialization of the communication scheme

        if communication_rule is CommunicationRule.GLOBAL_CLIPPING:
            self.robust_communication = GlobalClipping(self.topology, self.step_size_communication)
        elif communication_rule is CommunicationRule.GLOBAL_CLIPPING_APPROX:
            self.robust_communication = SimplifiedGlobalClipping(self.topology, self.step_size_communication)
        elif communication_rule is CommunicationRule.LOCAL_CLIPPING_HE:
            raise NotImplementedError("He et al. Clipping Value currently miss-defined")
            self.robust_communication = LocalClippedGossip(
                self.topology, self.step_size_communication, local_ratio=1)
        elif communication_rule is CommunicationRule.LOCAL_CLIPPING_OURS:
            self.robust_communication = LocalClippedGossip(
                self.topology, self.step_size_communication, local_ratio=2)
        elif communication_rule is CommunicationRule.LOCAL_TRIMMING:
            self.robust_communication = LocalTrimmedGossip(self.topology, self.step_size_communication)
        elif communication_rule is CommunicationRule.LOCAL_CLIPPING_SYM:
            self.robust_communication = LocalClippedGossipSymetric(
                self.topology, self.step_size_communication, local_ratio=2)
        else:
            raise ValueError("Communication scheme not implemented ")

        # initialization of the attack type

        if attack is AttackType.DISSENSION_HE:
            self.byzantine_attack = PlainDissensusAttack(self.topology, self.robust_communication)
        elif attack is AttackType.DISSENSION_SPECTRAL:
            self.byzantine_attack = SpectralDissensusAttack(self.topology, self.robust_communication)
        elif attack is AttackType.NO_ATTACK_ATTACK:
            self.byzantine_attack = NoAttackAttack(self.topology, self.robust_communication)
        elif attack is AttackType.ALIE:
            self.byzantine_attack = ALIE(self.topology, self.robust_communication)
        elif attack is AttackType.FOE:
            self.byzantine_attack = FOE(self.topology, self.robust_communication)
        else:
            raise ValueError("Attack not implemented")

    def communicate(self, x_honest):
        """
        :param x_honest: (nb_honest, d) parameters of nodes before communication
        :return: (nb_honest, d) parameters of honest nodes after communication
        """
        x_byzantine = self.byzantine_attack.declare_byz_parameters(x_honest)
        x = np.concatenate([x_honest, x_byzantine], axis=0)
        x_new = self.robust_communication.aggregate(x)
        return x_new[:self.topology.nb_honest, :]

    def run(self):
        raise NotImplementedError


class RobustAvgConsensusSimulation(Simulation):
    def __init__(
            self, topology, communication_rule, attack: AttackType,
            step_size_communication='Auto', nb_iterations=100, direction=None,
            save_trajectories=False, optimizing_factor=1, wandb_run=None, log_every_n_step=1
    ):
        super().__init__(
            topology=topology, communication_rule=communication_rule, attack=attack, step_size_communication=step_size_communication,
            nb_iterations=nb_iterations, direction=direction, save_trajectories=save_trajectories,
            optimizing_factor=optimizing_factor, wandb_run=wandb_run, log_every_n_step=log_every_n_step
        )

    def run(self, x_honest_init):
        """
        deprecated: run only the average consensus problem
        """
        if x_honest_init.ndim == 1:
            x_honest_init = x_honest_init[:, np.newaxis]

        x_honest_star = x_honest_init
        x_honest = x_honest_init

        x_storage = np.empty((x_honest.shape[0], x_honest.shape[1], self.nb_iterations + 1))
        x_storage[:, :, 0] = x_honest_star

        for t in range(1, self.nb_iterations + 1):
            x_honest_new = self.communicate(x_honest)  # this gets caught and handled as an exception
            x_storage[:, :, t] = x_honest_new


        self.data = {"X": x_storage}
        return self.data


class RobustDecentralizedOptimSimulation(Simulation):
    def __init__(
            self, topology, communication_rule, attack: AttackType,
            optimization_task: OptimizationTask, algorithm_duality: AlgorithmDuality,
            wandb_run=None, step_size_communication='Auto', nb_iterations=100,
            save_trajectories=False, optimizing_factor=0, log_every_n_step=1
    ):
        super().__init__(
            topology=topology, communication_rule=communication_rule, attack=attack, 
            step_size_communication=step_size_communication,
            nb_iterations=nb_iterations, save_trajectories=save_trajectories,
            optimizing_factor=optimizing_factor, wandb_run=wandb_run, log_every_n_step=log_every_n_step
        )
        self.optimization_task = optimization_task
        self.algorithm_duality = algorithm_duality

    def run(self):
        ### For Saving Data
        # error_honest = np.zeros((self.topology.nb_honest, self.nb_iterations + 1))
        
        mean_accuracy = np.zeros((self.nb_iterations + 1))
        loss_train = np.zeros((self.nb_iterations + 1))
        variance = np.zeros((self.nb_iterations + 1))
        loss_test = np.zeros((self.nb_iterations + 1))

        x_storage = None
        if self.save_trajectories:
            x_storage = np.empty((self.topology.nb_honest, self.optimization_task.dim, self.nb_iterations + 1))

        ### Initializations
        if self.algorithm_duality is AlgorithmDuality.PRIMAL:
            x_honest = np.zeros((self.topology.nb_honest, self.optimization_task.dim))
        elif self.algorithm_duality is AlgorithmDuality.DUAL:
            y_honest = np.zeros((self.topology.nb_honest, self.optimization_task.dim))
            x_honest = self.optimization_task.dual_gradient(y_honest, y_honest)
        else:
            raise ValueError(f"{self.algorithm_duality} unknown duality")


        ### Saving Metrics 
        # error_honest[:, 0] = self.optimization_task.error_honest(x_honest)
        # mean_error[0] = self.optimization_task.mean_error(x_honest)
        mean_accuracy[0] = self.optimization_task.mean_accuracy(x_honest)
        loss_train[0] = self.optimization_task.loss_train(x_honest)
        loss_test[0] = self.optimization_task.loss_test(x_honest)
        variance[0] = self.optimization_task.variance(x_honest)


        if self.save_trajectories:
            x_storage[:, :, 0] = x_honest

        if self.wandb_run != None:
            self.wandb_run.log(
                    {"acc":1 - mean_accuracy[0], "loss_train": loss_train[0], 'loss_test': loss_test[0], 'var': variance[0], "iter": 0})

        
        for t in range(1, self.nb_iterations + 1):
            ### Optimization step
            if self.algorithm_duality is AlgorithmDuality.PRIMAL:
                x_honest = self.optimization_task.gradient_descent_step(x_honest)
                x_honest = self.communicate(x_honest)
            elif self.algorithm_duality is AlgorithmDuality.DUAL:
                x_honest = self.optimization_task.dual_gradient(y_honest=y_honest, x_guess=x_honest)
                y_honest = y_honest + (self.communicate(x_honest) - x_honest)


            ### Saving Metrics 
            if t % self.log_every_n_step==0:
                mean_accuracy[t] = self.optimization_task.mean_accuracy(x_honest)
                loss_train[t] = self.optimization_task.loss_train(x_honest)
                loss_test[t] = self.optimization_task.loss_test(x_honest)

            else:
                mean_accuracy[t] = mean_accuracy[t-1]
                loss_train[t] = loss_train[t-1]
                loss_test[t] = loss_test[t-1]

            variance[t] = self.optimization_task.variance(x_honest) # more or less free

            if self.save_trajectories:
                x_storage[:, :, t] = x_honest
            
            if self.wandb_run != None and t % self.log_every_n_step==0:
                self.wandb_run.log(
                    {"acc":mean_accuracy[t], "loss_train": loss_train[t], 'loss_test': loss_test[t], 'variance' : variance[t], "iter": t})


        self.data = {"accuracy": mean_accuracy, "loss_train": loss_train, 'loss_test': loss_test, 'variance':variance, 'X': x_storage,
                     'out':x_honest}
        return self.data