from typing import TypeVar

import torch

from utils.logger.logger import Logger
from utils.utils import get_class_name, get_object_name

T: 'T' = TypeVar('T', torch.nn.Module, torch.optim.Optimizer)


def load_from_state_dict(entity: T, load_path: str = None, load_keys: list[str] = None) -> T:
    Logger.debug(
        f'{get_class_name(load_from_state_dict)} - '
        f'entity: {get_object_name(entity)}, '
        f'load_path: {load_path}, '
        f'load_keys: {load_keys}'
    )
    if load_keys is None:
        load_keys: list[str] = []
    if load_path is not None:
        Logger.debug(f'loading from path: {load_path}')
        state_dict: dict = torch.load(load_path, map_location='cpu')
        for load_key in load_keys:
            assert isinstance(state_dict, dict), f'state_dict must be dict: {state_dict}'
            assert load_key in state_dict.keys(), \
                f'load_key must be in state_dict keys: {load_key}, {state_dict.keys()}'
            Logger.debug(f'loading from key: {load_key}')
            state_dict: dict = state_dict[load_key]
        entity.load_state_dict(state_dict)
    return entity
