import logging
from typing import Type
import torch
import numpy as np
import copy

from federatedscope.core.trainers import GeneralTorchTrainer
from torch.nn.utils import parameters_to_vector, vector_to_parameters

logger = logging.getLogger(__name__)


def wrap_backdoorTrainer(
        base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:

    # ---------------- attribute-level plug-in -----------------------
    base_trainer.ctx.target_label_ind \
                = base_trainer.cfg.attack.target_label_ind
    base_trainer.ctx.trigger_type = base_trainer.cfg.attack.trigger_type
    base_trainer.ctx.label_type = base_trainer.cfg.attack.label_type

    # ---- action-level plug-in -------

    if base_trainer.cfg.attack.self_opt:

        base_trainer.ctx.self_lr = base_trainer.cfg.attack.self_lr
        base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch

        base_trainer.register_hook_in_train(
            new_hook=hook_on_fit_start_init_local_opt,
            trigger='on_fit_start',
            insert_pos=-1)

        base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt,
                                            trigger='on_fit_end',
                                            insert_pos=0)

    scale_poisoning = base_trainer.cfg.attack.scale_poisoning
    pgd_poisoning = base_trainer.cfg.attack.pgd_poisoning

    if scale_poisoning or pgd_poisoning:

        base_trainer.register_hook_in_train(
            new_hook=hook_on_fit_start_init_local_model,
            trigger='on_fit_start',
            insert_pos=-1)

    if base_trainer.cfg.attack.scale_poisoning:

        base_trainer.ctx.scale_para = base_trainer.cfg.attack.scale_para

        base_trainer.register_hook_in_train(
            new_hook=hook_on_fit_end_scale_poisoning,
            trigger="on_fit_end",
            insert_pos=-1)

    if base_trainer.cfg.attack.pgd_poisoning:

        base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch
        base_trainer.ctx.pgd_lr = base_trainer.cfg.attack.pgd_lr
        base_trainer.ctx.pgd_eps = base_trainer.cfg.attack.pgd_eps
        base_trainer.ctx.batch_index = 0

        base_trainer.register_hook_in_train(
            new_hook=hook_on_fit_start_init_local_pgd,
            trigger='on_fit_start',
            insert_pos=-1)

        base_trainer.register_hook_in_train(
            new_hook=hook_on_batch_end_project_grad,
            trigger='on_batch_end',
            insert_pos=-1)

        base_trainer.register_hook_in_train(
            new_hook=hook_on_epoch_end_project_grad,
            trigger='on_epoch_end',
            insert_pos=-1)

        base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt,
                                            trigger='on_fit_end',
                                            insert_pos=0)

    return base_trainer


def hook_on_fit_start_init_local_opt(ctx):

    ctx.original_epoch = ctx["num_train_epoch"]
    ctx["num_train_epoch"] = ctx.self_epoch


def hook_on_fit_end_reset_opt(ctx):

    ctx["num_train_epoch"] = ctx.original_epoch


def hook_on_fit_start_init_local_model(ctx):

    # the original global model
    ctx.original_model = copy.deepcopy(ctx.model)


def hook_on_fit_end_scale_poisoning(ctx):

    # conduct the scale poisoning
    scale_para = ctx.scale_para

    v = torch.nn.utils.parameters_to_vector(ctx.original_model.parameters())
    logger.info("the Norm of the original global model: {}".format(
        torch.norm(v).item()))

    v = torch.nn.utils.parameters_to_vector(ctx.model.parameters())
    logger.info("Attacker before scaling : Norm = {}".format(
        torch.norm(v).item()))

    ctx.original_model = list(ctx.original_model.parameters())

    for idx, param in enumerate(ctx.model.parameters()):
        param.data = (param.data - ctx.original_model[idx]
                      ) * scale_para + ctx.original_model[idx]

    v = torch.nn.utils.parameters_to_vector(ctx.model.parameters())
    logger.info("Attacker after scaling : Norm = {}".format(
        torch.norm(v).item()))

    logger.info('finishing model scaling poisoning attack')


def hook_on_fit_start_init_local_pgd(ctx):

    ctx.original_optimizer = ctx.optimizer
    ctx.original_epoch = ctx["num_train_epoch"]
    ctx["num_train_epoch"] = ctx.self_epoch
    ctx.optimizer = torch.optim.SGD(ctx.model.parameters(), lr=ctx.pgd_lr)
    # looks like adversary needs same lr to hide with others


def hook_on_batch_end_project_grad(ctx):
    '''
    after every 10 iters, we project update on the predefined norm ball.
    '''
    eps = ctx.pgd_eps
    project_frequency = 10
    ctx.batch_index += 1
    w = list(ctx.model.parameters())
    w_vec = parameters_to_vector(w)
    model_original_vec = parameters_to_vector(
        list(ctx.original_model.parameters()))
    # make sure you project on last iteration otherwise,
    # high LR pushes you really far
    if (ctx.batch_index % project_frequency
            == 0) and (torch.norm(w_vec - model_original_vec) > eps):
        # project back into norm ball
        w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm(
            w_vec - model_original_vec) + model_original_vec
        # plug w_proj back into model
        vector_to_parameters(w_proj_vec, w)


def hook_on_epoch_end_project_grad(ctx):
    '''
    after the whole epoch, we project the update on the predefined norm ball.
    '''
    ctx.batch_index = 0
    eps = ctx.pgd_eps
    w = list(ctx.model.parameters())
    w_vec = parameters_to_vector(w)
    model_original_vec = parameters_to_vector(
        list(ctx.original_model.parameters()))
    if (torch.norm(w_vec - model_original_vec) > eps):
        # project back into norm ball
        w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm(
            w_vec - model_original_vec) + model_original_vec
        # plug w_proj back into model
        vector_to_parameters(w_proj_vec, w)


def hook_on_fit_end_reset_pgd(ctx):

    ctx.optimizer = ctx.original_optimizer
