#! -*- coding: utf-8
import re
import typing

import numpy as np
import torch

__all__ = ["load_model"]


def count_params(model: torch.nn.Module) -> typing.Tuple[int, int]:
    # return trainable parameters and all parameters
    ntrainables, nparams = 0, 0
    for param in model.parameters():
        nparams += param.numel()
        if param.requires_grad:
            ntrainables += param.numel()

    return ntrainables, nparams


def set_trainables(model: torch.nn.Module, trainables: typing.List[str] = [".*"], frozens: typing.List[str] = []) -> torch.nn.Module:
    trainable_targets = [re.compile(pattern) for pattern in trainables]
    frozen_targets = [re.compile(pattern) for pattern in frozens]
    for param_name, params in model.named_parameters():
        if np.all([reg.match(param_name) is None for reg in trainable_targets]) \
                or np.any([reg.match(param_name) is not None for reg in frozen_targets]):
            params.requires_grad = False

    ntrainables, nparams = count_params(model)
    assert ntrainables > 0, f"no trainable parameters. model: {model}"

    return model


def load_model(name: str, *args, trainables: typing.List[str] = [".*"], frozens: typing.List[str] = [],
               **kwargs) -> torch.nn.Module:

    if name == "CNN":
        from .cnn import CNN
        return set_trainables(CNN(*args, **kwargs),
                              trainables=trainables, frozens=frozens)
    elif name == "CNN_GNORM":
        from .cnn_gnorm import CNN
        return set_trainables(CNN(*args, **kwargs),
                              trainables=trainables, frozens=frozens)

    elif name == "ResNet18":
        from .resnet_group_norm import ResNet18
        return set_trainables(ResNet18(*args, **kwargs),
                              trainables=trainables, frozens=frozens)

    elif name == "LeNet5":
        from .lenet5 import LeNet5
        return set_trainables(LeNet5(*args, **kwargs),
                              trainables=trainables, frozens=frozens)
    elif name == "LeNet5_gnorm":
        from .lenet5_gnorm import LeNet5
        return set_trainables(LeNet5(*args, **kwargs),
                              trainables=trainables, frozens=frozens)

    elif name == "AutoModelForSequenceClassification":
        import peft
        from transformers import (AutoModelForSequenceClassification,
                                  BitsAndBytesConfig)

        quantization_config = kwargs.pop("quantization", None)
        lora_config = kwargs.pop("lora", None)

        if quantization_config is not None:
            quantization_config = BitsAndBytesConfig(**quantization_config)
        model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs,
                                                                   quantization_config=quantization_config)
        if lora_config is not None:
            lora_config = peft.LoraConfig(**lora_config)
            model.enable_input_require_grads()
            model.gradient_checkpointing_enable()
            model = peft.get_peft_model(model, lora_config)

        return set_trainables(model, trainables=trainables, frozens=frozens)

    else:
        raise ValueError(f"Unsupported model: {name}")
