import torch.nn as nn

from flash_attn.ops.layer_norm import dropout_add_layer_norm
from cpr.llm_transformer.model.fast_probformer_block import FastProbFormerBlock


class FastTransformerStack(nn.Module):

    def __init__(self, config, cross_attn, causal_attn):
        super().__init__()

        module_list = []
        for idx in range(config.n_layers):
            layer = FastProbFormerBlock(config=config, cross_attn=cross_attn, causal_attn=causal_attn,
                                         layer_idx=idx)
            module_list.append(layer)
        self.layers = nn.ModuleList(module_list)

        self.residual_in_fp32 = True

        self.drop_f = nn.Dropout(config.resi_dropout)
        self.ln_f = nn.LayerNorm(config.model_dim, eps=config.ln_eps)


    def forward(self, src_act):

        residual = None
        for idx, layer in enumerate(self.layers):
            src_act, residual = layer(src_act, residual)

        src_act = dropout_add_layer_norm(
            src_act, residual, self.ln_f.weight, self.ln_f.bias,
            self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
            residual_in_fp32=self.residual_in_fp32
        )
        return src_act


