from .experiment import Experiment

from aggregation.aggregators.avg import Average
from aggregation.aggregators.cclip import CClip
from aggregation.aggregators.krum import Krum
from aggregation.aggregators.rsa import RSA
from aggregation.aggregators.fltrust import FLTrust


from aggregation.simulators.worker import TorchWorker, WorkerWithMomentum, RSAWorker, FLTrustWorker
from aggregation.simulators.simulator import ParallelTrainer, RSATrainer, FLTrustTrainer


class AverageExperiment(Experiment):

    def _parse_args(self, args):
        assert args.agg == 'avg'
        self.agg = 'avg'

        self.trainer_type = ParallelTrainer
        self.worker_type = TorchWorker
        self.worker_kwargs = {}

        super()._parse_args(args)

    def _get_agg_results_path(self):
        return f"avg/seed{self.seed}"

    def _get_aggregator(self):
        return Average()


class CClipExperiment(Experiment):
    """Fixed:             agg, attack, batch_size, lr, cclip_iters, cclip_scaling
       Varies:            cclip_tau, cclip_momentum, p_norm
    """

    def _parse_args(self, args):
        assert args.agg == 'cclip'
        self.agg = 'cclip'

        assert args.cclip_tau is not None
        self.tau = args.cclip_tau

        assert args.cclip_momentum is not None
        self.momentum = args.cclip_momentum

        assert args.cclip_iters == 1
        self.iters = args.cclip_iters

        assert args.cclip_scaling is None
        self.scaling = args.cclip_scaling

        self.trainer_type = ParallelTrainer
        self.worker_type = WorkerWithMomentum
        self.worker_kwargs = {'momentum': self.momentum}

        super()._parse_args(args)

    def _get_agg_results_path(self):
        return f"p{self.p_norm}_tau{self.tau}_momentum{self.momentum}/seed{self.seed}"

    def _get_aggregator(self):
        return CClip(tau=self.tau, p_norm=self.p_norm, fixed_point=self.fixed_point)


class KrumExperiment(Experiment):

    def _parse_args(self, args):
        assert args.agg == 'krum'
        self.agg = 'krum'

        assert args.krum_m >= 1
        self.krum_m = args.krum_m

        if args.n_reduced_dims is not None:
            assert args.n_reduced_dims >= 0
            assert args.dim_reduce_method is not None
        if args.dim_reduce_method is not None:
            assert args.n_reduced_dims is not None
        self.n_reduced_dims = args.n_reduced_dims
        self.dim_reduce_method = args.dim_reduce_method

        self.trainer_type = ParallelTrainer
        self.worker_type = TorchWorker
        self.worker_kwargs = {}

        super()._parse_args(args)

    def _get_agg_results_path(self):
        if self.n_reduced_dims is None:
            if self.task_name == 'mnist':
                dims_label = 1199882
            elif self.task_name == 'cifar10':
                dims_label = 11689512  # resnet18
            elif self.task_name == 'cifar100':
                dims_label = 5252132 # resnet18 with gn
            elif self.task_name == "emnist":
                dims_label = 1201946
            else:
                raise NotImplementedError(self.task)
        else:
            dims_label = self.n_reduced_dims

        results_dir = f"m{self.krum_m}_p{self.p_norm}_"
        results_dir += f"reduce-{self.dim_reduce_method}_dims{dims_label}"
        results_dir += f"/seed{self.seed}"

        return results_dir

    def _get_aggregator(self):
        return Krum(n=self.n,
                    f=self.f,
                    m=self.krum_m,
                    p_norm=self.p_norm,
                    dim_reduce_method=self.dim_reduce_method,
                    n_reduced_dims=self.n_reduced_dims,
                    fixed_point=self.fixed_point)


class RSAExperiment(Experiment):

    def _parse_args(self, args):
        assert args.agg == 'rsa'
        self.agg = 'rsa'

        assert args.rsa_weight_decay >= 0
        self.weight_decay = args.rsa_weight_decay

        assert args.rsa_lambda is not None and args.rsa_lambda >= 0
        self.rsa_lambda = args.rsa_lambda

        self.trainer_type = RSATrainer
        self.worker_type = RSAWorker
        self.worker_kwargs = {'rsa_lambda': args.rsa_lambda}

        super()._parse_args(args)

    def _get_agg_results_path(self):
        return f"lambda{self.rsa_lambda}/seed{self.seed}"

    def _get_aggregator(self):
        return RSA(self.rsa_lambda)


class FLTrustExperiment(Experiment):

    def _parse_args(self, args):
        assert args.agg == 'fltrust'
        self.agg = 'fltrust'


        self.trainer_type = FLTrustTrainer
        self.worker_type = FLTrustWorker
        self.param1 = args.fltrust_param1
        self.worker_kwargs = {'param1': self.param1}

        super()._parse_args(args)

    def _get_agg_results_path(self):
        return f"param{self.param1}/seed{self.seed}"

    def _get_aggregator(self):
        return FLTrust(self.param1) # param1 is just a placeholder.