from .comlib import *

class KnownByzantine():
    def __init__(self,
                 worker_num,byzantine_ids):
        self.worker_num=worker_num
        self.byzantine_ids=byzantine_ids
        self.setupKBCoeff(byzantine_ids,worker_num)

    def setupKBCoeff(self,byzantine_ids,worker_num):
        '''
            byzantine: 1/byzantine_num*worker_num
            normal:1/normal_num*worker_num
        '''
        self.byzantine_ids=torch.tensor(byzantine_ids)
        self.worker_num=worker_num

        byzantine_num=len(byzantine_ids)
        normal_num=worker_num-byzantine_num

        self.byzantine_coeff=worker_num/byzantine_num
        self.normal_coeff=worker_num/normal_num


    def get(self,chosen_workers,*args):
        chosen_workers=torch.tensor(chosen_workers)
        coeff=torch.isin(chosen_workers,self.byzantine_ids)
        coeff=torch.where(coeff, self.byzantine_coeff,-self.normal_coeff)
        return coeff


    # def getGrad(self,workerModel,chosen_workers):
    #     coeff=self.getKBCoeff(chosen_workers)
    #     grad=self.g_strategy.getGrad(coeff,workerModel,chosen_workers)
    #     return grad

    # def getInstantLoss(self,workerModel,chosen_workers):
    #     coeff=self.getKBCoeff(chosen_workers)
    #     grad=self.g_strategy.getInstantLoss(coeff,workerModel,chosen_workers)
    #     return grad

    # def getLossFromSigmoidValue(self,chosen_workers,chosenWorkersValue):
    #     coeff=self.getKBCoeff(chosen_workers).to(chosenWorkersValue.device)
    #     return torch.mean(coeff*chosenWorkersValue)


    