import math
import warnings
from dataclasses import dataclass
from typing import Optional, Union, Tuple

from random import random

from einops import rearrange, repeat

import torch
from torch import nn
from torch.nn import functional as F

from transformers.activations import ACT2FN
from transformers.utils import ModelOutput


@dataclass
class DiffusionOutput(ModelOutput):
    logits: torch.FloatTensor
    denoised: Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None


class DiffusionConfig:
    def __init__(
        self,
        num_embeddings: int = 30000,
        embedding_size: int = 256,
        hidden_size: int = 1024,
        num_attention_heads: int = 8,
        num_hidden_layers: int = 8,
        rotary_pct=0.25,
        hidden_act="gelu_new",
        time_embedding_size=128,
        initializer_range=0.02,
        embedding_initializer_range=0.001,
        max_position_embeddings=64,
        rotary_emb_base=10000,
        num_bins=100,
        time_wrapping_epsilon=5e-4,
        positional_embedding_type="rotary",
        time="new",
        flash_attention=False
    ):
        self.num_embeddings = num_embeddings
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.rotary_pct = rotary_pct
        self.hidden_act = hidden_act
        self.time_embedding_size = time_embedding_size
        self.initializer_range = initializer_range
        self.embedding_initializer_range = embedding_initializer_range
        self.max_position_embeddings = max_position_embeddings
        self.rotary_emb_base = rotary_emb_base
        self.num_bins = num_bins
        self.time_wrapping_epsilon = time_wrapping_epsilon
        self.positional_embedding_type = positional_embedding_type
        self.time = time
        if flash_attention:
            if hasattr(F, "scaled_dot_pruduct_attention"):
                self.flash_attention = True
            else:
                warnings.warn("Flash attention is not available. Falling back to simple self-attention.")
                self.flash_attention = False
        else:
            self.flash_attention = False


class KerpleLog(torch.nn.Module):
    """Kernelized T5 Relative Position Bias parallelized in the heads dimension"""

    def __init__(
        self,
        config,
    ):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.eps = 1e-2

        # Allocate weights and initialize.
        # The kernel has the form -p*log(1+a*|m-n|)
        def get_parameter(scale, init_method):
            if init_method == "ones":
                return nn.Parameter(
                    torch.ones(
                        self.num_heads,
                    )[:, None, None]
                    * scale
                )
            elif init_method == "uniform":
                return nn.Parameter(
                    torch.rand(
                        self.num_heads,
                    )[:, None, None]
                    * scale
                )

        self.bias_p = get_parameter(2, "uniform")
        self.bias_a = get_parameter(1, "uniform")

        self.cached_matrix = None
        self.cached_seq_len = None

    def stats(self):
        def get_stats(name, obj):
            return {
                name + "_mean": obj.mean().detach().cpu(),
                name + "_std": obj.std().detach().cpu(),
                name + "_max": obj.max().detach().cpu(),
                name + "_min": obj.min().detach().cpu(),
            }

        dd = {}
        self.bias_a.data = self.bias_a.data.clamp(min=self.eps)
        dd.update(get_stats("bias_a", self.bias_a))
        self.bias_p.data = self.bias_p.data.clamp(min=self.eps)
        dd.update(get_stats("bias_p", self.bias_p))
        return dd

    def forward(self, x):
        # [b, np, sq, sk]
        seq_len_q = x.shape[-2]
        seq_len_k = x.shape[-1]
        if self.cached_seq_len != seq_len_k:
            diff = torch.tril(
                torch.arange(seq_len_k, device=x.device)
                .view(seq_len_k, 1)
                .repeat(1, seq_len_k)
                + torch.arange(0, -seq_len_k, -1, device=x.device)
            )
            diff = diff.to(x.dtype)
            self.cached_seq_len = seq_len_k
            self.cached_matrix = diff
        else:
            diff = self.cached_matrix

        self.bias_p.data = self.bias_p.data.clamp(min=self.eps)
        self.bias_a.data = self.bias_a.data.clamp(min=self.eps)
        bias = -self.bias_p * torch.log(1 + self.bias_a * diff)  # log kernel

        if seq_len_q != seq_len_k:
            # In the train case x has dimensionality [b, np, sq, sk] with sq == sk
            # The number of query tokens is equal to the number of key tokens
            # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence)
            # In this case we use the appropriate token index of the cache matrix.
            # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used
            assert (
                seq_len_q == 1
            ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1"

            if type(bias) != float:
                # seq_len_k - 1 points to the last token index in the current inference batch.
                bias = bias[:, seq_len_k - 1, :].view(bias.shape[0], 1, bias.shape[2])

        return x + bias


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
    cos = cos[..., offset : q.shape[-2] + offset, :]
    sin = sin[..., offset : q.shape[-2] + offset, :]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(
            self.max_seq_len_cached,
            device=self.inv_freq.device,
            dtype=self.inv_freq.dtype,
        )
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(
                self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
            )
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[None, None, :, :]
            self.sin_cached = emb.sin()[None, None, :, :]
        return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[
            :seq_len, ...
        ].to(x.device)


class SinusoidalPosEmb(nn.Module):
    def __init__(self, config):
        super().__init__()
        dim = config.time_embedding_size
        self.dim = dim
        self.projector = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            ACT2FN[config.hidden_act],
            nn.Linear(dim, dim, bias=False),
            ACT2FN[config.hidden_act],
        )

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return self.projector(emb)


class LearnedSinusoidalPosEmb(nn.Module):
    """following @crowsonkb 's lead with learned sinusoidal pos emb"""

    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, config):
        super().__init__()
        dim = config.time_embedding_size
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))
        self.projector = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            ACT2FN[config.hidden_act],
            nn.Linear(dim, dim, bias=False),
            ACT2FN[config.hidden_act],
        )

    def forward(self, x):
        x = rearrange(x, "b -> b 1")
        freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
        return self.projector(fouriered)


class Attention(nn.Module):
    def __init__(self, config: DiffusionConfig):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.num_attention_heads

        self.register_buffer("masked_bias", torch.tensor(-1e9))
        self.rotary_emb = None
        if config.positional_embedding_type == "rotary":
            self.rotary_ndims = int(self.head_size * config.rotary_pct)
            self.rotary_emb = RotaryEmbedding(
                self.rotary_ndims,
                config.max_position_embeddings,
                base=config.rotary_emb_base,
            )
        self.norm_factor = torch.sqrt(
            torch.tensor(self.head_size, dtype=torch.float32)
        ).to(torch.get_default_dtype())
        self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.rpe = None
        self.flash_attention = config.flash_attention
        if config.positional_embedding_type == "kerple":
            self.rpe = KerpleLog(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        layer_past=None,
        use_cache=False,
        output_attentions=False,
    ):
        has_layer_past = layer_past is not None

        # Compute QKV
        # Attention heads [batch, seq_len, hidden_size]
        #   --> [batch, seq_len, (np * 3 * head_size)]
        qkv = self.query_key_value(hidden_states)

        # [batch, seq_len, (num_heads * 3 * head_size)]
        #   --> [batch, seq_len, num_heads, 3 * head_size]
        new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
        qkv = qkv.view(*new_qkv_shape)

        # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
        query = qkv[..., : self.head_size].permute(0, 2, 1, 3).contiguous()
        key = (
            qkv[..., self.head_size : 2 * self.head_size]
            .permute(0, 2, 1, 3)
            .contiguous()
        )
        value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3).contiguous()

        if self.rotary_emb is not None:
            # Compute rotary embeddings on rotary_ndims
            query_rot = query[..., : self.rotary_ndims]
            query_pass = query[..., self.rotary_ndims :]
            key_rot = key[..., : self.rotary_ndims]
            key_pass = key[..., self.rotary_ndims :]

            # Compute token offset for rotary embeddings (when decoding)
            seq_len = key.shape[-2]
            offset = 0
            if has_layer_past:
                offset = layer_past[0].shape[-2]
                seq_len += offset
            cos, sin = self.rotary_emb(value, seq_len=seq_len)
            query, key = apply_rotary_pos_emb(
                query_rot, key_rot, cos, sin, offset=offset
            )
            query = torch.cat((query, query_pass), dim=-1)
            key = torch.cat((key, key_pass), dim=-1)

        # Cache QKV values
        if has_layer_past:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)
        present = (key, value) if use_cache else None

        # Compute attention
        if self.flash_attention:
            attn_output = F.scaled_dot_product_attention(query, key, value, attention_mask=attention_mask, causal=False)
        else:
            attn_output, attn_weights = self._attn(
                query, key, value, attention_mask, head_mask
            )

        # Reshape outputs
        attn_output = self._merge_heads(
            attn_output, self.num_attention_heads, self.head_size
        )
        attn_output = self.dense(attn_output)

        outputs = (attn_output, present)
        if output_attentions and not self.flash_attention:
            outputs += (attn_weights,)

        return outputs

    @classmethod
    def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
        """
        Splits hidden dim into attn_head_size and num_attention_heads
        """
        # tensor: [bs, seq_len, hidden_size]
        new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
        # -> [bs, seq_len, num_attention_heads, attn_head_size]
        tensor = tensor.view(new_shape)
        # -> [bs, num_attention_heads, seq_len, attn_head_size]
        tensor = tensor.permute(0, 2, 1, 3)
        return tensor

    @classmethod
    def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden dim
        """
        # tensor [bs, num_attention_heads, seq_len, attn_head_size]
        tensor = tensor.permute(0, 2, 1, 3).contiguous()
        # -> [bs, seq_len, num_attention_heads, attn_head_size]
        tensor = tensor.view(
            tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
        )
        # -> [bs, seq_len, hidden_size]
        return tensor

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
        # compute causal mask from causal mask buffer
        batch_size, num_attention_heads, query_length, attn_head_size = query.size()
        key_length = key.size(-2)

        query = query.view(
            batch_size * num_attention_heads, query_length, attn_head_size
        )
        key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
        attn_scores = torch.zeros(
            batch_size * num_attention_heads,
            query_length,
            key_length,
            dtype=query.dtype,
            device=key.device,
        )
        attn_scores = torch.baddbmm(
            attn_scores,
            query,
            key.transpose(1, 2),
            beta=1.0,
            alpha=(
                torch.tensor(
                    1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device
                )
                / self.norm_factor
            ),
        )
        attn_scores = attn_scores.view(
            batch_size, num_attention_heads, query_length, key_length
        )

        if self.rpe is not None:
            attn_scores = self.rpe(attn_scores)

        mask_value = torch.finfo(attn_scores.dtype).min
        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
        mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(
            attn_scores.device
        )

        if attention_mask is not None:
            # Apply the attention mask
            attn_scores = attn_scores + attention_mask

        attn_weights = nn.functional.softmax(attn_scores, dim=-1)
        attn_weights = attn_weights.to(value.dtype)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights


class ConditionalLayerNorm(nn.Module):
    def __init__(self, config):
        super(ConditionalLayerNorm, self).__init__()
        self.gamma_cond = nn.Linear(config.time_embedding_size, config.hidden_size)
        self.beta_cond = nn.Linear(config.time_embedding_size, config.hidden_size)
        self.epsilon = 1e-5

    def init_weights(self):
        torch.nn.init.constant_(self.gamma_cond.weight, 0.0)
        torch.nn.init.constant_(self.gamma_cond.bias, 1.0)
        torch.nn.init.constant_(self.beta_cond.weight, 0.0)
        torch.nn.init.constant_(self.beta_cond.bias, 0.0)

    def forward(self, time_embedding, hidden_state):
        mean = hidden_state.mean(dim=-1, keepdim=True)
        var = ((hidden_state - mean) ** 2).mean(dim=-1, keepdim=True)
        std = (var + self.epsilon).sqrt()
        normalized = (hidden_state - mean) / std
        c_alpha = self.gamma_cond(time_embedding).view(hidden_state.size(0), 1, -1)
        c_beta = self.beta_cond(time_embedding).view(hidden_state.size(0), 1, -1)
        return normalized * c_alpha + c_beta


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conditional_layer_norm_1 = ConditionalLayerNorm(config)
        self.attention = Attention(config)
        self.conditional_layer_norm_2 = ConditionalLayerNorm(config)
        self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4)
        self.act = ACT2FN[config.hidden_act]
        self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size)

    def forward(self, hidden_states, time_embeddings, attention_mask=None):
        cond_1 = self.conditional_layer_norm_1(time_embeddings, hidden_states)
        attn_res = self.attention(cond_1, attention_mask=attention_mask)[0]
        hidden_states = hidden_states + attn_res
        cond_2 = self.conditional_layer_norm_2(time_embeddings, hidden_states)
        projected_up = self.act(self.up_proj(cond_2))
        projected_down = self.down_proj(projected_up)
        hidden_states = hidden_states + projected_down
        return hidden_states


class DiffusionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embeddings = nn.Embedding(config.num_embeddings, config.embedding_size)
        if config.time == "old":
            self.time_emmbeddings = LearnedSinusoidalPosEmb(config)
        else:
            self.time_emmbeddings = SinusoidalPosEmb(config)
        self.emb_to_hidden = nn.Linear(
            config.embedding_size * 4, config.hidden_size
        )  # *4 stands for conditioning_mask, token_embedding_c, token_embedding_x and self_condtioning
        self.encoders = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.num_hidden_layers)]
        )
        self.classifier = nn.Linear(
            config.hidden_size, config.num_embeddings, bias=False
        )

        for m in self.modules():
            self._init_weights(m)

        for e in self.encoders:
            e.conditional_layer_norm_1.init_weights()
            e.conditional_layer_norm_2.init_weights()

    def get_embeddings(self, input_ids):
        embeddings = self.embeddings(input_ids)
        embeddings = F.normalize(embeddings, dim=-1) * math.sqrt(
            self.config.embedding_size
        )  # ~ N(0, 1)
        return embeddings

    def get_noisy_embeddings(
        self, timestamps, conditioning_mask, input_ids, drop_conditioning_embeddings
    ):
        embeddings = self.get_embeddings(input_ids)
        if drop_conditioning_embeddings:
            conditioning_hidden_states = torch.zeros_like(embeddings)
        else:
            conditioning_hidden_states = embeddings * conditioning_mask.unsqueeze(-1)
        noise = (
            timestamps.view(-1, 1, 1)
            * torch.randn_like(embeddings)
            * (~conditioning_mask).unsqueeze(-1)
        )

        noisy_hidden_states = embeddings * (~conditioning_mask).unsqueeze(-1) + noise
        scale = 1 / torch.sqrt(timestamps**2 + 1)
        noisy_hidden_states = noisy_hidden_states * scale.view(
            -1, 1, 1
        )  # renormalize embeddings again. See Section 6.1
        return conditioning_hidden_states, noisy_hidden_states

    def forward(
        self,
        timestamps: torch.Tensor,
        conditioning_mask: Optional[torch.BoolTensor],
        input_ids: Optional[torch.LongTensor] = None,
        conditioning_hidden_states: Optional[torch.FloatTensor] = None,
        noisy_hidden_states: Optional[torch.FloatTensor] = None,
        self_conditioning_hidden_states: Optional[torch.FloatTensor] = None,
        output_denoised: bool = False,
        drop_conditioning_embeddings: bool = False,
        self_conditioning: bool = False,
        output_loss: bool = True,
        weighted_loss: bool = False,
        t_max: float = None,
        attention_mask: torch.Tensor = None,
    ) -> DiffusionOutput:
        if conditioning_hidden_states is None:
            conditioning_hidden_states, noisy_hidden_states = self.get_noisy_embeddings(
                timestamps=timestamps,
                conditioning_mask=conditioning_mask,
                input_ids=input_ids,
                drop_conditioning_embeddings=drop_conditioning_embeddings,
            )

        if self_conditioning_hidden_states is None:
            self_conditioning_hidden_states = torch.zeros_like(
                conditioning_hidden_states
            )
        mask_inp = conditioning_mask.unsqueeze(-1).expand_as(conditioning_hidden_states)

        batch_size = noisy_hidden_states.size(0)

        if self_conditioning:
            with torch.no_grad():
                s_outputs = self.forward(
                    timestamps=timestamps,
                    conditioning_mask=conditioning_mask,
                    conditioning_hidden_states=conditioning_hidden_states,
                    noisy_hidden_states=noisy_hidden_states,
                    output_loss=False,
                    output_denoised=True,
                    attention_mask=attention_mask,
                )
                self_conditioning_hidden_states = s_outputs.denoised * (
                    ~conditioning_mask.unsqueeze(-1)
                )

                # Attention mask.
        if attention_mask is not None:
            if batch_size <= 0:
                raise ValueError("batch_size has to be defined and > 0")
            attention_mask = attention_mask.view(batch_size, -1)
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, None, None, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and the dtype's smallest value for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

        embeddings_cat = torch.cat(
            (
                mask_inp,
                conditioning_hidden_states,
                noisy_hidden_states,
                self_conditioning_hidden_states,
            ),
            dim=-1,
        )

        hidden_states = self.emb_to_hidden(embeddings_cat)
        time_embeddings = self.time_emmbeddings(timestamps)
        for encoder in self.encoders:
            hidden_states = encoder(
                hidden_states=hidden_states,
                time_embeddings=time_embeddings,
                attention_mask=attention_mask,
            )

        logits = self.classifier(hidden_states)
        outputs = DiffusionOutput(logits=logits)

        loss = None
        if input_ids is not None:
            if weighted_loss:
                ce_not_weighted = torch.zeros(
                    input_ids.size(0), device=input_ids.device
                )
                for i in range(input_ids.size(0)):
                    ce_not_weighted[i] = F.cross_entropy(logits[i], input_ids[i])
                ce_weighted = ce_not_weighted * timestamps / t_max
                loss = ce_weighted.mean()
            else:
                targets = input_ids.view(-1)
                logits_r = logits.view(targets.size(0), -1)[~conditioning_mask.view(-1)]
                targets = targets[~conditioning_mask.view(-1)]
                loss = F.cross_entropy(logits_r, targets)

        denoised = None
        if output_denoised:
            with torch.no_grad():
                probs = F.softmax(logits, dim=-1)
                normalized_embeddings = F.normalize(
                    self.embeddings.weight.data, dim=-1
                ) * math.sqrt(
                    self.config.embedding_size
                )  # ~ N(0, 1)
                denoised = probs @ normalized_embeddings  # still normalized
        return DiffusionOutput(logits=logits, loss=loss, denoised=denoised)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # if module.padding_idx is not None:
            # module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
