from .comlib import *


class MeanGradWithChosenWorkers():
    def __init__(self,criterion,batch_size,workers_list):
        self.criterion=criterion
        self.batch_size=batch_size
        self.linear_grad=fed_learning.LinearGradR(criterion,batch_size)
        self.workers_list=workers_list

    def get_weight(self,chosen_workers):
        chosen_workers=torch.tensor(chosen_workers)
        workers_list=torch.tensor(self.workers_list)
        weight=torch.zeros(len(chosen_workers))
        weight[torch.isin(chosen_workers,workers_list)]=1.0
        weight=weight/len(chosen_workers)
        return weight

    def get(self,chosen_workers,worker_dataset,modelSetup,*args):
        device=modelSetup.device
        weight=self.get_weight(chosen_workers).to(device)
        grad=self.linear_grad.get_weighted_grad(weight,chosen_workers,self.criterion,modelSetup,worker_dataset,self.batch_size)
        return grad.tensors