from calendar import c
from mmengine import Registry

from mmengine import ConfigDict
from mmengine.config import Config
from typing import Dict, Optional, Any, Union
from mmengine.registry import Registry
from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
from mmengine.registry import MODELS as MMENGINE_MODELS
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
from mmengine.registry import FUNCTIONS as MMENGINE_FUNCTIONS
from mmengine.registry import METRICS as MMENGINE_METRICS
from mmengine.registry import LOOPS as MMENGINE_LOOPS
from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS


def build_trainer_from_cfg(
    cfg: Union[Dict[str, Any], ConfigDict, Config],
    registry: Registry,
    default_args: Optional[Dict[str, Any]] = None,
) -> Any:
    """
    Instantiate a Trainer based on a user-provided configuration.

    The configuration must include a `"trainer"` section (either as a string or a dict)
    that specifies the Trainer class to build, plus any Trainer‑specific parameters.
    All other top‑level keys in `cfg` (e.g., `"model"`, `"data"`, `"accelerator"`,
    `"checkpoints"`, etc.) will be passed into the Trainer constructor under the
    single `cfg` argument for downstream initialization.

    Example valid configs:

        # Minimal: trainer name only
        cfg = {
            "trainer": "MyFancyTrainer",
            "model": { ... },
            "data":  { ... },
            "accelerator": { ... },
        }

        # Extended: trainer name plus custom trainer args
        cfg = {
            "trainer": {
                "type": "MyFancyTrainer",
                "foo": 123,
                "bar": "baz",
            },
            "model": { ... },
            "data":  { ... },
            "checkpoints": { ... },
        }

    Args:
        cfg (Union[Dict[str, Any], ConfigDict, Config]):
            A nested configuration object or dict. Must contain the key `"trainer"`,
            whose value is either:
              - A string naming the Trainer class to build.
              - A dict with at least a `"type"` field naming the Trainer class,
                plus any Trainer-specific keyword arguments.

    Returns:
        An instance of the requested Trainer class, constructed via `MMENGINE_MODELS.build`.

    Raises:
        AssertionError: If `"trainer"` is missing from `cfg`.
        KeyError: If `"type"` is missing when `cfg["trainer"]` is a dict.
    """
    # 1. Ensure the top-level 'trainer' key is present
    assert "trainer" in cfg, "Configuration must include a 'trainer' key."

    # 2. Extract trainer configuration
    trainer_cfg = cfg.pop("trainer")
    # If trainer is specified as a bare string, wrap it in a dict
    if isinstance(trainer_cfg, str):
        trainer_type = trainer_cfg
        trainer_cfg = dict()
    # If it's already a dict, make a shallow copy so we don't mutate the original
    else:
        trainer_type = trainer_cfg.pop("type", None)
        if trainer_type is None:
            raise KeyError("When 'trainer' is a dict, it must include a 'type' field.")

    # 3. Pass the full cfg into the Trainer constructor under the 'cfg' key
    trainer_cfg["cfg"] = cfg
    # 4. Build and return the Trainer instance
    if isinstance(trainer_type, str):
        obj_cls = registry.get(trainer_type)
        if obj_cls is None:
            raise KeyError(f"'{trainer_type}' is not registered in {registry.name}")
    elif callable(trainer_type):
        obj_cls = trainer_type
    else:
        raise TypeError(
            f"`type` must be a string or callable, but got {type(trainer_type)}"
        )
    return obj_cls(**trainer_cfg)


def build_hf_model_from_cfg(
    cfg: Dict[str, Any],
    registry: Registry,
    default_args: Optional[Dict[str, Any]] = None,
):
    """
    Build and initialize a Hugging Face module (model, pipeline, scheduler, etc. from transformers or diffusers)
    from a configuration dictionary and a registry of available classes.

    Depending on the contents of `cfg`, we support four initialization methods:
      1. __init__                     : Direct constructor call.
      2. from_config()                : Initialize from a `PretrainedConfig` object.
      3. from_pretrained()            : Load pretrained weights from a folder or model hub.
      4. from_single_file()           : Load from a single checkpoint file (e.g. .ckpt/.safetensors).

    Args:
        cfg (Dict[str, Any]):
            Configuration dictionary. Must contain a `"type"` key specifying
            the registry key or class to instantiate. May contain one of:
            - `"from_pretrained"`: a dict of kwargs for `from_pretrained`.
            - `"from_config"`: a dict with `"config"` key (either path or config object).
            - `"from_single_file"`: a dict with `"file_path"` key.
            Otherwise, all other entries are treated as constructor kwargs.
        registry (Registry):
            An mmengine.Registry instance that maps string keys to classes.
            It must include all HF classes you intend to build.
        default_args (Optional[Dict[str, Any]]):
            Default kwargs to merge into cfg before processing.

    Returns:
        An instantiated Hugging Face object (model, pipeline, scheduler, etc.).

    Raises:
        TypeError: If inputs are of wrong type or initialization method is unsupported.
        KeyError: If `"type"` is missing or not found in registry.
        AssertionError: If required keys for a method are missing.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f"`cfg` must be a dict, but got {type(cfg)}")
    if not isinstance(registry, Registry):
        raise TypeError(
            f"`registry` must be an mmengine Registry, but got {type(registry)}"
        )

    # Merge in defaults, without overriding explicit cfg entries
    args = dict(cfg)  # make a shallow copy
    if default_args:
        for k, v in default_args.items():
            args.setdefault(k, v)

    # Extract and validate the target class name / type
    if "type" not in args:
        raise KeyError("`cfg` (or `default_args`) must contain the key 'type'")
    obj_type = args.pop("type")

    # Resolve the class object
    if isinstance(obj_type, str):
        obj_cls = registry.get(obj_type)
        if obj_cls is None:
            raise KeyError(f"'{obj_type}' is not registered in {registry.name}")
    elif callable(obj_type):
        obj_cls = obj_type
    else:
        raise TypeError(
            f"`type` must be a string or callable, but got {type(obj_type)}"
        )

    # 1) from_pretrained
    if "from_pretrained" in args:
        fp_args = args.pop("from_pretrained")
        assert isinstance(
            fp_args, dict
        ), "`from_pretrained` must map to a dict of kwargs"
        assert (
            "pretrained_model_name_or_path" in fp_args
        ), "`from_pretrained` dict must include 'pretrained_model_name_or_path'"
        # merge remaining cfg args into from_pretrained kwargs
        fp_args = {**args, **fp_args}
        return obj_cls.from_pretrained(**fp_args)

    # 2) from_single_file
    if "from_single_file" in args:
        fs_args = args.pop("from_single_file")
        assert isinstance(
            fs_args, dict
        ), "`from_single_file` must map to a dict of kwargs"
        assert (
            "pretrained_model_link_or_path" in fs_args
        ), "`from_single_file` dict must include 'pretrained_model_link_or_path'"
        # merge any extra args
        fs_args = {**args, **fs_args}
        assert hasattr(
            obj_cls, "from_single_file"
        ), f"{obj_cls.__name__} does not implement 'from_single_file'"
        return obj_cls.from_single_file(**fs_args)

    # 3) from_config
    if "from_config" in args:
        fc_args = args.pop("from_config")
        assert isinstance(fc_args, dict), "`from_config` must map to a dict of kwargs"
        assert "config" in fc_args, "`from_config` dict must include 'config'"
        # merge any extra args
        fc_args = {**args, **fc_args}
        assert hasattr(
            obj_cls, "from_config"
        ), f"{obj_cls.__name__} does not implement 'from_config'"
        # Refer to https://github.com/huggingface/diffusers/blob/v0.33.1/src/diffusers/configuration_utils.py#L188
        # from_config accept a Dict as input.
        return obj_cls.from_config(fc_args)

    # 4) fallback to direct constructor (__init__)
    # all remaining args are passed into __init__
    return obj_cls(**args)


# HF_MODELS responsible for building diffusers and transformers model. Including ModelMixin, PipelineMixin, SchedulerMixin, etc.
HF_MODELS = Registry(
    "hf_model",
    build_hf_model_from_cfg,
    locations=[
        "mmhug.pipelines",
        "mmhug.models",
        "mmhug.schedulers",
        "mmhug.losses",
    ],
)

METRICS = Registry("metric", parent=MMENGINE_METRICS, locations=["mmhug.evaluators"])
MODEL_WRAPPERS = Registry(
    "model_wrapper", parent=MMENGINE_MODEL_WRAPPERS, locations=["mmhug.model_wrappers"]
)

DATASETS = Registry("dataset", parent=MMENGINE_DATASETS, locations=["mmhug.datasets"])
TRANSFORMS = Registry(
    "transform", parent=MMENGINE_TRANSFORMS, locations=["mmhug.datasets.transforms"]
)
TRAINERS = Registry("trainer", build_trainer_from_cfg, locations=["mmhug.trainers"])
MODELS = Registry("model", parent=MMENGINE_MODELS, locations=["mmhug.models"])

FUNCTIONS = Registry(
    "function", parent=MMENGINE_FUNCTIONS, locations=["mmhug.datasets.collate_fn"]
)


LOOPS = Registry("loop", parent=MMENGINE_LOOPS, locations=["mmhug.runners"])
