import torch
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.activations import ACT2FN
from torch.nn import functional as F

class ResidualLayer(torch.nn.Module):
    def __init__(self, config: LlamaConfig, hidden_dim, up_dim, rank, idx):
        super().__init__()
        self.rank = rank
        self.hidden_dim = hidden_dim
        self.up_dim = up_dim
        self.compress_act = ACT2FN[config.compress_act_fn]
        self.enable_act = config.compress_mode == 'low_rank_with_act_fn'
        self.A = torch.nn.Linear(self.hidden_dim, self.rank)
        self.B = torch.nn.Linear(self.rank, self.up_dim)
        self.beta = torch.nn.Parameter(torch.ones(1))

        
    def forward(self, x, last_layer_states):
        x1 = self.A(x)
        if self.enable_act:
            x1 = self.compress_act(x1)
        x2 = self.B(x1)
        x3 = x2 + last_layer_states * self.beta

        return x3