import torch
import torch.nn as nn
from transformers.models.deberta.modeling_deberta import (
    DebertaOutput,
    DebertaSelfOutput,
)
from transformers.models.roberta.modeling_roberta import (
    RobertaOutput,
    RobertaSelfOutput,
)


def create_new_forward(module):
    def new_forward(
        hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        hidden_states = module.dense(hidden_states)
        hidden_states = module.adapter(hidden_states)
        hidden_states = module.dropout(hidden_states)
        hidden_states = module.LayerNorm(hidden_states + input_tensor)
        return hidden_states

    return new_forward


class Adapter(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()

        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, input_dim),
        )

    def forward(self, input):
        return self.fc(input) + input


def adapter(model, adapter_dim, device):
    model_modules = list(model.named_modules())

    with torch.no_grad():
        for name, layer in model_modules:
            if isinstance(
                layer,
                (RobertaSelfOutput, RobertaOutput, DebertaSelfOutput, DebertaOutput),
            ):
                adapter = Adapter(layer.dense.weight.size(0), adapter_dim).to(device)
                layer.register_module("adapter", adapter)
                layer.forward = create_new_forward(layer)
