import os, sys
import argparse
import math, random
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm

from custom_transformer import FMoETransformerMLP
from custom_gates import *
import cmath


class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
    def __init__(
        self,
        gate,
        hidden_size,
        inner_hidden_size,
        dropout,
        pre_lnorm=False,
        moe_num_expert=16,
        moe_top_k=2,
    ):
        activation = nn.Sequential(nn.ReLU(), nn.Dropout(dropout))
        super().__init__(
            num_expert=moe_num_expert,
            d_model=hidden_size,
            d_hidden=inner_hidden_size,
            moe_top_k=moe_top_k,
            activation=activation,
            gate=gate,
        )
        # self.layer_norm = nn.LayerNorm(hidden_size)
        #self.dropout = nn.Dropout(dropout)
        self.load_balancing_loss = 0

    def forward(self, inp):
        ##### positionwise feed-forward
        core_out = super().forward(inp)
        # core_out = self.dropout(core_out)

        ##### residual connection + layer normalization
        # output = self.layer_norm(inp + core_out)

        return core_out