import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

from .activations import ACT2FN
from .normalizations import NORM2FN

# pylint:disable=no-member


class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.act_fn = ACT2FN[config.hidden_act]
        self.reset_paramters()

    def forward(self, hidden_states):
        hidden_states = self.up_proj(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.down_proj(hidden_states)
        return hidden_states

    def reset_paramters(self):
        nn.init.xavier_uniform_(self.up_proj.weight)
        nn.init.xavier_uniform_(self.down_proj.weight)
        nn.init.constant_(self.up_proj.bias, 0.)
        nn.init.constant_(self.down_proj.bias, 0.)


class FeedForwardGLU(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size)

        self.act_fn = ACT2FN[config.hidden_act]
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.reset_paramters()

    def forward(self, hidden_states):
        hidden_states, hidden_states_2 = self.up_proj(hidden_states).chunk(2, dim=-1)
        hidden_states = self.act_fn(hidden_states) * hidden_states_2
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.down_proj(hidden_states)
        return hidden_states

    def reset_paramters(self):
        nn.init.xavier_uniform_(self.up_proj.weight)
        nn.init.xavier_uniform_(self.down_proj.weight)
        nn.init.constant_(self.up_proj.bias, 0.)
        nn.init.constant_(self.down_proj.bias, 0.)
