from .no_atk_ic import *
from .label_flip_ic import *
from .bit_flip_ic import *
from .ng_atk_ic import *
from .rd_atk_ic import *


def non_omniscient_attack_ic(byz_mode, byz_num, rank, world_size, model, w_flat, criterion, instances,
                            target, weight_decay, class_num):
    """
    Compute Byzantine gradient based on current model parameter on image classification model with typical attack

    :param byz_mode: type of Byzantine attack
    :param byz_num: total number of Byzantine workers
    :param rank: the number of this worker
    :param world_size: total number of workers
    :param model: current model
    :param w_flat: current model parameter in flat tensor form
    :param criterion: loss function
    :param instances: a batch of training instances
    :param target: the labels of training instances
    :param weight_decay: hyper-parameter of weight decay
    :param class_num: total number of classes on this image classification task (for label-flipping attack only)

    :return: computed Byzantine gradient
    """

    if rank >= byz_num or byz_mode == 'noAtk' or byz_mode == 'ALIE' or byz_mode == 'FoE':
        return no_atk_ic(model, w_flat, criterion, instances, target, weight_decay)
    else:
        if byz_mode == 'labelFlip':
            return label_flip_ic(model, w_flat, criterion, instances, target, weight_decay, class_num)
        elif byz_mode == 'bitFlip':
            return bit_flip_ic(model, w_flat, criterion, instances, target, weight_decay)
        elif byz_mode == 'NG_atk':
            return ng_atk_ic(model, w_flat, criterion, instances, target, weight_decay)
        elif byz_mode == 'RD_atk':
            return rd_atk_ic(model, w_flat, criterion, instances, target, weight_decay)
        else:
            raise ValueError("undefined byz_mode!")

