# import sys
# sys.path.append('/home/yjf/FL/robustfl')

from .comlib import *
from .worker_model import WorkerModelSigmoid

class MaxVariance():
    def __init__(self,permutation=0,value_batch_size=400):
        self.permutation=permutation
        self.value_batch_size=value_batch_size

    def get(self,chosen_workers,workerModel:WorkerModelSigmoid):
        '''
        get weight for linear grad
        '''
        batch_size=self.value_batch_size
        workersValue=workerModel.getValue(chosen_workers,batch_size)
        workersValue=workersValue+self.permutation*torch.rand_like(workersValue)
        meanWorkersValue=torch.mean(workersValue)
        coeff=workersValue-meanWorkersValue
        # torch.clip_(coeff,min=0.01)
        return coeff

    # @staticmethod
    # def getVar(chosenWorkersValue):
    #     var=torch.var(chosenWorkersValue, dim=None, correction=0, keepdim=False, out=None)
    #     return var

    # def getGrad(self,workerModel,chosen_workers):
    #     coeff=self.getVarCoeff(workerModel,chosen_workers)
    #     grad=self.g_strategy.getGrad(coeff,workerModel,chosen_workers)
    #     return grad

    # def getInstantLoss(self,workerModel,chosen_workers):
    #     coeff=self.getVarCoeff(workerModel,chosen_workers)
    #     loss=self.g_strategy.getInstantLoss(coeff,workerModel,chosen_workers)
    #     return loss

    # def getLossFromSigmoidValue(self,chosen_workers,chosenWorkersValue):
    #     return self.getVar(chosenWorkersValue)
    


    

        



