import os
import torch

def save(state, event, model):
    ckpt_file = state["filename"]

    epoch = state.all["main.current_epoch"]
    current_acc = state.all["val_accuracy"] if hasattr(event,'validate') and "val_accuracy" in state.all else state.all["last_accuracy"]
    best_acc = state.all["best_acc"] if "best_acc" in state.all else 0

    # remember best prec@1 and save checkpoint
    is_best = current_acc > best_acc
    best_acc = max(current_acc, best_acc)
    state.all["best_acc"] = best_acc

    if epoch > 0 and epoch % state["every_epoch"] == 0:
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
        }, state["filename"])

def save_now(state, event, model):
    epoch = state.all["main.epoch"]
    best_acc = state.all["best_acc"]
    current_acc = state.all["val_accuracy"]
    torch.save({
        'state_dict': model.state_dict(),
        'best_acc': best_acc,
        'final_acc': current_acc,
    }, state["filename"])

def register(mf):
    mf.register_defaults({
        "filename": lambda state,event: state["log.dir"]+".ckpt" if "log.dir" in state else "model.ckpt",
        "every_epoch": lambda state,event: int(state["main.epoch"] * 0.1)
    })
    mf.register_event('save_ckpt', save)
