import torch
import torch.nn as nn
import torch.nn.functional as F
from .activations import ACT2FN


class FeedForward(nn.Module):
    def __init__(self, hidden_size, intermediate_size, dropout=0.0, act_type="swish"):
        super().__init__()
        self.up_proj = nn.Linear(hidden_size, intermediate_size)
        self.down_proj = nn.Linear(intermediate_size, hidden_size)
        self.dropout = dropout
        self.act_fn = ACT2FN[act_type]
        self.reset_paramters()

    def forward(self, hidden_states):
        hidden_states = self.up_proj(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        if self.dropout > 0.0:
            hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = self.down_proj(hidden_states)
        return hidden_states

    def reset_paramters(self):
        nn.init.xavier_uniform_(self.up_proj.weight.data)
        nn.init.xavier_uniform_(self.down_proj.weight.data)
        nn.init.constant_(self.up_proj.bias.data, 0.)
        nn.init.constant_(self.down_proj.bias.data, 0.)