from ANONYMOUS import Singleton
from ANONYMOUS.utils.retry import retry_call
from ANONYMOUStorch.trainer import LossException, check_loss_error
from ANONYMOUStorch.trainer.ema_trainer import EMATrainer


class Trainer(EMATrainer, metaclass=Singleton):
    def __init__(self, cfg, loss_fn):
        super().__init__(cfg, loss_fn)
        self.rank = 0
        self.is_master = True  # for sync ddp_trainer
        self.bm = None  # pylint: disable= invalid-name
        check_loss_error(self)

    def train_step(self, feed_dict):
        retry_call(
            super().train_step, fargs=[feed_dict], tries=3, exceptions=LossException
        )

    @property
    def mmodel(self):  # always enable ema, so mmodel should go to ema
        if hasattr(self.ema.model, "module"):
            return self.ema.model.module
        if self.ema.model is None:
            return self.model
        return self.ema.model
