from .comlib import *

class WorkerModel():
    def __init__(self, modelSetup: training.ModelSetup, workerDataset:dataset.AggWorkersDatasetFromConf):
        '''
        model input: (data,target)
        '''
        self.device=modelSetup.device
        self.modelSetup=modelSetup
        self.workerDataset=workerDataset
        self.worker_num=self.workerDataset.worker_num

    @staticmethod
    def getfuncOut(func,*args):
        def f1(out,*out_args):
            return func(*args,out,*out_args)
        return f1
    
    @staticmethod
    def getfunc(func,funcOut):
        def f1(input_args,output_args):
            return func(input_args,funcOut,output_args)
        return f1
    
    @staticmethod
    def composefunc(f1,f2):
        return lambda *args:f1(f2(*args))
    
    
    @staticmethod
    def scatterValue(criterion_unreduced,worker_num,out,worker_ids,*out_args):
        device=out.device
        loss_data_vec=criterion_unreduced(out,*out_args).view(-1)
        loss_worker_vec = torch.zeros((worker_num,), dtype=loss_data_vec.dtype,device=device)
        loss_worker_vec.scatter_add_(0, worker_ids, loss_data_vec)

        return loss_worker_vec
    
    @staticmethod
    def scatterValues(criterion_list,worker_num,out,worker_ids,*out_args):
        device=out.device
        criterion_num=len(criterion_list)
        # loss_data_vec.dtype: torch.float32
        loss_worker_vec = torch.zeros((criterion_num,worker_num), dtype=torch.float32,device=device)
        for i, criterion_unreduced in enumerate(criterion_list):
            loss_data_vec=criterion_unreduced(out,*out_args).view(-1)
            loss_worker_vec[i].scatter_add_(0, worker_ids, loss_data_vec)
        return loss_worker_vec
    
    @staticmethod
    def getDataloaderAvg(dataloader,func):        
        value=util.MovingAvg()
        for batch_idx, (input_args,output_args) in enumerate(dataloader):
            # print("output_args",output_args,len(output_args))
            # print("len(input_args)",len(input_args))
            tempValue=func(input_args,output_args)
            value.update(tempValue,1)
            # value.update(tempValue,len(input_args))
        return value
    
    def getSubWorkerDataset(self,chosen_workers):
        return self.workerDataset.getSubWokerDataset(chosen_workers)
    
    def getSubWorkerDataNum(self,chosen_workers):
        return self.workerDataset.getSubWokerDataNum(chosen_workers)
    
    def getValueFromDataset(self,criterion_unreduced,dataset,batch_size,scatter_meth=None):
        '''
        criterion_unreduced or list of criterion_unreduced
        '''
        if scatter_meth is None:
            scatter_meth=self.scatterValue
        dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)
        funcLoss=self.getfuncOut(scatter_meth,criterion_unreduced,dataset.worker_num)
        func=self.getfunc(self.modelSetup.calcuLoss,funcLoss)
        value=self.getDataloaderAvg(dataloader,func)
        return value.get_sum()

    def getChosenWorkersValue(self,valueFunc,chosen_workers,batch_size)->torch.tensor:
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)
        
        workerValues=self.getValueFromDataset(valueFunc,chosenWorkerDataset,batch_size)

        workerValues=workerValues/chosenWorkersDataNum
        if isinstance(self.workerDataset, dataset.AggWorkersDatasetWithRedundancy):
            workerValues=self.workerDataset.restoreFromUnique(workerValues,chosen_workers)
        return workerValues
    
    def getChosenWorkersValues(self,valueFunc,chosen_workers,batch_size)->torch.tensor:
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)
        
        workerValues=self.getValueFromDataset(valueFunc,chosenWorkerDataset,batch_size,self.scatterValues)
        for i in range(len(workerValues)):
            workerValues[i]=workerValues[i]/chosenWorkersDataNum
        return workerValues
    
    def getJacFromDataset(self,criterion_unreduced,dataset,batch_size)->training.TensorTuple:
        dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)
        funcLoss=self.getfuncOut(self.scatterValue,criterion_unreduced,dataset.worker_num)
        funcGetJac=self.composefunc(training.ModelParaS,self.modelSetup.calcuJacobian)
        func=self.getfunc(funcGetJac,funcLoss)
        jacobi_tuple=self.getDataloaderAvg(dataloader,func)
        return jacobi_tuple.get_sum()
        

    def getChosenWorkersGrad(self,valueFunc,chosen_workers,batch_size,flat=False)->training.ModelParaS|util.FlatModelParaS:
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)

        jacobi_tuple=self.getJacFromDataset(valueFunc,chosenWorkerDataset,batch_size)
        jacobi_tuple.divide_coeffs(chosenWorkersDataNum)

        if isinstance(self.workerDataset, dataset.AggWorkersDatasetWithRedundancy):
            jacobi=util.flattenToFlatModelParaS(jacobi_tuple.tensors)
            jacobi=self.workerDataset.restoreFromUnique(jacobi.tensors,chosen_workers)
            jacobi=util.FlatModelParaS(jacobi)
        elif flat:
            jacobi=util.flattenToFlatModelParaS(jacobi_tuple.tensors)
        else:
            jacobi=jacobi_tuple

        return jacobi 
    
class WorkerModelKernel(WorkerModel):   
    @staticmethod
    def segment_list(lst,seg_len):
        n = len(lst)
        ilst=list(range(n))
        # seg_len = (n + seg_num - 1) // seg_num        # 向上取整
        seg_num=(n + seg_len - 1) // seg_len 
        return [ilst[i*seg_len : (i+1)*seg_len] for i in range(seg_num)],[lst[i*seg_len : (i+1)*seg_len] for i in range(seg_num)]

    def getFlattenChosenWorkersGrad(self,valueFunc,chosen_workers,batch_size):
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)

        jacobi_tuple=self.getJacFromDataset(valueFunc,chosenWorkerDataset,batch_size)
        jacobi_tuple.divide_coeffs(chosenWorkersDataNum)

        jacobi=util.flattenToFlatModelParaS(jacobi_tuple.tensors)
        return jacobi.tensors

    def getChosenWorkersGradKernel(self,valueFunc,chosen_workers,batch_size,seg_len):
        chosen_num=len(chosen_workers)
        seg_indices,seg_workers=self.segment_list(chosen_workers,seg_len)
        kernel=torch.zeros([chosen_num,chosen_num],device=self.device)
        for i,indices_i in enumerate(seg_indices):
            jacobi1=self.getFlattenChosenWorkersGrad(valueFunc,seg_workers[i],batch_size)
            for j,indices_j in enumerate(seg_indices):
                jacobi2=self.getFlattenChosenWorkersGrad(valueFunc,seg_workers[j],batch_size)
                kernel[np.ix_(indices_i, indices_j)]=jacobi1 @ jacobi2.T

        return kernel
    
    def getChosenWorkersGradKernelHalf(self,valueFunc,chosen_workers,batch_size,seg_len):
        chosen_num=len(chosen_workers)
        seg_indices,seg_workers=self.segment_list(chosen_workers,seg_len)
        kernel=torch.zeros([chosen_num,chosen_num],device=self.device)
        for i,indices_i in enumerate(seg_indices):
            jacobi1=self.getFlattenChosenWorkersGrad(valueFunc,seg_workers[i],batch_size)
            kernel[np.ix_(indices_i, indices_i)]=jacobi1 @ jacobi1.T
            for j,indices_j in enumerate(seg_indices[i+1:]):
                jacobi2=self.getFlattenChosenWorkersGrad(valueFunc,seg_workers[i+1+j],batch_size)
                kernel[np.ix_(indices_i, indices_j)]=jacobi1 @ jacobi2.T
                kernel[np.ix_(indices_j, indices_i)]=kernel[np.ix_(indices_i, indices_j)].T
        return kernel
    
    def getChosenWorkersGradKernelPrior(self,valueFunc,chosen_workers,batch_size,seg_len,r_map,priorList=None):
        rt=util.RedundentTensor(r_map)
        rt=rt.get_sub(chosen_workers)

        kernel=self.getChosenWorkersGradKernelHalf(valueFunc,priorList,batch_size,seg_len)
        inverse_idx=[priorList.index(r_map[i]) for i in chosen_workers]
        # print(inverse_idx)
        restored_kernel=kernel[inverse_idx][:, inverse_idx]
        return restored_kernel

    

class WrapWorkers():
    def __init__(self):
        pass
    @staticmethod
    def wrap(worker_ids,data,target):
        input_args=(data,)
        output_args=(worker_ids,target,)
        return input_args,output_args

class WrapWorkersDiscriminator():
    def __init__(self):
        pass
    @staticmethod
    def wrap(worker_ids,data,target):
        input_args=(data,target)
        output_args=(worker_ids,)
        return input_args,output_args
    

def create_wrapper(wrap_method:Literal["default","discrim"]):
    if wrap_method=="default":
        wrapper=WrapWorkers()
    elif wrap_method=="discrim":
        wrapper=WrapWorkersDiscriminator()
    else:
        print("invalid wrap_method")
    return wrapper