import torch
from torch import nn
import pytorch_lightning as PL
import torchmetrics
import numpy as np

class BaseMPCmodel(nn.Module):
    def __init__(self, model, horizon, stride, lossblocks, loss_fn, **kwargs):
        super().__init__()
        self.model=model
        self.loss_fn=loss_fn
        
        self.update_stride=kwargs.get('update_stride',True)
        self.horizon=horizon
        self.stride=stride
        self.blocks=nn.ModuleList()
        for n,m in model.named_children():
            if n=='head':
                continue
            if n=='stem':
                self.stem=m
                continue
            if isinstance(m, nn.ModuleList):
                self.blocks+=nn.ModuleList([b for b in m])
            else:
                self.blocks.append(m)
        if not hasattr(self,'stem'):
            self.stem=nn.Identity()
        self.T=len(self.blocks)
        if isinstance(lossblocks,nn.Module):
            self.lossblocks=nn.ModuleList([lossblocks for _ in range(self.T)])
        elif lossblocks is None:
            self.lossblocks=nn.ModuleList([model.head for _ in range(self.T)])
        else:
            if len(lossblocks)==self.T-1:
                self.lossblocks=nn.ModuleList(lossblocks)
                self.lossblocks.append(model.head)
                self.head_untrained=False
            elif len(lossblocks)==self.T:
                self.lossblocks=nn.ModuleList(lossblocks)
                self.head_untrained=True
            else:
                raise Exception("Invalid option")
                
    def _do_backward(self,loss):
        loss.backward()
        
    def _detach_x(self,x):
        return x.detach()
    
    def unpack_batch(self,batch):
        return (batch[0],batch[1])
    
    def get_gradient(self,x,y,lastblock,):
        loss = self.loss_fn(self.lossblocks[lastblock-1](x), y)
        self._do_backward(loss)
        return loss
    
    def train_one_batch(self, xs, y, batch_idx=None, optimizer=None, modify_grad=None,):
        if optimizer:
            optimizer.zero_grad()
        s=0
        while True:
            if s==0:
                x=self.stem(xs)
            else:
                x=xs
            lastblock=min(s+self.horizon,self.T)
            if self.update_stride:
                if lastblock<self.T:
                    for b in self.blocks[s:s+self.stride]:
                        b.requires_grad_(True)
                    for b in self.blocks[s+1:lastblock]:
                        b.requires_grad_(False)
                else:
                    for b in self.blocks[s:lastblock]:
                        b.requires_grad_(True)

            for b in self.blocks[s:s+self.stride]:
                x=b(x)
            xs=self._detach_x(x)
            for b in self.blocks[s+self.stride:lastblock]:
                x=b(x)
            
            loss=self.get_gradient(x,y,lastblock,)
            
            if lastblock==self.T:
                break
            else:
                s+=self.stride

        if modify_grad:
            modify_grad()
        if optimizer:
            optimizer.step()
            optimizer.zero_grad()

        return loss,self.lossblocks[lastblock-1](x)
    
    
    def forward(self, x):
        return self.model(x)
    
class MPCmodel(BaseMPCmodel,PL.LightningModule):
    def __init__(self,model, horizon, stride, lossblocks, loss_fn,
                 lr=1e-3,optimizer='sgd',momentum=None,metrics=None,**kwargs):
        super().__init__(model, horizon, stride, lossblocks, loss_fn, **kwargs)
        self.automatic_optimization = False
        self.metrics=metrics
        self.val_metrics=metrics.clone(prefix='val_') if metrics else None
        self.lr=lr
        self.train_loss_metric = torchmetrics.MeanMetric()
        self.val_loss_metric = torchmetrics.MeanMetric()
        
        self.optimizer_name=optimizer
        self.momentum=momentum
        
        if self.head_untrained:
            self.head_epoch=kwargs.get('head_epoch',30)
            self.head_loss_fn=kwargs.get('head_loss_fn',loss_fn)
            self.head_loss = torchmetrics.MeanMetric()
            self.head_metrics=kwargs.get('head_metrics',None)
            self.val_head_loss = torchmetrics.MeanMetric()
            self.val_head_metrics=self.head_metrics.clone(prefix='val_') if self.head_metrics else None
            self.head_lr=kwargs.get('head_lr',lr)
    
    def _do_backward(self,loss):
        self.manual_backward(loss)
    
    def training_step(self, batch, batch_idx):
        opt=self.optimizers()
        xs,y=self.unpack_batch(batch)
        loss,pred=self.train_one_batch(xs,y,batch_idx=batch_idx,optimizer=opt)
        self.train_loss_metric(loss)
        self.log('loss', self.train_loss_metric, on_step=True, on_epoch=False, prog_bar=True, logger=False)
        if self.metrics:
            self.metrics(pred, y)
            self.log_dict(self.metrics, on_step=True, on_epoch=False, prog_bar=True, logger=False)

    def on_train_epoch_end(self,):
        # Log metrics for the epoch
        self.log('loss', self.train_loss_metric.compute(), on_epoch=True, prog_bar=True, logger=True)
        self.train_loss_metric.reset()
        
        if self.metrics:
            self.log_dict(self.metrics.compute(), on_epoch=True, prog_bar=True, logger=True)
            self.metrics.reset()
        
    def on_train_end(self):
        if self.head_untrained:
            for b in self.blocks:
                b.requires_grad_(False)
            self.stem.requires_grad_(False)
            self.model.head.requires_grad_(True)
            opt=torch.optim.Adam(self.model.head.parameters(), lr=self.head_lr)
            opt.zero_grad()
            for epoch_idx in range(self.head_epoch):
                for batch_idx, batch in enumerate(self.trainer.train_dataloader):
                    X, y=self.unpack_batch(batch)
                    X, y = X.to(self.device), y.to(self.device)
                    pred = self.model(X)
                    loss = self.head_loss_fn(pred, y)
                    self.manual_backward(loss)
                    opt.step()
                    opt.zero_grad()
                    self.head_loss(loss)
                    if self.head_metrics:
                        self.head_metrics(pred, y)
                
                self.logger.experiment.log({'head_loss': self.head_loss.compute().item()})
                self.head_loss.reset()

                if self.head_metrics:
                    self.logger.experiment.log({k:v.item() for k,v in self.head_metrics.compute().items()})
                    self.head_metrics.reset()

                for batch_idx, batch in enumerate(self.trainer.val_dataloaders):
                    X, y=self.unpack_batch(batch)
                    X, y = X.to(self.device), y.to(self.device)
                    pred = self.model(X)
                    loss = self.head_loss_fn(pred, y)
                    self.val_head_loss(loss)
                    if self.val_head_metrics:
                        self.val_head_metrics(pred, y)
                
                self.logger.experiment.log({'val_head_loss': self.val_head_loss.compute().item(),})
                print(f'Epoch {epoch_idx+1}/{self.head_epoch}: val_head_loss: {self.val_head_loss.compute().item()}')
                self.val_head_loss.reset()

                if self.val_head_metrics:
                    self.logger.experiment.log({k:v.item() for k,v in self.val_head_metrics.compute().items()},)
                    self.val_head_metrics.reset()
                    
    def validation_step(self, batch, batch_idx):
        X,y=self.unpack_batch(batch)
        pred = self.model(X)
        val_loss = self.loss_fn(pred, y)
        self.val_loss_metric(val_loss)
        self.log("val_loss", self.val_loss_metric, on_step=True, on_epoch=False, prog_bar=True, logger=False)
        if self.val_metrics:
#             mouts=self.val_metrics(pred, y)
            self.val_metrics(pred, y)
            self.log_dict(self.val_metrics, on_step=True, on_epoch=False, prog_bar=True, logger=False)
    
    def on_validation_epoch_end(self,):
        # Log metrics for the epoch
        self.log('val_loss', self.val_loss_metric.compute(), on_epoch=True, prog_bar=True, logger=True)
        self.val_loss_metric.reset()
        
        if self.val_metrics:
            self.log_dict(self.val_metrics.compute(), on_epoch=True, prog_bar=True, logger=True)
            self.val_metrics.reset()
            
    def configure_optimizers(self,):
        if self.optimizer_name=='sgd':
            optimizer = torch.optim.SGD(self.parameters(), lr=self.lr, 
                               momentum=self.momentum if self.momentum else 0)
        elif self.optimizer_name=='adam':
            optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
class BaseLossModifiedModel(BaseMPCmodel):
    def __init__(self, *args, small_batch_size=8, modify_batch=10, 
                 use_modify_dataloader=False, modify_dataloader=None,period='epoch',
                 lambda_g=1.,lambda_scale_lb=1., a=1.,
                 blockwise=True, compute_var=True, modify_g=False,**kwargs, ):
        super().__init__(*args, **kwargs)
        assert modify_batch>0
        self.small_batch_size=small_batch_size
        self.modify_batch=modify_batch
        self.use_modify_dataloader=use_modify_dataloader
        self.modify_dataloader=modify_dataloader
        assert use_modify_dataloader==False or modify_dataloader
        self.period=period
        self._lambda_g=lambda_g
        self.lambda_scale=1.
        self.lambda_scale_lb=lambda_scale_lb
        self._a=a
            
        self.blockwise=blockwise
        self.compute_var=compute_var
        self.modify_g=modify_g

        BaseLossModifiedModel.compute_a(self,)
        BaseLossModifiedModel.compute_lambda(self,)
    
    def get_grads(self, xs, y):
        loss,pred=self.train_one_batch(xs, y, modify_grad=None, optimizer=None,)
        grads=[]
        for b in self.blocks:
            block_grad=[]
            for p in b.parameters():
                if p.requires_grad:
                    block_grad.append(p.grad.detach().clone())
                    p.grad.zero_()
            grads.append(block_grad)
        return loss,pred,grads
    
    @torch.no_grad
    def compute_a(self,**kwargs):
        if self.blockwise:
            self.a=[self._a]*self.T
        else:
            self.a=self._a
    
    @torch.no_grad
    def compute_lambda(self,**kwargs):
        if self.blockwise:
            self.lambda_g=[self._lambda_g]*self.T
        else:
            self.lambda_g=self._lambda_g
    
    @torch.no_grad
    def log_modify_term(self,**kwargs):
        pass
    
    @torch.no_grad
    def modify_grad(self,):
        for t in range(self.T):
            a=self.a[t] if self.blockwise else self.a
            if hasattr(self,'modified_g'):
                lambda_g=self.lambda_g[t] if self.blockwise else self.lambda_g
                for p,modified_g in zip([p for p in self.blocks[t].parameters() if p.requires_grad],
                                        self.modified_g[t]):
                    p.grad.mul_(a).add_(modified_g*lambda_g*self.lambda_scale)
            else:
                for p in [p for p in self.blocks[t].parameters() if p.requires_grad]:
                    p.grad.mul_(a)
    
    @torch.no_grad
    def modify_loss(self,):
        norm_g_block_mpc=[] if self.blockwise else torch.zeros(1).to(self.device)
        norm_g_block_true=[] if self.blockwise else torch.zeros(1).to(self.device)
        dot_block=[] if self.blockwise else torch.zeros(1).to(self.device)
        
        def get_block_norm(block_g):
            return sum([torch.sum(g.square()) for g in block_g])/self.modify_batch**2
        
        if self.compute_var and self.blockwise:
            sigma_g_block_mpc=[] if self.blockwise else torch.zeros(1).to(self.device)
            sigma_g_block_true=[] if self.blockwise else torch.zeros(1).to(self.device)
            sigma_g_block_delta=[] if self.blockwise else torch.zeros(1).to(self.device)
            
            for block_g_mpc, block_g_true,block_var_mpc,block_var_true,block_var_delta in zip(self.mpcgrad,self.truegrad,self.var_mpc,self.var_true,self.var_delta):
                norm_g_block_mpc.append(get_block_norm(block_g_mpc))
                norm_g_block_true.append(get_block_norm(block_g_true))
                dot_block.append(sum([torch.sum(g_mpc*g_true) for g_mpc,g_true in zip(block_g_mpc, block_g_true)])/self.modify_batch**2)
                sigma_g_block_mpc.append((block_var_mpc/self.modify_batch-norm_g_block_mpc[-1])*self.small_batch_size)
                sigma_g_block_true.append((block_var_true/self.modify_batch-norm_g_block_true[-1])*self.small_batch_size)
                sigma_g_block_delta.append((block_var_delta-sum([(gh-gT).square().sum() for gh,gT in zip(block_g_mpc,block_g_true)])/self.modify_batch)/self.modify_batch*self.small_batch_size)
        else:
            if self.compute_var:
                norm_dg=torch.zeros(1).to(self.device)
            for block_g_mpc, block_g_true in zip(self.mpcgrad,self.truegrad):
                if self.blockwise:
                    norm_g_block_mpc.append(get_block_norm(block_g_mpc))
                    norm_g_block_true.append(get_block_norm(block_g_true))
                    dot_block.append(sum([torch.sum(g_mpc*g_true) for g_mpc,g_true in zip(block_g_mpc, block_g_true)])/self.modify_batch**2)
                else:
                    norm_g_block_mpc.add_(get_block_norm(block_g_mpc))
                    norm_g_block_true.add_(get_block_norm(block_g_true))
                    dot_block.add_(sum([torch.sum(g_mpc*g_true) for g_mpc,g_true in zip(block_g_mpc, block_g_true)])/self.modify_batch**2)
                    if self.compute_var:
                        norm_dg.add_(sum([(g_mpc-g_true).square().sum() for g_mpc,g_true in zip(block_g_mpc, block_g_true)])/self.modify_batch**2)
            if self.compute_var:
                sigma_g_block_mpc=(self.var_mpc/self.modify_batch-norm_g_block_mpc)*self.small_batch_size
                sigma_g_block_true=(self.var_true/self.modify_batch-norm_g_block_true)*self.small_batch_size
                sigma_g_block_delta=(self.var_delta/self.modify_batch-norm_dg)*self.small_batch_size
        
        modify_dict=dict(norm_gh=norm_g_block_mpc,norm_gT=norm_g_block_true,dot=dot_block,)
        if self.compute_var:
            modify_dict.update(dict(var_gh=sigma_g_block_mpc,var_gT=sigma_g_block_true,var_dg=sigma_g_block_delta))
        
        self.compute_a(**modify_dict)
        if self.blockwise:
            self.a=[ai if t<self.T-self.horizon else 1.0 for t,ai in enumerate(self.a)]
        modify_dict.update(dict(a=self.a))
        
        if self.modify_g:
            self.compute_lambda(**modify_dict)
            if self.blockwise:
                self.lambda_g=[lambda_t if t<self.T-self.horizon else 0 for t,lambda_t in enumerate(self.lambda_g)]
            modify_dict.update(dict(lambda_g=self.lambda_g))
            self.modified_g=[]
            for t,(block_g_mpc, block_g_true) in enumerate(zip(self.mpcgrad,self.truegrad)):
                ai=self.a[t] if self.blockwise else self.a
                self.modified_g.append([(g_true-ai*g_mpc)/self.modify_batch for g_mpc,g_true in zip(block_g_mpc,block_g_true)])

        self.log_modify_term(**modify_dict)
        
    #unused
    def _detach_clone(self,x):
        return x.detach().clone()
    
    def _split_batch(self,xs,y,need_samples=None):
        if need_samples:
            return torch.split(xs[:need_samples],self.small_batch_size),torch.split(y[:need_samples],self.small_batch_size)
        else:
            return torch.split(xs,self.small_batch_size),torch.split(y,self.small_batch_size)

    def train_small_batches(self, xs, y, batch_idx, optimizer):
        if self.use_modify_dataloader:
            preds=[]
            loss=0
            data_iter=iter(self.modify_dataloader)
            for idx in range(self.modify_batch):
                xs,y=data_iter.next()
                self.train_small_batch(xs, y, optimizer=optimizer, init=(idx==0), final=(idx==self.modify_batch-1))
        else:
            if self.count==self._small_batches-1:
                need_samples=self.modify_batch*self.small_batch_size-self.count*self.batch_size
                xs_half,y_half = self._split_batch(xs,y,need_samples)
                for idx,(xsi,yi) in enumerate(zip(xs_half,y_half)):
                    self.train_small_batch(xsi, yi, optimizer=optimizer,init=((self.count==0)&(idx==0)),
                                           final=(idx==len(xs_half)-1))
                if hasattr(self,'lambda_scale_ub'):
                    self.lambda_scale=self.lambda_scale_ub
                return self.train_one_batch(xs,y,batch_idx=batch_idx,optimizer=optimizer,)
            else:
                xs,y = self._split_batch(xs,y)
                preds=[]
                loss=0
                for idx,(xsi,yi) in enumerate(zip(xs,y)):
                    lossi,predi=self.train_small_batch(xsi, yi, optimizer=optimizer,init=((self.count==0)&(idx==0)),final=False)
                    loss+=lossi
                    preds.append(predi)
                return loss/len(xs),torch.cat(preds)

    def train_small_batch(self, xs, y, optimizer, init=False, final=False):
        if init:
            self.mpcgrad=[[torch.zeros_like(p).to(self.device) for p in b.parameters() if p.requires_grad] for b in self.blocks]
            self.truegrad=[[torch.zeros_like(p).to(self.device) for p in b.parameters() if p.requires_grad] for b in self.blocks]
            if self.compute_var:
                self.var_mpc=torch.zeros(self.T if self.blockwise else 1).to(self.device)
                self.var_true=torch.zeros(self.T if self.blockwise else 1).to(self.device)
                self.var_delta=torch.zeros(self.T if self.blockwise else 1).to(self.device)
        
        optimizer.zero_grad()
        _,_,g_hs=self.get_grads(xs, y)
        optimizer.zero_grad()
        horizon=self.horizon
        self.horizon=self.T
        loss,pred=self.train_one_batch(xs, y, modify_grad=None, optimizer=None,)
        self.horizon=horizon
        
        if self.compute_var and self.blockwise:
            for block_g_T_past,block_g_h_past,b,block_g_h,block_var_T,block_var_h,block_var_delta in \
                zip(self.truegrad,self.mpcgrad,self.blocks,g_hs,self.var_true,self.var_mpc,self.var_delta):
                for g_T_past,g_h_past,p,g_h_now in \
                zip(block_g_T_past, block_g_h_past, [p for p in b.parameters() if p.requires_grad],block_g_h):
                    g_T_now=p.grad
                    g_T_past.add_(g_T_now)
                    g_h_past.add_(g_h_now)
                    block_var_T.add_(g_T_now.square().sum())
                    block_var_h.add_(g_h_now.square().sum())
                    block_var_delta.add_((g_T_now-g_h_now).square().sum())
        else:
            for block_g_T_past,block_g_h_past,b,block_g_h,in zip(self.truegrad,self.mpcgrad,self.blocks,g_hs):
                for g_T_past,g_h_past,p,g_h_now in \
                zip(block_g_T_past, block_g_h_past, [p for p in b.parameters() if p.requires_grad],block_g_h):
                    g_T_now=p.grad
                    g_T_past.add_(g_T_now)
                    g_h_past.add_(g_h_now)
                    if self.compute_var:
                        self.var_true.add_(g_T_now.square().sum())
                        self.var_mpc.add_(g_h_now.square().sum())
                        self.var_delta.add_((g_T_now-g_h_now).square().sum())
        optimizer.step()            
        optimizer.zero_grad()
        
        if final:
            self.modify_loss()
            self.mpcgrad=None
            self.truegrad=None
            if self.compute_var:
                self.var_mpc=None
                self.var_true=None
                self.var_delta=None
            return loss,pred
        else:
            return loss,pred

class BaseLossModifiedModel_a(BaseLossModifiedModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.modify_batch>1
        
    @torch.no_grad
    def compute_a(self,**kwargs):
        B_prime=self.modify_batch*self.small_batch_size/self.period
        # B_prime=self.modify_batch*self.small_batch_size
        if self.blockwise:
            var_g_hT=[var_gh_t+var_gT_t-var_dg_t for var_gh_t,var_gT_t,var_dg_t in \
                        zip(kwargs['var_gh'],kwargs['var_gT'],kwargs['var_dg'])]
            self.a=[((1-lambda_t)**2*dot_t+lambda_t**2*var_g_hT_t/2./B_prime)/    \
                ((1-lambda_t)**2*norm_gh_t+var_gh_t*(1./self.batch_size+lambda_t**2/B_prime))  \
                for dot_t,norm_gh_t,var_gh_t,var_g_hT_t,lambda_t in \
                    zip(kwargs['dot'],kwargs['norm_gh'],kwargs['var_gh'],var_g_hT,self.lambda_g)]
        else:
            self.a=((1-self.lambda_g)**2*kwargs['dot']+self.lambda_g**2*(kwargs['var_gh']+kwargs['var_gT']-kwargs['var_dg'])/2./B_prime)/    \
                ((1-self.lambda_g)**2*kwargs['norm_gh']+kwargs['var_gh']*(1./self.batch_size+self.lambda_g**2/B_prime))

class BaseLossModifiedModel_Modify(BaseLossModifiedModel):
    def __init__(self, *args, **kwargs):
        kwargs['modify_g']=True
        super().__init__(*args, **kwargs)

class BaseLossModifiedModel_lambda(BaseLossModifiedModel_Modify):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.modify_batch>1

    @torch.no_grad
    def compute_lambda(self,**kwargs):
        B_prime=self.modify_batch*self.small_batch_size/self.period
        if self.blockwise:
            mse=[a_t**2*norm_gh_t+norm_gT_t-2.*a_t*dot_t for a_t,norm_gh_t,norm_gT_t,dot_t in\
                    zip(self.a,kwargs['norm_gh'],kwargs['norm_gT'],kwargs['dot'])]
            self.lambda_g=[mse_t/(mse_t+(a_t**2*var_gh_t+var_gT_t-a_t*(var_gh_t+var_gT_t-var_dg_t))/B_prime) for mse_t,a_t,var_gh_t,var_gT_t,var_dg_t in \
                            zip(mse,self.a,kwargs['var_gh'],kwargs['var_gT'],kwargs['var_dg'])]
        else:
            mse=self.a**2*kwargs['norm_gh']+kwargs['norm_gT']-2.*self.a*kwargs['dot']
            self.lambda_g=mse/(mse+(self.a**2*kwargs['var_gh']+\
                        kwargs['var_gT']-self.a*(kwargs['var_gh']+kwargs['var_gT']-kwargs['var_dg']))/B_prime)

class BaseLossModifiedModel_a_lambda(BaseLossModifiedModel_lambda,BaseLossModifiedModel_a):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.iter_a_lambda=kwargs.get('iter_a_lambda',50)

    @torch.no_grad
    def compute_a(self, **kwargs):
        if hasattr(self.a,'copy'):
            last_a=self.a.copy()
        else:
            last_a=self.a
        for _ in range(self.iter_a_lambda):
            super().compute_a(**kwargs)
            super().compute_lambda(**kwargs)
            if self.blockwise:
                if all([abs(ai-last_ai)/last_ai<1e-3 for ai,last_ai in zip(self.a,last_a)]):
                    break
            else:
                if abs(self.a-last_a)/last_a<1e-3:
                    break

class LossModifiedModel_PLwrapper(BaseLossModifiedModel,MPCmodel):
    @torch.no_grad
    def log_modify_term(self,**kwargs):
        if self.blockwise:
            logdict={}
            for name,value in kwargs.items():
                logdict.update({f'{name}_{i+1}':v for i,v in enumerate(value)})
            self.logger.experiment.log(logdict)
        else:
            self.logger.experiment.log(kwargs)

    def training_step(self, batch, batch_idx):
        opt=self.optimizers()
        xs,y = self.unpack_batch(batch)
        if self.use_modify_dataloader and self.count==0:
            loss,pred=self.train_small_batches(xs,y,batch_idx,optimizer=opt)
        elif self.count*self.batch_size<self.small_batch_size*self.modify_batch:
            loss,pred=self.train_small_batches(xs,y,batch_idx,optimizer=opt)
        else:
            self.lambda_scale=self.lambda_scale_ub-(self.lambda_scale_ub-self.true_lambda_scale_lb)* \
                        (self.count-self._small_batches)/(self.period-self._small_batches)
            loss,pred=self.train_one_batch(xs,y,batch_idx,optimizer=opt)
            
        self.train_loss_metric(loss)
        self.log('loss', self.train_loss_metric, on_step=True, on_epoch=False, prog_bar=True, logger=False)
        if self.metrics:
            self.metrics(pred, y)
            self.log_dict(self.metrics, on_step=True, on_epoch=False, prog_bar=True, logger=False)
        self.count=(self.count+1)%self.period
            
    def on_train_start(self):
        self.batch_size=self.trainer.train_dataloader.batch_size
        self._small_batches=np.ceil(self.small_batch_size*self.modify_batch/self.batch_size)
        if self.period=='epoch':
            self.period=len(self.trainer.train_dataloader)
        elif not isinstance(self.period,int):
            raise Exception("Invalid option: Period must be integral or 'epoch'")
        self.count=0
        if self.momentum:
            self.lambda_scale_ub=1-self.momentum
        elif 'adam' in self.optimizer_name:
            self.lambda_scale_ub=0.1
            
        self.true_lambda_scale_lb=self.lambda_scale_lb*self.lambda_scale
        
class LossModifiedModel_test(BaseLossModifiedModel_a,LossModifiedModel_PLwrapper):
    pass

class LossModifiedModel_final(BaseLossModifiedModel_a_lambda,LossModifiedModel_PLwrapper):
    pass


