from ..comlib import *
from .abstract_discim import Discriminator

class DiscriminatorSigmoid(Discriminator):
    def __init__(self, modelSetup, workerDataset, save_folder, label=''):
        super().__init__(modelSetup, workerDataset, save_folder, label)
        self.sigmoid=nn.Sigmoid()

    def scatterSigmoidValue(self,out,worker_ids,worker_num):
        loss_data_vec=self.sigmoid(out).view(-1)
        loss_worker_vec = torch.zeros((worker_num,), dtype=loss_data_vec.dtype,device=self.device)
        loss_worker_vec.scatter_add_(0, worker_ids, loss_data_vec)

        return loss_worker_vec
    
    def getChosenWorkersSigmoidValue(self,chosen_workers: list[int],batch_size):
        chosenWorkerNum=len(chosen_workers)
        valueFunc=self.getfunc(self.scatterSigmoidValue,chosenWorkerNum)
        return self.getChosenWorkersValue(valueFunc,chosen_workers,batch_size)
    
    def getChosenWorkersSigmoidGrad(self,chosen_workers: list[int],batch_size):
        chosenWorkerNum=len(chosen_workers)
        valueFunc=self.getfunc(self.scatterSigmoidValue,chosenWorkerNum)
        return self.getChosenWorkersGrad(valueFunc,chosen_workers,batch_size)
    
    @abstractmethod
    def getLossFromSigmoidValue(self,chosen_workers,chosenWorkersValue):
        pass
    
    def getLoss(self,chosen_workers,batch_size):
        workersValue=self.getChosenWorkersSigmoidValue(chosen_workers,batch_size)
        return self.getLossFromSigmoidValue(chosen_workers,workersValue)
    
    def getCurStat(self,round):
        workers=list(range(self.worker_num))
        workerValues=self.getChosenWorkersSigmoidValue(workers,batch_size=400)
        self.appendRoundToSheet(round,workerValues)

    def appendRoundToSheet(self,round,workerValues):
        row=[
            {"round":round,
             "weight_norm":self.modelSetup.gettWeightNorm().item()}
        ]
        for i in range(self.worker_num):
            row[0][i]=workerValues[i].item()

        self.saveCsvObj.append_data(self.sheet_name,row)

    @staticmethod
    def get_sheet_name(label):
        return f"DiscrimWorkerValue{label}"
    
    def read_to_df(self):
        return self.saveCsvObj.read_to_df(self.sheet_name)
        
    def set_sheets(self):
        self.sheet_name=self.get_sheet_name(self.label)
        sheets={
            self.sheet_name:["round","weight_norm"]+list(range(self.worker_num))
        }
        self.saveCsvObj=save_log.SaveCsv(self.save_folder,sheets)

class DiscriminatorSigmoidLinear(DiscriminatorSigmoid):
    def __init__(self, modelSetup, workerDataset, save_folder, label=''):
        super().__init__(modelSetup, workerDataset, save_folder, label)

    @abstractmethod
    def getCoeff(self,chosen_workers):
        pass

    def normalizeCoeff(self,coeff,chosen_workers:list):
        return coeff/self.workerDataset.getSubWokerDataNum(chosen_workers).to(coeff.device)
        
    
    def getLossGrad(self,chosen_workers,batch_size):
        chosenWorkerNum=len(chosen_workers)
        coeff=self.getCoeff(chosen_workers)
        coeff=self.normalizeCoeff(coeff,chosen_workers)
        def f1(out,*out_args):
            temp=coeff*self.scatterSigmoidValue(out,*out_args,chosenWorkerNum)
            return torch.mean(temp)
        chosenWorkerDataset=self.workerDataset.getSubWokerDataset(chosen_workers)
        grad_tuple=self.getJacFromDataset(f1,chosenWorkerDataset,batch_size)
        return training.ModelPara(grad_tuple.tensors)
    
    

    def testGetLossGrad(self,chosen_workers,batch_size):
        workersGrad=self.getChosenWorkersSigmoidGrad(chosen_workers,batch_size)
        coeff=self.getCoeff(chosen_workers)

        workersGrad.mult_coeffs(coeff)
        grad_tuple=workersGrad.mean()
        return grad_tuple
    
    def testLinearLossEqLoss(self,chosen_workers,batch_size):
        '''
        not general
        '''
        workersValue=self.getChosenWorkersSigmoidValue(chosen_workers,batch_size)
        coeff=self.getCoeff(chosen_workers)
        # print("testLinearLossEqLoss", torch.mean(workersValue*coeff),self.getLossFromSigmoidValue(chosen_workers,workersValue))
        assert torch.allclose(torch.mean(workersValue*coeff),self.getLossFromSigmoidValue(chosen_workers,workersValue))