# import sys
# sys.path.append('/home/yjf/FL/robustfl')

from ..comlib import *
from ..z import DiscriminatorSigmoidLinear
from .cross_entro_grad import DiscriminatorCELinear

def getVarCoeff(d,chosen_workers):
    batch_size=400
    workersValue=d.getChosenWorkersSigmoidValue(chosen_workers,batch_size)
    workersValue=workersValue+d.permutation*torch.rand_like(workersValue)
    meanWorkersValue=torch.mean(workersValue)
    coeff=workersValue-meanWorkersValue
    # torch.clip_(coeff,min=0.01)
    return coeff

def getVar(chosenWorkersValue):
    var=torch.var(chosenWorkersValue, dim=None, correction=0, keepdim=False, out=None)
    return var

class MaxVariance(DiscriminatorSigmoidLinear):
    def __init__(self, modelSetup, workerDataset, save_file, label='',permutation=0):
        super().__init__(modelSetup, workerDataset, save_file, label)
        self.permutation=permutation

    def getLossFromSigmoidValue(self,chosen_workers,chosenWorkersValue):
        return getVar(chosenWorkersValue)

    def getCoeff(self,chosen_workers):
        return getVarCoeff(self,chosen_workers)
    
class MaxVarianceCE(DiscriminatorCELinear):
    def __init__(self, modelSetup, workerDataset, save_folder, label='',permutation=0):
        super().__init__(modelSetup, workerDataset, save_folder, label)
        self.permutation=permutation

    def getLossFromSigmoidValue(self,chosen_workers,chosenWorkersValue):
        return getVar(chosenWorkersValue)

    def getCoeff(self,chosen_workers):
        return getVarCoeff(self,chosen_workers)


        



