import torch
import torch.nn as nn

from transformers.activations import ACT2FN


class SVD_LlamaMLP(nn.Module):
    """ SVD MLP layer for Llama model.
    """

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        ratio: float = 1.0,
    ) -> None:
        super().__init__()

        rank = int(hidden_size * intermediate_size * ratio /
                   (hidden_size + intermediate_size))

        self.down_u_proj = nn.Linear(
            in_features=rank,
            out_features=hidden_size,
            bias=False,
        )
        self.down_v_proj = nn.Linear(
            in_features=intermediate_size,
            out_features=rank,
            bias=False,
        )

        self.gate_u_proj = nn.Linear(
            in_features=rank,
            out_features=intermediate_size,
            bias=False,
        )
        self.gate_v_proj = nn.Linear(
            in_features=hidden_size,
            out_features=rank,
            bias=False,
        )

        self.up_u_proj = nn.Linear(
            in_features=rank,
            out_features=intermediate_size,
            bias=False,
        )
        self.up_v_proj = nn.Linear(
            in_features=hidden_size,
            out_features=rank,
            bias=False,
        )

        self.act_fn = ACT2FN[hidden_act]

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        gate = self.gate_u_proj(self.gate_v_proj(x))
        up = self.up_u_proj(self.up_v_proj(x))
        down = self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up))

        return down
