from torch import nn
import torch.nn.functional as F
from fairseq.modules.layer_norm import LayerNorm
from fairseq.modules.fairseq_dropout import FairseqDropout


class AdapterLayer(nn.Module):
    """
    Implements Adapter Layers from
    Simple, Scalable Adaptation for Neural Machine Translation
    (https://arxiv.org/abs/1909.08478)

    Lingvo has a module reference here:
    https://tensorflow.github.io/lingvo/_modules/lingvo/core/layers.html#ResidualAdapterLayer
    """
    def __init__(self, input_dim, projection_dim, pfeiffer=False, init='small', dropout=0.0):
        super().__init__()

        self.down = nn.Linear(input_dim, projection_dim)
        self.up = nn.Linear(projection_dim, input_dim)
        self.pfeiffer = pfeiffer
        if dropout > 0.0:
            self.dropout = FairseqDropout(
                dropout, module_name=self.__class__.__name__
            )
        else:
            self.dropout = None
        if not self.pfeiffer:
            self.layer_norm = LayerNorm(input_dim)

        if init == 'small' or init == 'bert':
            if init == 'small': # and not self.pfeiffer:
                almost_zero = 1e-5
                delta = 1e-6

                def init_fn(tensor):
                    nn.init.uniform_(
                       tensor,
                       almost_zero - delta, almost_zero + delta
                    )

            if init == 'bert': # or self.pfeiffer:
                def init_fn(tensor):
                    nn.init.normal_(tensor, mean=0.0, std=0.02)

            # Init up.
            init_fn(self.up.weight)
            init_fn(self.up.bias)

            # Init down.
            init_fn(self.down.weight)
            init_fn(self.down.bias)

    def forward(self, x):
        if self.pfeiffer:
            y = self.down(x)
            y = F.relu(y)
            y = self.up(y)
        else:
            y = self.layer_norm(x)
            y = self.down(y)
            #y = self.down(x)
            y = F.relu(y)
            if self.dropout:
                y = self.dropout(y)
            y = self.up(y)
            y = x + y
        return y
