import torch
from layers import TransformerFFN
from layers.attention import AttentionCore
import math


def gpt2_init(model: torch.nn.Module, n_layers: int):
    for m in model.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.normal_(m.weight, std=0.02)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        elif isinstance(m, torch.nn.Embedding):
            torch.nn.init.normal_(m.weight, std=0.02)
        elif isinstance(m, torch.nn.LayerNorm):
            if m.weight is not None:
                torch.nn.init.ones_(m.weight)

            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    resstd = 0.02/math.sqrt(2 * n_layers)
    for m in model.modules():
        if isinstance(m, TransformerFFN):
            torch.nn.init.normal_(m.linear2.weight, mean=0.0, std=resstd)
        elif isinstance(m, AttentionCore):
            torch.nn.init.normal_(m.o.weight, mean=0.0, std=resstd)

    return model
