import os

from lightning_fabric import Fabric
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

### Load ###


def load_weights(
    fabric: Fabric,
    ckpt_path: str,
    model: Module,
    strict: bool = True,
    model_prefix: str = "model.",
):
    """Only load model weights into the model from a checkpoint file."""
    ckpt = fabric.load(ckpt_path)
    training_module_state_dict = ckpt["training_module"]
    model_state_dict = {
        k[len(model_prefix) :]: v
        for k, v in training_module_state_dict.items()
        if k.startswith(model_prefix)
    }
    model.load_state_dict(model_state_dict, strict=strict)


def load_training(
    fabric: Fabric,
    ckpt_path: str,
    optimizer: Optimizer,
    scheduler: LRScheduler,
    training_module: Module,
) -> tuple[float, float]:
    """Load training module and optimizer states. Returns the last epoch and step."""
    d = fabric.load(ckpt_path)
    optimizer.load_state_dict(d["optimizer"])
    scheduler.load_state_dict(d["scheduler"])
    training_module.load_state_dict(d["training_module"])
    return d["epoch"], d["step"]


### Save ###


def save_training(
    fabric: Fabric,
    path: str,
    epoch: int,
    step: int,
    optimizer_state_dict: dict,
    scheduler_state_dict: dict,
    training_module_state_dict: dict,
):
    """Save the state dicts of all elements needed to restart training later."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    state = {
        "epoch": epoch,
        "step": step,
        "optimizer": optimizer_state_dict,
        "scheduler": scheduler_state_dict,
        "training_module": training_module_state_dict,
    }
    fabric.save(path, state)
