import loralib as lora
from torch import nn


def replace_lora_layers(module, r, lora_alpha, lora_dropout):

    for name, child in module.named_children():
        if "pair_update" in name:
            continue

        if isinstance(child, nn.Linear):

            in_features = child.in_features
            out_features = child.out_features
            bias = child.bias is not None

            new_layer = lora.Linear(
                in_features,
                out_features,
                r,
                lora_alpha,
                lora_dropout,
                merge_weights=False,
                bias=bias,
            )
            setattr(module, name, new_layer)
        elif isinstance(child, nn.Embedding):

            num_embeddings = child.num_embeddings
            embedding_dim = child.embedding_dim
            padding_idx = child.padding_idx
            max_norm = child.max_norm
            norm_type = child.norm_type
            scale_grad_by_freq = child.scale_grad_by_freq
            sparse = child.sparse
            new_layer = lora.Embedding(
                num_embeddings,
                embedding_dim,
                r,
                lora_alpha,
                merge_weights=False,
                padding_idx=padding_idx,
                max_norm=max_norm,
                norm_type=norm_type,
                scale_grad_by_freq=scale_grad_by_freq,
                sparse=sparse,
            )
            setattr(module, name, new_layer)
        else:

            replace_lora_layers(child, r, lora_alpha, lora_dropout)
