from .comlib import *
from .linear_grad import LinearGradR
    
class FedLearn(training.Train11):
    def __init__(self, evalObjs, save_folder, max_epoch, chosenWorkerNum=None,grad_strategy=None,save_name="train",save_round_inteval=None, save_epoch_inteval=None,chosenWorkerRatio=None,gaggEvalObj=None):
        super().__init__(evalObjs, save_folder, None, max_epoch,save_name, save_round_inteval, save_epoch_inteval)
        self.chosenWorkerNum=chosenWorkerNum
        self.grad_strategy=grad_strategy
        self.chosenWorkerRatio=chosenWorkerRatio
        self.gaggEvalObj=gaggEvalObj

    def save_model_stat(self,modelSetup,flag):
        if flag:
            self.saveObj.add_data_to_row({"weight_norm":self.get_weight_norm(modelSetup)})
            for key in self.evalObjs:
                r=self.evalObjs[key].get(modelSetup)
                self.saveObj.add_data_to_row(r)

    def get_save_func(self,key,flag):
        if flag & (self.gaggEvalObj is not None):
            def save_func(x):
                r=self.gaggEvalObj.get(x)
                self.saveObj.add_data_to_row(r)
            return save_func
        else:
            return None
        
    def get_chosenWorkerNum(self,worker_num):
        if self.chosenWorkerNum is not None:
            chosenWorkerNum=self.chosenWorkerNum
        elif self.chosenWorkerRatio is not None:
            chosenWorkerNum=int(self.chosenWorkerRatio*worker_num)
        else:
            print("no chosenWorkerNum")
        return chosenWorkerNum
        
    def get_chosen_workers(self,worker_num):
        chosenWorkerNum=self.get_chosenWorkerNum(worker_num)
        return random.sample(range(worker_num), k=chosenWorkerNum)

    def train(self,worker_dataset,modelSetup:training.ModelSetup,
          optimizer):
        worker_num=worker_dataset.worker_num
        round_per_epoch=int(worker_num/self.get_chosenWorkerNum(worker_num))

        for epoch in range(self.max_epoch):
            for round in range(round_per_epoch):
                flag=self.get_save_flag(epoch,round_per_epoch,round)
                self.save_model_stat(modelSetup,flag)

                chosen_workers=self.get_chosen_workers(worker_num)
                grad=self.grad_strategy.get(chosen_workers,worker_dataset,modelSetup,
                                            self.get_save_func("grad_agg",flag))

                self.save_grad_norm(grad,flag)

                optimizer.stepByGradTuple(grad)
                self.saveObj.save_cur_row()

class MeanGrad():
    def __init__(self,criterion,batch_size):
        self.criterion=criterion
        self.batch_size=batch_size
        self.linear_grad=LinearGradR(criterion,batch_size)

    def get(self,chosen_workers,worker_dataset,modelSetup,*args):
        weight=torch.ones(len(chosen_workers),device=modelSetup.device)/len(chosen_workers)
        grad=self.linear_grad.get_weighted_grad(weight,chosen_workers,self.criterion,modelSetup,worker_dataset,self.batch_size)
        return grad.tensors


class MeanGrad_vb_1():
    def __init__(self,criterion,batch_size):
        self.criterion=criterion
        self.batch_size=batch_size

    @staticmethod
    def normalizeCoeff(coeff,worker_dataset,chosen_workers:list):
        datanum=worker_dataset.getSubWokerDataNum(chosen_workers)
        coeff=coeff/datanum.to(coeff.device)
        return coeff

    @staticmethod
    def getWorkerWeightLoss(criterion_unreduced,weight,out,worker_ids,*out_args):        
        loss_data_vec=criterion_unreduced(out,*out_args).view(-1)
        loss=torch.sum(loss_data_vec*weight[worker_ids])
        return loss
    
    # def getInstantLossFunc2
    @staticmethod
    def get_weighted_grad(weight,chosen_workers,criterion,modelSetup,worker_dataset,batch_size):
        device=modelSetup.device
        chosenWorkerDataset=worker_dataset.getSubWokerDataset(chosen_workers)
        weight=MeanGrad_vb_1.normalizeCoeff(weight,worker_dataset,chosen_workers).to(device)
        funcOut=lambda out, *out_args: MeanGrad_vb_1.getWorkerWeightLoss(criterion,weight,out, *out_args)
        grad=modelSetup.calcuDatasetJacobian(funcOut,chosenWorkerDataset,batch_size)
        return grad

    def get(self,chosen_workers,worker_dataset,modelSetup,*args)->tuple[torch.Tensor]:
        weight=torch.ones(len(chosen_workers),device=modelSetup.device)/len(chosen_workers)
        grad=self.get_weighted_grad(weight,chosen_workers,self.criterion,modelSetup,worker_dataset,self.batch_size)
        return grad.tensors
    
    @staticmethod
    def getWorkerWeightLoss2(criterion_unreduced,weight,out,worker_ids,*out_args):
        device=out.device
        worker_num=len(weight)
        loss_data_vec=criterion_unreduced(out,*out_args).view(-1)
        loss_worker_vec = torch.zeros((worker_num,), dtype=loss_data_vec.dtype,device=device)
        loss_worker_vec.scatter_add_(0, worker_ids, loss_data_vec)
        loss=torch.sum(loss_worker_vec*weight)

        return loss