# import sys
# sys.path.append('/home/yjf/FL/robustfl')

from .comlib import *
from .discriminator_strategy import StrategySigmoidValue,LinearGrad

class MaxVariance(StrategySigmoidValue):
    def __init__(self, g_strategy,permutation=0,value_batch_size=400):
        super().__init__(g_strategy)
        self.permutation=permutation
        self.value_batch_size=value_batch_size

    def getVarCoeff(self,workerModel,chosen_workers):
        batch_size=self.value_batch_size
        workersValue=workerModel.getChosenWorkersSigmoidValue(chosen_workers,batch_size)
        workersValue=workersValue+self.permutation*torch.rand_like(workersValue)
        meanWorkersValue=torch.mean(workersValue)
        coeff=workersValue-meanWorkersValue
        # torch.clip_(coeff,min=0.01)
        return coeff

    @staticmethod
    def getVar(chosenWorkersValue):
        var=torch.var(chosenWorkersValue, dim=None, correction=0, keepdim=False, out=None)
        return var

    def getGrad(self,workerModel,chosen_workers):
        coeff=self.getVarCoeff(workerModel,chosen_workers)
        grad=self.g_strategy.getGrad(coeff,workerModel,chosen_workers)
        return grad

    def getInstantLoss(self,workerModel,chosen_workers):
        coeff=self.getVarCoeff(workerModel,chosen_workers)
        loss=self.g_strategy.getInstantLoss(coeff,workerModel,chosen_workers)
        return loss

    def getLossFromSigmoidValue(self,chosen_workers,chosenWorkersValue):
        return self.getVar(chosenWorkersValue)
    


    

        



