import os

import torch


def save_checkpoint(
    model,
    ema,
    optimizer,
    args,
    global_step,
    save_path,
    accelerator,
    min_recorded_avg_fid=None,
):
    checkpoint = {
        "model": (
            accelerator.unwrap_model(model).state_dict()
            if hasattr(model, "module")
            else model.state_dict()
        ),
        "ema": ema.state_dict(),
        "opt": optimizer.state_dict(),
        "args": args,
        "steps": global_step,
    }
    if min_recorded_avg_fid is not None:
        checkpoint["avg_fid"] = min_recorded_avg_fid

    # If the save_path does not end with '.pt', assume it's a directory
    # and construct the filename from the global step.
    if not save_path.endswith(".pt"):
        checkpoint_path = os.path.join(save_path, f"{global_step:07d}.pt")
    else:
        checkpoint_path = save_path

    torch.save(checkpoint, checkpoint_path)
    return checkpoint_path


def load_checkpoint(model, ema, optimizer, args, checkpoint_dir):
    ckpt_name = str(args.resume_step).zfill(7) + ".pt"
    ckpt_path = os.path.join(checkpoint_dir, ckpt_name)
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(ckpt["model"])
    ema.load_state_dict(ckpt["ema"])
    optimizer.load_state_dict(ckpt["opt"])
    global_step = ckpt["steps"]
    return model, ema, optimizer, global_step
