from .comlib import *

class WorkerModel():
    def __init__(self, modelSetup: training.ModelSetup, workerDataset:dataset.AggWorkersDatasetFromConf):
        '''
        model input: (data,target)
        '''
        self.device=modelSetup.device
        self.modelSetup=modelSetup
        self.workerDataset=workerDataset
        self.worker_num=self.workerDataset.worker_num

    @staticmethod
    def getfuncOut(func,*args):
        def f1(out,*out_args):
            return func(*args,out,*out_args)
        return f1
    
    @staticmethod
    def getfunc(func,funcOut):
        def f1(input_args,output_args):
            return func(input_args,funcOut,output_args)
        return f1
    
    @staticmethod
    def composefunc(f1,f2):
        return lambda *args:f1(f2(*args))
    
    
    @staticmethod
    def scatterValue(criterion_unreduced,worker_num,out,worker_ids,*out_args):
        device=out.device
        loss_data_vec=criterion_unreduced(out,*out_args).view(-1)
        loss_worker_vec = torch.zeros((worker_num,), dtype=loss_data_vec.dtype,device=device)
        loss_worker_vec.scatter_add_(0, worker_ids, loss_data_vec)

        return loss_worker_vec
    
    @staticmethod
    def scatterValues(criterion_list,worker_num,out,worker_ids,*out_args):
        device=out.device
        criterion_num=len(criterion_list)
        # loss_data_vec.dtype: torch.float32
        loss_worker_vec = torch.zeros((criterion_num,worker_num), dtype=torch.float32,device=device)
        for i, criterion_unreduced in enumerate(criterion_list):
            loss_data_vec=criterion_unreduced(out,*out_args).view(-1)
            loss_worker_vec[i].scatter_add_(0, worker_ids, loss_data_vec)
        return loss_worker_vec
    
    @staticmethod
    def getDataloaderAvg(dataloader,func):        
        value=util.MovingAvg()
        for batch_idx, (input_args,output_args) in enumerate(dataloader):
            # print("output_args",output_args,len(output_args))
            # print("len(input_args)",len(input_args))
            tempValue=func(input_args,output_args)
            value.update(tempValue,1)
            # print("tempValue",len(value.mean),len(value.num))
            # value.update(tempValue,len(input_args))
        return value
    
    def getSubWorkerDataset(self,chosen_workers):
        return self.workerDataset.getSubWokerDataset(chosen_workers)
    
    def getSubWorkerDataNum(self,chosen_workers):
        return self.workerDataset.getSubWokerDataNum(chosen_workers)
    
    def getValueFromDataset(self,criterion_unreduced,dataset,batch_size,scatter_meth=None):
        '''
        criterion_unreduced or list of criterion_unreduced
        '''
        if scatter_meth is None:
            scatter_meth=self.scatterValue
        # print("dataset",len(dataset),batch_size)
        dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)
        funcLoss=self.getfuncOut(scatter_meth,criterion_unreduced,dataset.worker_num)
        func=self.getfunc(self.modelSetup.calcuLoss,funcLoss)
        with torch.no_grad():
            value=self.getDataloaderAvg(dataloader,func)
        return value.get_sum()

    def getChosenWorkersValue(self,valueFunc,chosen_workers,batch_size)->torch.Tensor:
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)
        
        workerValues=self.getValueFromDataset(valueFunc,chosenWorkerDataset,batch_size)

        workerValues=workerValues/chosenWorkersDataNum
        
        return workerValues
    
    
    def getChosenWorkersValues(self,valueFunc,chosen_workers,batch_size)->torch.Tensor:
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)
        
        workerValues=self.getValueFromDataset(valueFunc,chosenWorkerDataset,batch_size,self.scatterValues)
        for i in range(len(workerValues)):
            workerValues[i]=workerValues[i]/chosenWorkersDataNum
        return workerValues
    
    def getJacFromDataset_(self,funcOut,dataset,batch_size)->training.TensorTuple:
        dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)
        funcGetJac=self.composefunc(training.ModelParaS,self.modelSetup.calcuJacobian)
        func=self.getfunc(funcGetJac,funcOut)
        jacobi_tuple=self.getDataloaderAvg(dataloader,func)
        return jacobi_tuple.get_sum()
    
    def getJacFromDataset(self,criterion_unreduced,dataset,batch_size)->training.TensorTuple:
        funcLoss=self.getfuncOut(self.scatterValue,criterion_unreduced,dataset.worker_num)
        return self.getJacFromDataset_(funcLoss,dataset,batch_size)

        # dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)
        # funcLoss=self.getfuncOut(self.scatterValue,criterion_unreduced,dataset.worker_num)
        # funcGetJac=self.composefunc(training.ModelParaS,self.modelSetup.calcuJacobian)
        # func=self.getfunc(funcGetJac,funcLoss)
        # jacobi_tuple=self.getDataloaderAvg(dataloader,func)
        # return jacobi_tuple.get_sum()
        

    def getChosenWorkersGrad(self,valueFunc,chosen_workers,batch_size,flat=False)->training.ModelParaS|util.FlatModelParaS:
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        chosenWorkersDataNum=self.workerDataset.getSubWokerDataNum(chosen_workers).to(self.device)

        jacobi_tuple=self.getJacFromDataset(valueFunc,chosenWorkerDataset,batch_size)
        jacobi_tuple.divide_coeffs(chosenWorkersDataNum)
        if flat:
            return util.flattenToFlatModelParaS(jacobi_tuple.tensors)
        else:
            return jacobi_tuple

    
class WorkerModelR(WorkerModel):
    def __init__(self, modelSetup, workerDataset):
        super().__init__(modelSetup, workerDataset)
        self.redundentTensor=None
        if isinstance(self.workerDataset, dataset.AggWorkersDatasetWithRedundancy):
            self.redundentTensor=self.workerDataset.redundentTensor
    def getChosenWorkersValue(self,valueFunc,chosen_workers,batch_size)->torch.Tensor:
        if self.redundentTensor is None:
            return super().getChosenWorkersValue(valueFunc,chosen_workers,batch_size)
        else:
            rt=self.redundentTensor.get_sub(chosen_workers)
            values=super().getChosenWorkersValue(valueFunc,rt.unique_keys,batch_size)
            return values[rt.inverse_pos]
        
    def getChosenWorkersValues(self,valueFunc,chosen_workers,batch_size)->torch.Tensor:
        if self.redundentTensor is None:
            return super().getChosenWorkersValues(valueFunc,chosen_workers,batch_size)
        else:
            rt=self.redundentTensor.get_sub(chosen_workers)
            values=super().getChosenWorkersValues(valueFunc,rt.unique_keys,batch_size)
            return values[:,rt.inverse_pos]
        
    def getChosenWorkersGrad(self,valueFunc,chosen_workers,batch_size,flat=False)->util.FlatModelParaS:
        if self.redundentTensor is None:
            return super().getChosenWorkersGrad(valueFunc,chosen_workers,batch_size,flat)
        else:
            rt=self.redundentTensor.get_sub(chosen_workers)
            jacobi=super().getChosenWorkersGrad(valueFunc,rt.unique_keys,batch_size,True)
            # inverse_idx=[rt.unique_keys.index(rt.map[i]) for i in chosen_workers]
            return util.FlatModelParaS(jacobi.tensors[rt.inverse_pos])
    
class WorkerModelKernel(WorkerModelR):  
    def __init__(self, modelSetup, workerDataset):
        super().__init__(modelSetup, workerDataset)

    @staticmethod
    def segment_list(lst,seg_len):
        n = len(lst)
        ilst=list(range(n))
        # seg_len = (n + seg_num - 1) // seg_num        # 向上取整
        seg_num=(n + seg_len - 1) // seg_len 
        return [ilst[i*seg_len : (i+1)*seg_len] for i in range(seg_num)],[lst[i*seg_len : (i+1)*seg_len] for i in range(seg_num)]

    def getChosenWorkersGradKernelFull(self,valueFunc,chosen_workers,batch_size,seg_len):
        chosen_num=len(chosen_workers)
        seg_indices,seg_workers=self.segment_list(chosen_workers,seg_len)
        kernel=torch.zeros([chosen_num,chosen_num],device=self.device)
        for i,indices_i in enumerate(seg_indices):
            jacobi1=self.getChosenWorkersGrad(valueFunc,seg_workers[i],batch_size,flat=True).tensors
            for j,indices_j in enumerate(seg_indices):
                # print("j",j,indices_j,seg_workers[j])
                jacobi2=self.getChosenWorkersGrad(valueFunc,seg_workers[j],batch_size,flat=True).tensors
                kernel[np.ix_(indices_i, indices_j)]=jacobi1 @ jacobi2.T

        return kernel
    
    def getChosenWorkersGradKernelHalf(self,valueFunc,chosen_workers,batch_size,seg_len):
        chosen_num=len(chosen_workers)
        seg_indices,seg_workers=self.segment_list(chosen_workers,seg_len)
        kernel=torch.zeros([chosen_num,chosen_num],device=self.device)
        for i,indices_i in enumerate(seg_indices):
            jacobi1=self.getChosenWorkersGrad(valueFunc,seg_workers[i],batch_size,flat=True).tensors
            kernel[np.ix_(indices_i, indices_i)]=jacobi1 @ jacobi1.T
            for j,indices_j in enumerate(seg_indices[i+1:]):
                jacobi2=self.getChosenWorkersGrad(valueFunc,seg_workers[i+1+j],batch_size,flat=True).tensors
                kernel[np.ix_(indices_i, indices_j)]=jacobi1 @ jacobi2.T
                kernel[np.ix_(indices_j, indices_i)]=kernel[np.ix_(indices_i, indices_j)].T
        return kernel
    
    def getChosenWorkersGradKernelPrior(self,valueFunc,chosen_workers,batch_size,seg_len,priorList=None):
        '''
        not normalized, not centered
        '''
        if self.redundentTensor is None:
            rt=util.RedundentTensor({i:i for i in chosen_workers})
        else:
            rt=self.redundentTensor.get_sub(chosen_workers)

        if priorList is None:
            priorList=rt.unique_keys

        kernel=self.getChosenWorkersGradKernelHalf(valueFunc,priorList,batch_size,seg_len)
        inverse_idx=[priorList.index(rt.map[i]) for i in chosen_workers]
        # print(inverse_idx)
        restored_kernel=kernel[inverse_idx][:, inverse_idx]
        return restored_kernel
    
    def getChosenWorkersGradKernel(self,valueFunc,chosen_workers,batch_size,seg_len,priorList=None,centered=True,normalize=True):
        kernel=self.getChosenWorkersGradKernelPrior(valueFunc,chosen_workers,batch_size,seg_len,priorList)
        if normalize:
            kernel=kernel/len(kernel)
        if centered:
            kernel=util.Kernel.center_kernel(kernel)
        return kernel

    




    # @staticmethod
    # def center_kernel(A):
    #     """
    #     输入: A  (n*n) 协方差矩阵 A = G Gᵀ
    #     输出: B  (n*n) 中心化后的核矩阵 H Hᵀ
    #     """
    #     n = A.size(0)
    #     device = A.device
    #     dtype = A.dtype
    #     ones = torch.ones(n, 1, device=device, dtype=dtype)
        
    #     # 中心矩阵 I - 1 1ᵀ / n
    #     C = torch.eye(n, device=device, dtype=dtype) - ones @ ones.T / n
        
    #     # 公式: B = C A Cᵀ
    #     B = C @ A @ C.T
    #     return B