from .comlib import *

class WorkerModel():
    def __init__(self, modelSetup: training.ModelSetup, workerDataset:dataset.AggWorkersDatasetFromConf):
        '''
        model input: (data,target)
        '''
        self.device=modelSetup.device
        self.modelSetup=modelSetup
        self.workerDataset=workerDataset
        self.worker_num=self.workerDataset.worker_num


    def getfunc(self,func,*args):
        def f1(out,*out_args):
            return func(out,*out_args,*args)
        return f1
    
    def getSubWorkerDataset(self,chosen_workers):
        return self.workerDataset.getSubWokerDataset(chosen_workers)
    
    def getSubWorkerDataNum(self,chosen_workers):
        return self.workerDataset.getSubWokerDataNum(chosen_workers)
    
    def getValueFromDataset(self,func,dataset,batch_size):
        dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)
        
        value=util.MovingAvg()
        for batch_idx, (worker_ids, data, target) in enumerate(dataloader):
            input_args=(data,target)
            output_args=(worker_ids,)
            tempValue=self.modelSetup.calcuLoss(input_args,func,output_args)
            value.update(tempValue,1)
        return value.get_sum()

    def getChosenWorkersValue(self,valueFunc,chosen_workers,batch_size)->torch.tensor:
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)
        
        workerValues=self.getValueFromDataset(valueFunc,chosenWorkerDataset,batch_size)

        workerValues=workerValues/chosenWorkersDataNum
        return workerValues
    
    def getJacFromDataset(self,func,dataset,batch_size)->training.TensorTuple:
        dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)
 
        jacobi_tuple=util.MovingAvg()
        for batch_idx, (worker_ids, data, target) in enumerate(dataloader):
            input_args=(data,target)
            output_args=(worker_ids,)
            tempGrad=self.modelSetup.calcuJacobian(input_args,func,output_args)
            jacobi_tuple.update(training.TensorTuple(tempGrad),1)

        return jacobi_tuple.get_sum()
    

    def getChosenWorkersGrad(self,valueFunc,chosen_workers,batch_size)->training.ModelParaS:
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)

        jacobi_tuple=self.getJacFromDataset(valueFunc,chosenWorkerDataset,batch_size)

        workersGrad=training.ModelParaS(jacobi_tuple.tensors)
        workersGrad.divide_coeffs(chosenWorkersDataNum)

        return workersGrad 
    

class WorkerModelSigmoid(WorkerModel):
    def __init__(self, modelSetup, workerDataset):
        super().__init__(modelSetup, workerDataset)
        # self.sigmoid=nn.Sigmoid()

    @staticmethod
    def scatterSigmoidValue(out,worker_ids,worker_num):
        device=out.device
        loss_data_vec=nn.Sigmoid()(out).view(-1)
        loss_worker_vec = torch.zeros((worker_num,), dtype=loss_data_vec.dtype,device=device)
        loss_worker_vec.scatter_add_(0, worker_ids, loss_data_vec)

        return loss_worker_vec
    
    def getChosenWorkersSigmoidValue(self,chosen_workers: list[int],batch_size):
        chosenWorkerNum=len(chosen_workers)
        valueFunc=self.getfunc(self.scatterSigmoidValue,chosenWorkerNum)
        return self.getChosenWorkersValue(valueFunc,chosen_workers,batch_size)
    
    def getChosenWorkersSigmoidGrad(self,chosen_workers: list[int],batch_size):
        chosenWorkerNum=len(chosen_workers)
        valueFunc=self.getfunc(self.scatterSigmoidValue,chosenWorkerNum)
        return self.getChosenWorkersGrad(valueFunc,chosen_workers,batch_size)