import os
import torch
from moduleloader import print_info as print

def load(state, event, model):
    ckpt_file = state["filename"]
    if os.path.isfile(ckpt_file):

        # load checkpoint
        print("... loading ckpt '{}'".format(ckpt_file))
        checkpoint = torch.load(ckpt_file)

        # assert that the model matches
        # assert type(model) == checkpoint["model_type"], "Model does not match"

        # save checkpoint metadata
        state.all["main.start_epoch"] = checkpoint['epoch']
        state["best_acc"] = checkpoint['best_acc']

        # load model weights
        model.load_state_dict(checkpoint['state_dict'])
        print("... loaded ckpt '{}' (saved epoch {})".format(ckpt_file, checkpoint['epoch']))
    else:
        print("... no ckpt found at '{}'".format(ckpt_file))

def register(mf):
    mf.register_defaults({
        "filename": lambda state,event: state.all["log.dir"]+".ckpt" if "log.dir" in state else "model.ckpt"
    })
    mf.register_event('load_ckpt', load, unique=True)
