from .comlib import *
from .discriminator_strategy import StrategySigmoidValue,LinearGrad

class KnownByzantine(StrategySigmoidValue):
    def __init__(self, g_strategy:LinearGrad,
                 worker_num,byzantine_ids):
        super().__init__(g_strategy)
        self.setupKBCoeff(byzantine_ids,worker_num)

    def setupKBCoeff(self,byzantine_ids,worker_num):
        '''
            byzantine: 1/byzantine_num*worker_num
            normal:1/normal_num*worker_num
        '''
        self.byzantine_ids=torch.tensor(byzantine_ids)
        self.worker_num=worker_num

        byzantine_num=len(byzantine_ids)
        normal_num=worker_num-byzantine_num

        self.byzantine_coeff=worker_num/byzantine_num
        self.normal_coeff=worker_num/normal_num


    def getKBCoeff(self,chosen_workers):
        chosen_workers=torch.tensor(chosen_workers)
        coeff=torch.isin(chosen_workers,self.byzantine_ids)
        coeff=torch.where(coeff, self.byzantine_coeff,-self.normal_coeff)
        return coeff


    def getGrad(self,workerModel,chosen_workers):
        coeff=self.getKBCoeff(chosen_workers)
        grad=self.g_strategy.getGrad(coeff,workerModel,chosen_workers)
        return grad

    def getInstantLoss(self,workerModel,chosen_workers):
        coeff=self.getKBCoeff(chosen_workers)
        grad=self.g_strategy.getInstantLoss(coeff,workerModel,chosen_workers)
        return grad

    def getLossFromSigmoidValue(self,chosen_workers,chosenWorkersValue):
        coeff=self.getKBCoeff(chosen_workers).to(chosenWorkersValue.device)
        return torch.mean(coeff*chosenWorkersValue)


    