import logging
import torch.nn as nn


_logger = logging.getLogger(__name__)


def prompt_filter_weight_decay(
        model: nn.Module,
        weight_decay=1e-5,
        no_weight_decay_list=()
):
    no_weight_decay_list = set(no_weight_decay_list)
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad or "prompt" in name:
            continue

        if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
            no_decay.append(param)
        else:
            decay.append(param)

    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

def param_filter_weight_decay_fn(
        model: nn.Module,
        weight_decay=1e-5,
        no_weight_decay_list=(),
        filter_terms=["prompt"]
):
    no_weight_decay_list = set(no_weight_decay_list)
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad or any([term in name for term in filter_terms]):
            _logger.info(f"Skipping {name}")
            continue

        if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
            _logger.info(f"Adding {name} to no_decay list")
            no_decay.append(param)
        else:
            decay.append(param)

    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]
