from ruamel.yaml import YAML
import torch
from torch.nn import Module

from training import TrainingConfig


def prepare_model(checkpoint_path: str) -> Module:
    """
    Initialises the model in evaluation mode and loads the model parameters from the given checkpoint path.
    """
    checkpoint = torch.load(checkpoint_path, weights_only=False)
    config: TrainingConfig = YAML().load(checkpoint["training_config"])

    model = config.model_config.instantiate_model()
    model.load_state_dict(checkpoint["model"])
    model.eval()

    print(f"Loaded {config.model_config.MODEL_NAME}, trained for {checkpoint['epoch']} epochs on {config.dataset_name}")
    return model
