### Preamble ###################################################################

"""
Simple helper functions for working with the HuggingFace packages
"""

#######################################################################################################################

### Imports ###########################################################################################################

from timm.data import resolve_data_config
from typing import Union, Iterable, Optional, Tuple

#######################################################################################################################


def parse_timm_preprocess_config(timm_config: dict) -> dict:
    """
    :param timm_config: dict
        The config describing the preprocessing functions to apply to a classifier derived from the `timm` package
        from HuggingFace. This should be the output of `timm.data.resolve_data_config({}, model=inception_model)`.

    Returns a preprocessing config dictionary compatible with classifiers from the `transformers` package from
    HuggingFace and classifier-based guidance controllers from `gcontrol`.
    """

    gc_config = {}
    if "input_size" in timm_config:
        gc_config["size"] = timm_config["input_size"]
    else:
        raise KeyError("`timm_config` must have an 'input_size' attribute")

    if "mean" in timm_config:
        gc_config["image_mean"] = list(timm_config["mean"])
    if "std" in timm_config:
        gc_config["image_std"] = list(timm_config["std"])
    if ("image_mean" in gc_config) and ("image_std" in gc_config):
        gc_config["do_normalize"] = True
    else:
        gc_config["do_normalize"] = False
    if "crop_pct" in timm_config:
        gc_config["crop_pct"] = timm_config["crop_pct"]

    gc_config["do_rescale"] = True
    gc_config["rescale_factor"] = 1 / 255.0
    gc_config["do_resize"] = True

    return gc_config


def get_timm_config(model):
    """
    :param model:
        A classifier model from the `timm` package.

    Returns a preprocessing config dictionary compatible with classifiers from the `transformers` package from
    HuggingFace and classifier-based guidance controllers from `gcontrol`.
    """

    config = resolve_data_config({}, model=model)
    return parse_timm_preprocess_config(config)


#######################################################################################################################
