from tool.logger import Logger
from data.dl_getter import get_dl_tr, get_dl_vl, cycle
from model.model_getter import get_model
from train.opt import get_opt
from model.model_io import _save_model


class MLBase:

    def __init__(self, args=None, other=None):
        if other is None:
            self.epoch = 0
            self.args = args
            if self.args.cls:
                self.tr_dt_dl = get_dl_tr(args)
                self.vl_dt_dl = get_dl_vl(args)
            self.model = get_model(args)
            self.optimizer = get_opt(args, self.model)
        else:
            self.epoch = 0
            self.args = other.args
            if self.args.cls:
                self.tr_dt_dl = other.tr_dt_dl
                self.vl_dt_dl = other.vl_dt_dl
            self.model = other.model
            self.optimizer = other.optimizer

    def save_model(self, loss, best_acc):
        last_epoch_f = (self.epoch == self.args.epochs - 1)
        if (loss < 1e3 and (self.epoch+1)%self.args.save_freq == 0) or \
            last_epoch_f:
            _save_model(self.args, best_acc, self.model, self.optimizer,
                             epoch=self.epoch+1)
