import hydra
import torch
from typing import Union
from omegaconf import DictConfig
import pytorch_lightning as pl

def instantiate_model(cfg: DictConfig): # -> Union[torch.nn.Module | pl.LightningModule]:
    model = hydra.utils.instantiate(cfg.model, _recursive_=False)
    if cfg.model.training.ckpt_file is not None:
        model = model.__class__.load_from_checkpoint(
                        cfg.model.training.ckpt_file, strict=False)
    
    return model


def instantiate_generic(cfg: DictConfig): # -> Union[torch.nn.Module | pl.LightningModule]:
    model = hydra.utils.instantiate(cfg, _recursive_=False)
    if cfg.training.ckpt_file is not None:
        model = model.__class__.load_from_checkpoint(
                        cfg.training.ckpt_file, strict=False)
    
    return model
