import torch.nn as nn
from typing import List
# from mmengine.logging import MMLogger

first_set_requires_grad = True
first_set_train = True


def set_requires_grad(model: nn.Module, keywords: List[str]):
    """
    notice:key in name!
    """
    requires_grad_names = []
    num_params = 0
    num_trainable = 0
    for name, param in model.named_parameters():
        num_params += param.numel()
        if any(key in name for key in keywords):
            param.requires_grad = True
            requires_grad_names.append(name)
            num_trainable += param.numel()
        else:
            param.requires_grad = False
    global first_set_requires_grad
    # if first_set_requires_grad:
    #     # logger = MMLogger.get_current_instance()
    #     for name in requires_grad_names:
    #         logger.info(f"set_requires_grad----{name}")
    #     logger.info(
    #         f"Total trainable params--{num_trainable}, All params--{num_params}, Ratio--{num_trainable*100/num_params:.1f}%"
    #     )
    #     first_set_requires_grad = False


def _set_train(model: nn.Module, keywords: List[str], prefix: str = ""):
    train_names = []
    for name, child in model.named_children():
        fullname = ".".join([prefix, name])
        if any(name.startswith(key) for key in keywords):
            train_names.append(fullname)
            child.train()
        else:
            train_names += _set_train(child, keywords, prefix=fullname)
    return train_names


def set_train(model: nn.Module, keywords: List[str]):
    """
    notice:sub name startwith key!
    """
    model.train(False)
    train_names = _set_train(model, keywords)
    # global first_set_train
    # if first_set_train:
    #     logger = MMLogger.get_current_instance()
    #     for train_name in train_names:
    #         logger.info(f"set_train----{train_name}")
        # first_set_train = False