#! -*- coding: utf-8
import re
import typing
from logging import getLogger

import numpy as np
import torch

from .resnet_group_norm import (ResNet18, ResNet34, ResNet50, ResNet101,
                                ResNet152)
from .vgg import VGG11, VGG13, VGG16, VGG19

_LOGGER = getLogger("trainers.models.utils")


# def load_model(name: str, *args, **kwargs) -> torch.nn.Module:
def load_model(name: str, args=[], kwargs={}, **configs) -> torch.nn.Module:
    ctor: typing.Callable = None
    if name == "ResNet18":
        ctor = ResNet18
    elif name == "ResNet34":
        ctor = ResNet34
    elif name == "ResNet50":
        ctor = ResNet50
    elif name == "ResNet101":
        ctor = ResNet101
    elif name == "VGG11":
        ctor = VGG11
    elif name == "VGG13":
        ctor = VGG13
    elif name == "VGG16":
        ctor = VGG16
    elif name == "VGG19":
        ctor = VGG19
    elif name == "LogisticRegressionModel":
        from .logistic_regresion_model import LogisticRegressionModel
        ctor = LogisticRegressionModel
    elif name == "MLPModel":
        from .mlp import MLPModel
        ctor = MLPModel
    elif name == "ResNet152":
        ctor = ResNet152
    elif name == "AutoModelForSequenceClassification":
        import peft
        from transformers import (AutoModelForSequenceClassification,
                                  BitsAndBytesConfig)
        _LOGGER.info(f"configs: {configs}")

        quantization_config = configs.pop("quantization", None)
        lora_config = configs.pop("lora", None)

        if quantization_config is not None:
            quantization_config = BitsAndBytesConfig(**quantization_config)
        _LOGGER.info(f"quantization: {quantization_config}")
        model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs,
                                                                   quantization_config=quantization_config)
        if lora_config is not None:
            lora_config = peft.LoraConfig(**lora_config)
            _LOGGER.info(f"LoRA: {lora_config}")
            model.enable_input_require_grads()
            model.gradient_checkpointing_enable()
            model = peft.get_peft_model(model, lora_config)

        trainable_targets = [re.compile(pattern)
                             for pattern in configs.pop("trainables", [".*"])]
        for param_name, params in model.named_parameters():
            if np.all([reg.match(param_name) is None for reg in trainable_targets]):
                params.requires_grad = False
                _LOGGER.debug(f"fixed parameter: {param_name}")
            else:
                _LOGGER.debug(f"trainable parameter: {param_name}")

        ntrainables, nparams = count_params(model)
        _LOGGER.info(f"trainable parameters: {ntrainables} / {nparams}")
        assert ntrainables > 0, f"no trainable parameters. model: {model}"

        return model
    else:
        raise ValueError(f"Unsupported model: {name}")

    model = ctor(*args, **kwargs)
    trainable_targets = [re.compile(pattern)
                         for pattern in configs.pop("trainables", [])]
    if len(trainable_targets) > 0:
        for param_name, params in model.named_parameters():
            if np.all([reg.match(param_name) is None for reg in trainable_targets]):
                params.requires_grad = False
                _LOGGER.debug(f"fixed parameter: {param_name}")
            else:
                _LOGGER.debug(f"trainable parameter: {param_name}")

        ntrainables, nparams = count_params(model)
        _LOGGER.info(f"trainable parameters: {ntrainables} / {nparams}")
        assert ntrainables > 0, f"no trainable parameters. model: {model}"

    return model


def count_params(model: torch.nn.Module) -> typing.Tuple[int, int]:
    # return trainable parameters and all parameters
    ntrainables, nparams = 0, 0
    for param in model.parameters():
        nparams += param.numel()
        if param.requires_grad:
            ntrainables += param.numel()

    return ntrainables, nparams
