from .comlib import *
from .z.sigmoid import DiscriminatorSigmoidLinear
from .cross_entro_grad import DiscriminatorCELinear

def setupKBCoeff(byzantine_ids,device,worker_num):
    '''
        byzantine: 1/byzantine_num*worker_num
        normal:1/normal_num*worker_num
    '''
    byzantine_ids=torch.tensor(byzantine_ids,device=device)

    byzantine_num=len(byzantine_ids)
    worker_num=worker_num
    normal_num=worker_num-byzantine_num

    byzantine_coeff=worker_num/byzantine_num
    normal_coeff=worker_num/normal_num
    return byzantine_ids,byzantine_coeff, normal_coeff

def getKBCoeff_(chosen_workers,device,byzantine_ids,byzantine_coeff,normal_coeff):
    chosen_workers=torch.tensor(chosen_workers,device=device)
    coeff=torch.isin(chosen_workers,byzantine_ids)
    coeff=torch.where(coeff, byzantine_coeff,-normal_coeff)
    return coeff

def getKBLoss_(chosen_workers,chosenWorkersValue,device,byzantine_ids,byzantine_coeff,normal_coeff):
    coeff=getKBCoeff_(chosen_workers,device,byzantine_ids,byzantine_coeff,normal_coeff)
    return torch.mean(coeff*chosenWorkersValue)

def getKBCoeff(d,chosen_workers):
    return getKBCoeff_(chosen_workers,d.device,d.byzantine_ids,d.byzantine_coeff,d.normal_coeff)

def getKBLoss(d,chosen_workers,chosenWorkersValue):
    coeff=getKBCoeff(d,chosen_workers)
    return torch.mean(coeff*chosenWorkersValue)

class KnownByzantineSigmoid(DiscriminatorSigmoidLinear):
    def __init__(self, modelSetup, workerDataset, save_file, label='',byzantine_ids=None):
        super().__init__(modelSetup, workerDataset, save_file, label)
        
        self.byzantine_ids, self.byzantine_coeff,self.normal_coeff=setupKBCoeff(byzantine_ids,self.device,self.worker_num)

    def getLossFromSigmoidValue(self,chosen_workers,chosenWorkersValue):
        return getKBLoss(self,chosen_workers,chosenWorkersValue)

    def getCoeff(self,chosen_workers):
        return getKBCoeff(self,chosen_workers)

    
class KnownByzantineCrossEntropy(DiscriminatorCELinear):
    def __init__(self, modelSetup, workerDataset, save_folder, label='',byzantine_ids=None):
        super().__init__(modelSetup, workerDataset, save_folder, label)

        self.byzantine_ids, self.byzantine_coeff,self.normal_coeff=setupKBCoeff(byzantine_ids,self.device,self.worker_num)

    def getLossFromSigmoidValue(self,chosen_workers,chosenWorkersValue):
        return getKBLoss(self,chosen_workers,chosenWorkersValue)

    def getCoeff(self,chosen_workers):
        return getKBCoeff(self,chosen_workers)

class KnownByzantineCrossEntropy2(KnownByzantineSigmoid):
    def __init__(self, modelSetup, workerDataset, save_file, label='',byzantine_ids=None):
        super().__init__(modelSetup, workerDataset, save_file, label,byzantine_ids)

    def getByzantineInChosenWorker(self,chosen_workers):
        chosen_workers=torch.tensor(chosen_workers,device=self.device)
        mask=torch.isin(chosen_workers,self.byzantine_ids)
        return torch.arange(len(chosen_workers),device=self.device)[mask],mask

    def getTargetWeight(self,worker_ids,chosen_workers):
        '''
        data_num same
        '''        
        cbid,mask=self.getByzantineInChosenWorker(chosen_workers)
        target=torch.isin(worker_ids,cbid)

        subWokerDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)
        wokerCoeff=1/subWokerDataNum*torch.where(mask,self.byzantine_coeff,self.normal_coeff)
        weight=wokerCoeff[worker_ids]
        # byzantine_value=self.worker_num/self.byzantine_num
        # normal_value=self.worker_num/self.normal_num
        # weight=torch.where(target,byzantine_value,normal_value)
        return target.to(torch.float32),weight
    
    def getLossFromOut(self,out,worker_ids,chosen_workers):
        out=out.view(-1)
        target,weight=self.getTargetWeight(worker_ids,chosen_workers)
        loss=nn.BCEWithLogitsLoss(weight=weight,reduction='sum'
                                    )(out,target)
        loss=loss/len(chosen_workers)
        return loss
        

    def getCELoss(self,chosen_workers,batch_size):  
        lossFunc=self.getfunc(self.getLossFromOut,chosen_workers)
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)      
        loss=self.getValueFromDataset(lossFunc,chosenWorkerDataset,batch_size)
        return loss

    
    def getLossGrad(self,chosen_workers,batch_size):
        lossFunc=self.getfunc(self.getLossFromOut,chosen_workers)
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        grad_tuple=self.getJacFromDataset(lossFunc,chosenWorkerDataset,batch_size)
        return training.ModelPara(grad_tuple.tensors)
    
    # def testScatterValue(self,out,worker_ids,worker_num):
    #     loss_data_vec=self.sigmoid(out).view(-1)
    #     loss_data_vec1=torch.log(loss_data_vec)
    #     loss_data_vec0=torch.log(1-loss_data_vec)
    #     loss_data_vec1
    #     loss_worker_vec = torch.zeros((worker_num,), dtype=loss_data_vec.dtype,device=self.device)
    #     loss_worker_vec.scatter_add_(0, worker_ids, loss_data_vec)

    #     return loss_worker_vec
    