import re, warnings
import torch
from torch import nn

_NORMS = (
    nn.LayerNorm,
    nn.BatchNorm1d,
    nn.BatchNorm2d,
    nn.BatchNorm3d,
    nn.InstanceNorm1d,
    nn.InstanceNorm2d,
    nn.InstanceNorm3d,
)


def split_params(config, model):  # c is task_config
    if config.lr_layer_wise_decay_rate:
        return split_params_for_layerdecaylr_and_wd(
            model=model,
            learning_rate=config.learning_rate,
            decay_rate=config.lr_layer_wise_decay_rate,
            original_way=config.lr_layer_wise_decay_original_way,
            weight_decay=config.weight_decay,
            num_hidden_layers=config.num_hidden_layers,
        )
    else:
        return split_params_for_weightdecay(
            model=model, weight_decay=config.weight_decay
        )


def split_params_for_weightdecay(model, weight_decay):
    no_decay = find_no_weightdecay_names(model)
    return [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]


def split_params_for_layerdecaylr_and_wd(
    model, learning_rate, decay_rate, original_way, weight_decay, num_hidden_layers
):

    # Find the attribute name of hidden layers and norms
    no_decay = find_no_weightdecay_names(model)
    layers_name = None
    _num_hidden_layers = None
    for name, module in model.named_modules():
        if isinstance(module, nn.ModuleList) and "layer" in name:
            if layers_name:
                msg = f"We detected both'{layers_name}' and '{name}' as the hidden_layers, but currently we only support auto splitting for model with only one attribute to hold hidden layers."
                raise RuntimeError(msg)
            else:
                layers_name = name
                _num_hidden_layers = len(module)
    if layers_name is None:
        msg = "We didn't detect any attribute to hold hidden layers. Assure the model should have a nn.ModuleList attribute to hold hidden layers."
        raise RuntimeError(msg)
    assert num_hidden_layers == _num_hidden_layers

    # Get learning rates for each groups
    if original_way:
        lrs = [
            learning_rate * (decay_rate ** depth)
            for depth in range(num_hidden_layers + 3)
        ]
        lrs.pop(1)
    else:
        lrs = [
            learning_rate * (decay_rate ** depth)
            for depth in range(num_hidden_layers + 2)
        ]
    lrs = list(reversed(lrs))

    # Get parameter groups
    ## init groups
    param_groups = [
        {"params": [], "lr": lr, "weight_decay": wd,}
        for lr in lrs
        for wd in [0.0, float(weight_decay)]
    ]

    ## insert params to its group
    before_layers = True
    for name, param in model.named_parameters():
        if layers_name in name:
            layer_i = int(re.findall(f"{layers_name}\.(.+?)\.", name)[0]) + 1
            before_layers = False
        else:
            layer_i = 0 if before_layers else num_hidden_layers + 1
        use_decay = not any(nd in name for nd in no_decay)
        param_groups[layer_i * 2 + int(use_decay)]["params"].append(param)
    param_groups = [pg for pg in param_groups if pg["params"]]  # clear empty pg

    # Final check
    ## Sum check
    assert sum([len(pg["params"]) for pg in param_groups]) == len(
        list(model.parameters())
    )
    ## finite check
    assert not any(
        param.isfinite().sum().item() != param.nelement()
        for param in model.parameters()
    ), "Some parameters are not finite / NaN"
    ## number of groups check
    assert num_hidden_layers + 2 < len(param_groups) and len(param_groups) <= 2 * (
        num_hidden_layers + 2
    ), f"Only {len(param_groups)} param groups."
    ## learning rate check
    assert param_groups[-1]["lr"] == learning_rate

    return param_groups


def find_no_weightdecay_names(model):
    no_decay = set(["bias"])
    for name, module in model.named_modules():
        if isinstance(module, _NORMS):
            norm_attr_name = name.split(".")[-1]
            no_decay.add(norm_attr_name)
        elif "norm" in name.lower():
            if any(
                isinstance(child, _NORMS) or "norm" in _name
                for _name, child in module.named_children()
            ):
                pass  # is a norm wrapper but not norm
            else:
                msg = f"There is a module named '{name}' not excluded from weight decay but may should."
                warnings.warn(UserWarning(msg))
    return no_decay
