from .comlib import *
from .worker_model import WorkerModelSigmoid


class LinearCEGrad(fed_learning.LinearGradR):
    def __init__(self, batch_size):
        super().__init__(None, batch_size)

    @staticmethod
    def getTargetWeightFromCoeff(worker_ids,coeff): 
        target1=coeff>0
        target3=target1[worker_ids]
        abs_coeff=torch.abs(coeff)
        weight=abs_coeff[worker_ids]

        return target3.to(torch.float32),weight

    @staticmethod
    def getWorkerWeightLoss(weight,out,worker_ids):     
        out=out.view(-1)
        target,weight=LinearCEGrad.getTargetWeightFromCoeff(worker_ids,weight)
        loss=nn.BCEWithLogitsLoss(weight=weight,reduction='sum'
                                    )(out,target)
        return loss
    
    @staticmethod
    def get_weighted_grad(weight,chosen_workers,modelSetup,worker_dataset,batch_size):
        device=modelSetup.device
        chosenWorkerDataset=worker_dataset.getSubWokerDataset(chosen_workers)
        weight=LinearCEGrad.normalizeCoeff(weight,worker_dataset,chosen_workers).to(device)
        funcOut=lambda out, *out_args: LinearCEGrad.getWorkerWeightLoss(weight,out, *out_args)
        grad=modelSetup.calcuDatasetJacobian(funcOut,chosenWorkerDataset,batch_size)
        return grad

    def get(self,weight,chosen_workers,worker_dataset,modelSetup,*args):
        weight,chosen_workers=self.get_weight_for_redundent(weight,chosen_workers,worker_dataset)
        grad=self.get_weighted_grad(weight,chosen_workers,modelSetup,worker_dataset,self.batch_size)
        return grad.tensors
    

