from .comlib import *
    

class LinearGrad():
    def __init__(self,criterion,batch_size):
        self.criterion=criterion
        self.batch_size=batch_size

    @staticmethod
    def normalizeCoeff(coeff,worker_dataset,chosen_workers:list):
        datanum=worker_dataset.getSubWokerDataNum(chosen_workers)
        coeff=coeff/datanum.to(coeff.device)
        return coeff

    @staticmethod
    def getWorkerWeightLoss(criterion_unreduced,weight,out,worker_ids,*out_args):        
        loss_data_vec=criterion_unreduced(out,*out_args).view(-1)
        loss=torch.sum(loss_data_vec*weight[worker_ids])
        return loss
    
    # def getInstantLossFunc2
    @staticmethod
    def get_weighted_grad(weight,chosen_workers,criterion,modelSetup:training.ModelSetup,worker_dataset:dataset.AggWorkersDatasetFromConf,batch_size):
        device=modelSetup.device
        chosenWorkerDataset=worker_dataset.getSubWokerDataset(chosen_workers)
        weight=LinearGrad.normalizeCoeff(weight,worker_dataset,chosen_workers).to(device)
        funcOut=lambda out, *out_args: LinearGrad.getWorkerWeightLoss(criterion,weight,out, *out_args)
        grad=modelSetup.calcuDatasetJacobian(funcOut,chosenWorkerDataset,batch_size)
        return grad

    def get(self,weight,chosen_workers,worker_dataset,modelSetup,*args):
        # weight=torch.ones(len(chosen_workers),device=modelSetup.device)/len(chosen_workers)
        grad=self.get_weighted_grad(weight,chosen_workers,self.criterion,modelSetup,worker_dataset,self.batch_size)
        return grad.tensors
    
class LinearGradR(LinearGrad):
    def __init__(self, criterion, batch_size):
        super().__init__(criterion, batch_size)

    @staticmethod
    def get_weight_for_redundent(weight,chosen_workers,worker_dataset):
        if hasattr(worker_dataset, 'redundentTensor') :
            rt=worker_dataset.redundentTensor
            rt:util.RedundentTensor
            rt=rt.get_sub(chosen_workers)
            unique_chosen_workers=rt.unique_keys
            new_weight=torch.zeros(len(unique_chosen_workers),device=weight.device)
            for i,key in enumerate(chosen_workers):
            # for key in rt.map:
                j=unique_chosen_workers.index(rt.map[key])
                new_weight[j]=new_weight[j]+weight[i]
            return new_weight,unique_chosen_workers
        else:
            return weight,chosen_workers
    def get(self,weight,chosen_workers,worker_dataset,modelSetup,*args):
        weight,chosen_workers=self.get_weight_for_redundent(weight,chosen_workers,worker_dataset)
        return super().get(weight,chosen_workers,worker_dataset,modelSetup,*args)
    

    
        # if hasattr(worker_dataset, 'redundentTensor') :
        #     rt=worker_dataset.redundentTensor
        #     rt:util.RedundentTensor
        #     rt=rt.get_sub(chosen_workers)
        #     unique_chosen_workers=rt.unique_keys
        #     new_weight=torch.zeros(len(unique_chosen_workers),device=weight.device)
        #     for key in rt.map:
        #         new_weight[rt.map[key]]=new_weight[rt.map[key]]+weight[key]
        #     return super().get(new_weight,unique_chosen_workers,worker_dataset,modelSetup,*args)
        # else:
        #     return super().get(weight,chosen_workers,worker_dataset,modelSetup,*args)