from ..comlib import *

class Discriminator(ABC):
    def __init__(self, modelSetup: training.ModelSetup, workerDataset:dataset.AggWorkersDatasetFromConf, save_folder,label=''):
        '''
        model input: (data,target)
        '''
        self.device=modelSetup.device
        self.modelSetup=modelSetup
        self.workerDataset=workerDataset
        self.worker_num=self.workerDataset.worker_num
        self.sigmoid=nn.Sigmoid()

        self.save_folder=save_folder
        self.label=label
        self.set_sheets()

    
    def getfunc(self,func,*args):
        def f1(out,*out_args):
            return func(out,*out_args,*args)
        return f1
    
    def getValueFromDataset(self,func,dataset,batch_size):
        dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)
        
        value=util.MovingAvg()
        for batch_idx, (worker_ids, data, target) in enumerate(dataloader):
            input_args=(data,target)
            output_args=(worker_ids,)
            tempValue=self.modelSetup.calcuLoss(input_args,func,output_args)
            value.update(tempValue,1)
        return value.get_sum()

    def getChosenWorkersValue(self,valueFunc,chosen_workers,batch_size)->torch.tensor:
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)
        
        workerValues=self.getValueFromDataset(valueFunc,chosenWorkerDataset,batch_size)

        workerValues=workerValues/chosenWorkersDataNum
        return workerValues
    
    def getJacFromDataset(self,func,dataset,batch_size)->training.TensorTuple:
        dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)
 
        jacobi_tuple=util.MovingAvg()
        for batch_idx, (worker_ids, data, target) in enumerate(dataloader):
            input_args=(data,target)
            output_args=(worker_ids,)
            tempGrad=self.modelSetup.calcuJacobian(input_args,func,output_args)
            jacobi_tuple.update(training.TensorTuple(tempGrad),1)

        return jacobi_tuple.get_sum()
    

    def getChosenWorkersGrad(self,valueFunc,chosen_workers,batch_size)->training.ModelParaS:
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)

        jacobi_tuple=self.getJacFromDataset(valueFunc,chosenWorkerDataset,batch_size)

        workersGrad=training.ModelParaS(jacobi_tuple.tensors)
        workersGrad.divide_coeffs(chosenWorkersDataNum)

        return workersGrad        


    @abstractmethod
    def getLoss(self,chosen_workers,batch_size):
        pass
    
    @abstractmethod
    def getLossGrad(self,chosen_workers,batch_size):
        pass
    

    def maximizeLoss(self,max_round, optimizer:training.Optimizer,chosenWorkerNum,batch_size=400):  
        '''
        optimizer: maximize true?
        '''   
        
        for round in range(max_round):     
            chosenWorkers=random.sample(range(self.worker_num), k=chosenWorkerNum)
            grad_tuple=self.getLossGrad(chosenWorkers,batch_size).tensors

            optimizer.stepByGradTuple(grad_tuple)
            if (round%10)==0:
                self.getCurStat(round)

    @abstractmethod
    def getCurStat(self,round):
        pass

    @abstractmethod
    def set_sheets(self):
        pass

