import os
import torch
from .base_util import device


def save_chkpt(folder_out, model, i_epoch, optimizer=None):
        folder_path = os.path.join(folder_out, "chkpt")

        os.makedirs(folder_path, exist_ok=True)

        chkpt = {'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'i_epoch': i_epoch}

        chkpt_name = 'chkpt_ep{}.pth'.format(i_epoch)
        chkpt_path = os.path.join(folder_path, chkpt_name)
        torch.save(chkpt, chkpt_path)

        chkpt_last_name = 'chkpt_last.pth'
        chkpt_last_path = os.path.join(folder_path, chkpt_last_name)
        torch.save(chkpt, chkpt_last_path)

        print("Saved checkpoint to ", folder_path)


def load_chkpt(folder_out, model, optimizer=None, i_epoch=None):
    if i_epoch is None:
        chkpt_name = 'chkpt_last.pth'
        chkpt_path = os.path.join(folder_out, "chkpt", chkpt_name)
    else:
        chkpt_name = 'chkpt_ep{}.pth'.format(i_epoch)
        chkpt_path = os.path.join(folder_out, "chkpt", chkpt_name)
    chkpt = torch.load(chkpt_path, map_location=device)

    model.load_state_dict(chkpt['model'])

    if optimizer and "optimizer_state_dict" in chkpt:
        optimizer.load_state_dict(chkpt["optimizer"])

    model.train()

    if i_epoch is None:
        print("Loaded checkpoint from ", folder_out, " with the latest weights")
    else:
        print("Loaded checkpoint from ", folder_out, " with the weights from epoch ", i_epoch)

    if 'i_epoch' in chkpt:
        return chkpt['i_epoch']
