import torch.nn as nn
import math


def _init_weights(config, module):
    """Initialize the weights."""
    if isinstance(module, nn.Linear):  # (nn.Linear,Conv1D)
        # Slightly different from the TF version which uses truncated_normal for initialization
        # cf https://github.com/pytorch/pytorch/pull/5617
        module.weight.data.normal_(mean=0.0, std=config.initializer_range)
        if module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=config.initializer_range)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

    # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
    #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
    #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
    #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
    #
    # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
    
    # for name, p in module.named_parameters():
    #     if name == "c_proj.weight":
    #         # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
    #         p.data.normal_(
    #             mean=0.0,
    #             std=(config.initializer_range / math.sqrt(2 * config.nlayers)),
    #         )
