import math

import torch
import torch.nn as nn

from cpr.llm_transformer.model.embedding import PosEmbedding
from cpr.llm_transformer.model.fast_probformer import FastTransformerStack


class DecoderTransformer(nn.Module):

    def __init__(self, config):
        super().__init__()

        assert config.seq_vocab_size == config.trg_vocab_size

        self.trg_embed = PosEmbedding(config.trg_vocab_size, config.model_dim, config.max_len, config.pos_embedding,
                                      config.rel_pos_enc, config.initializer_range)

        self.decoder = FastTransformerStack(config, cross_attn=False, causal_attn=True)

        self.output = self.trg_embed.embed_seq

        self.initialize(config.initializer_range, config.trg_vocab_size)

    def initialize(self, initializer_range, vocab_size):

        for n, p in self.named_parameters():
            if 'bias' in n:
                nn.init.zeros_(p)
            elif 'norm' in n or 'ln' in n:
                continue
            elif p.shape == torch.Size([1]):
                continue
            elif 'embed_seq' in n:
                if initializer_range:
                    nn.init.normal_(p, mean=0.0, std=initializer_range)
                else:
                    nn.init.normal_(p, mean=0.0, std=1.0 / math.sqrt(vocab_size))
            else:
                if initializer_range:
                    nn.init.normal_(p, mean=0.0, std=initializer_range)
                else:
                    nn.init.xavier_uniform_(p)

    def forward(self, trg_shf_seq=None):

        trg_shift_embed = self.trg_embed(trg_shf_seq)
        latent = self.decoder(trg_shift_embed)
        logits = self.output(latent)

        return logits
