
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint

from workbench.gpt_model.modules.attention import AttentionLayer
from workbench.gpt_model.modules.feed_forward import FeedForward
from workbench.gpt_model.modules.norm import RMSNorm

from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import FusedMLP

class TransformerBlock(nn.Module):


    def __init__(self, config, cross_attn, causal_attn):
        """Builds Attention module.
        """
        super().__init__()

        ff_dim = int(config.ff_factor * config.model_dim)


        self.cross_attn = cross_attn
        self.residual_in_fp32 = True

        self.dropout1 = nn.Dropout(config.resi_dropout)
        self.dropout2 = nn.Dropout(config.resi_dropout)

        if config.rms_norm:
            self.norm1 = RMSNorm(config.model_dim, eps=config.ln_eps, add_unit_offset=config.add_unit_offset)
            self.norm2 = RMSNorm(config.model_dim, eps=config.ln_eps, add_unit_offset=config.add_unit_offset)
        else:
            self.norm1 = nn.LayerNorm(config.model_dim, eps=config.ln_eps, bias=config.use_bias, elementwise_affine=config.learn_ln)
            self.norm2 = nn.LayerNorm(config.model_dim, eps=config.ln_eps, bias=config.use_bias, elementwise_affine=config.learn_ln)


        head_dim = config.model_dim // config.num_head
        softmax_scale = 1.0 if not config.softmax_scale else head_dim ** (-0.5)
        # if config.scale_attn_by_inverse_layer_idx:  # True
        #     assert layer_idx is not None
        #     softmax_scale /= float(layer_idx + 1)

        rotary_emb_dim = 0.5 * head_dim # int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
        rotary_emb_base = 10000.0 #getattr(config, 'rotary_emb_base', 10000.0)
        rotary_emb_scale_base = None # getattr(config, 'rotary_emb_scale_base', None)
        rotary_emb_interleaved = False #getattr(config, 'rotary_emb_interleaved', False)

        self.self_attn = MHA(config.model_dim, config.num_head, cross_attn=cross_attn,
                             qkv_proj_bias=config.use_bias, out_proj_bias=config.use_bias,
                             dropout=config.attn_dropout, softmax_scale=softmax_scale, causal=causal_attn,
                             layer_idx=None, dwconv=False,
                             rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base,
                             rotary_emb_scale_base=rotary_emb_scale_base,
                             rotary_emb_interleaved=rotary_emb_interleaved,
                             fused_bias_fc=True, use_flash_attn=True,
                             return_residual=False, checkpointing=False, device=None, dtype=None)

        self.transition = FeedForward(config.model_dim, ff_dim, use_bias=config.use_bias, activation=config.activation, glu=config.glu)



    def forward(self, hidden_states, attn_mask=None):

        hidden_states = hidden_states + self.dropout1(self.self_attn(self.norm1(hidden_states)))


        hidden_states = hidden_states + self.dropout2(self.transition(self.norm2(hidden_states)))

        return hidden_states
