from .comlib import *
from .worker_model import WorkerModelSigmoid

class LinearGrad(ABC):
    def __init__(self,batch_size):
        self.batch_size=batch_size

    @abstractmethod
    def getInstantLossFunc(self,coeff,
                           chosen_workers):
        pass
    
    def getInstantLoss(self,coeff,
                           workerModel:WorkerModelSigmoid,
                           chosen_workers):
        lossFunc=self.getInstantLossFunc(coeff,workerModel,chosen_workers)

        chosenWorkerDataset=workerModel.getSubWorkerDataset(chosen_workers)   

        loss=workerModel.getValueFromDataset(lossFunc,chosenWorkerDataset,self.batch_size)
        return loss
    
    def getGrad(self,coeff,
                           workerModel:WorkerModelSigmoid,
                           chosen_workers):
        lossFunc=self.getInstantLossFunc(coeff,workerModel,chosen_workers)

        chosenWorkerDataset=workerModel.getSubWorkerDataset(chosen_workers)

        grad_tuple=workerModel.getJacFromDataset(lossFunc,chosenWorkerDataset,self.batch_size)
        return training.ModelPara(grad_tuple.tensors)

    # @abstractmethod
    # def getCoeff(self,chosen_workers):
    #     pass
    @staticmethod
    def normalizeCoeff(coeff,workerModel:WorkerModelSigmoid,chosen_workers:list):
        device=workerModel.device
        datanum=workerModel.getSubWorkerDataNum(chosen_workers)
        coeff=coeff/datanum.to(coeff.device)
        return coeff.to(device)
        
class LinearSigmoidGrad(LinearGrad):
    def __init__(self, batch_size):
        super().__init__(batch_size)

    def getInstantLossFunc(self,coeff,workerModel,
                           chosen_workers):
        coeff=self.normalizeCoeff(coeff,workerModel,chosen_workers)
        chosenWorkerNum=len(chosen_workers)
        def f1(out,*out_args):
            temp=coeff*WorkerModelSigmoid.scatterSigmoidValue(out,*out_args,chosenWorkerNum)
            return torch.mean(temp)
        return f1
    

    def testGetGrad(self,coeff,
                    workerModel:WorkerModelSigmoid,
                    chosen_workers):
        workersGrad=workerModel.getChosenWorkersSigmoidGrad(chosen_workers,self.batch_size)

        workersGrad.mult_coeffs(coeff.to(workerModel.device))
        grad_tuple=workersGrad.mean()
        return grad_tuple
    

class LinearCEGrad(LinearGrad):
    def __init__(self, batch_size):
        super().__init__(batch_size)
    @staticmethod
    def getTargetWeightFromCoeff(worker_ids,coeff):
        '''
        data_num same
        '''     
        target1=coeff>0
        target3=target1[worker_ids]
        abs_coeff=torch.abs(coeff)
        weight=abs_coeff[worker_ids]

        return target3.to(torch.float32),weight
    

    def getInstantLossFunc(self,coeff,
                           workerModel:WorkerModelSigmoid,chosen_workers):
        coeff=self.normalizeCoeff(coeff,workerModel,chosen_workers)
        def f(out,worker_ids):
            out=out.view(-1)
            target,weight=self.getTargetWeightFromCoeff(worker_ids,coeff)
            loss=nn.BCEWithLogitsLoss(weight=weight,reduction='sum'
                                        )(out,target)
            loss=loss/len(chosen_workers)
            return loss
        return f


class StrategySigmoidValue(ABC):
    def __init__(self,g_strategy:LinearGrad):
        self.g_strategy=g_strategy

    @abstractmethod
    def getLossFromSigmoidValue(self,chosen_workers,chosenWorkersValue):
        pass

    def getSigmoidLoss(self,workerModel:WorkerModelSigmoid,chosen_workers):
        batch_size=self.g_strategy.batch_size
        workersValue=workerModel.getChosenWorkersSigmoidValue(chosen_workers,batch_size)
        return self.getLossFromSigmoidValue(chosen_workers,workersValue)
    
    @abstractmethod
    def getInstantLoss(self,workerModel,chosen_workers,batch_size):
        pass
    
    @abstractmethod
    def getGrad(self,workerModel,chosen_workers):
        pass

    # def __str__(self):
    #     return self.alias

    # @classmethod
    # def testLinearLossEqLoss(self,coeff,
    #                 workerModel:WorkerModelSigmoid,
    #                 chosen_workers):
    #     '''
    #     not general
    #     '''
    #     workersValue=workerModel.getChosenWorkersSigmoidValue(chosen_workers,self.batch_size)
    #     # print("testLinearLossEqLoss", torch.mean(workersValue*coeff),self.getLossFromSigmoidValue(chosen_workers,workersValue))
    #     assert torch.allclose(torch.mean(workersValue*coeff),workerModel.getLossFromSigmoidValue(chosen_workers,workersValue))
