from .comlib import *
from .z.sigmoid import DiscriminatorSigmoid,DiscriminatorSigmoidLinear

class DiscriminatorCELinear(DiscriminatorSigmoidLinear):
    def __init__(self, modelSetup, workerDataset, save_folder, label=''):
        super().__init__(modelSetup, workerDataset, save_folder, label)

    @staticmethod
    def getTargetWeightFromCoeff(worker_ids,coeff):
        '''
        data_num same
        '''     
        target1=coeff>0
        target3=target1[worker_ids]
        abs_coeff=torch.abs(coeff)
        weight=abs_coeff[worker_ids]

        return target3.to(torch.float32),weight
    

    def getCELossFunc(self,chosen_workers):
        coeff=self.getCoeff(chosen_workers)
        coeff=self.normalizeCoeff(coeff,chosen_workers)
        def f(out,worker_ids):
            out=out.view(-1)
            target,weight=self.getTargetWeightFromCoeff(worker_ids,coeff)
            loss=nn.BCEWithLogitsLoss(weight=weight,reduction='sum'
                                        )(out,target)
            loss=loss/len(chosen_workers)
            return loss
        return f


    def getCELoss(self,chosen_workers,batch_size):
        lossFunc=self.getCELossFunc(chosen_workers)

        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)   

        loss=self.getValueFromDataset(lossFunc,chosenWorkerDataset,batch_size)
        return loss

    
    def getLossGrad(self,chosen_workers,batch_size):
        lossFunc=self.getCELossFunc(chosen_workers)

        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)

        grad_tuple=self.getJacFromDataset(lossFunc,chosenWorkerDataset,batch_size)
        return training.ModelPara(grad_tuple.tensors)