import logging
import warnings
from typing import Dict, Union

import torch
import torch.nn as nn
from mmcls.models import build_classifier
from mmcv.runner import load_checkpoint


def init_mmcls_classifier(
        classifier_cfg: Dict,
        classifier_ckpt: str,
        logger: logging.Logger,
        device: Union[str, torch.device] = None) -> nn.Module:
    """Initialize a classifier.

    Build a classifier from the config, load the checkpoint, move it to the
    device, and turn the eval mode on.

    Args:
        classifier_cfg: Config of the classifier.
        classifier_ckpt: Path to pre-trained weights.
        logger: Logger.
        device: device to which the classifier will be moved.

    Returns:
        The classifier.
    """
    classifier_cfg['pretrained'] = None
    classifier = build_classifier(classifier_cfg)
    classifier_ckpt = load_checkpoint(
        classifier, classifier_ckpt, map_location='cpu', logger=logger)
    if 'CLASSES' in classifier_ckpt.get('meta', {}):
        classifier.CLASSES = classifier_ckpt['meta']['CLASSES']
    else:
        from mmcls.datasets import ImageNet
        warnings.simplefilter('once')
        warnings.warn(
            'Class names are not saved in the checkpoint\'s '
            'meta data, use imagenet by default.')
        classifier.CLASSES = ImageNet.CLASSES

    classifier.to(device)
    classifier.eval()
    return classifier
