import warnings
from collections.abc import Iterable

from transformers.modeling_utils import PreTrainedModel


from .models.llama import *
from .models.qwen3 import *


MODEL_TYPE_TO_APPLY_FN = {
    "llama": apply_patch_to_llama_model,
    "qwen3": apply_patch_to_qwen3_model,
}


def apply_patch_to_model(
    model: PreTrainedModel,
    patch_locations: Iterable | None = None,
    compress_kwargs: dict | None = None,
) -> None:
    model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)

    if not model_type:
        warnings.warn("Model type could not be determined from model config.")
        return

    if model_type not in MODEL_TYPE_TO_APPLY_FN.keys():
        warnings.warn(f"No patch supported for model type: {model_type}.")
        return

    apply_fn = MODEL_TYPE_TO_APPLY_FN[model_type]

    if patch_locations:
        if isinstance(patch_locations, dict):
            locations_kwargs = {loc: True for loc in patch_locations if patch_locations[loc] == True}
        elif isinstance(patch_locations, Iterable):
            locations_kwargs = {loc: True for loc in patch_locations}
        else:
            raise TypeError("Invalid type of `patch_locations`, must be `Iterable` or `None`.")
        print(f"Applying patch to {model_type} model in: {tuple(locations_kwargs.keys())}")
        apply_fn(model=model, **locations_kwargs, compress_kwargs=compress_kwargs)
