import torch
import torch.nn as nn
import torch.utils.checkpoint

from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import FusedMLP
from flash_attn.ops.layer_norm import dropout_add_layer_norm


class FastProbFormerBlock(nn.Module):

    def __init__(self, config, cross_attn, causal_attn, layer_idx):
        """Builds Attention module.
        """
        super().__init__()

        ff_dim = int(config.ff_factor * config.model_dim)

        """
         heuristic:
             -1: don't fuse gemm + gelu (separate kernel)
             0..4: use this heuristic for the algo section in the fused gemm + gelu
             For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
             For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
         """

        if float(torch.version.cuda) <= 11.7:
            if config.precision == 'bf16':
                heuristic = -1
            else:
                heuristic = 1
        else:
            heuristic = 0

        self.cross_attn = cross_attn
        self.unpadded = config.unpadded

        self.residual_in_fp32 = True

        self.dropout1 = nn.Dropout(config.resi_dropout)
        self.dropout2 = nn.Dropout(config.resi_dropout)
        self.norm1 = nn.LayerNorm(config.model_dim, eps=config.ln_eps)
        self.norm2 = nn.LayerNorm(config.model_dim, eps=config.ln_eps)

        head_dim = config.model_dim // config.num_head
        softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
        if config.scale_attn_by_inverse_layer_idx:  # True
            assert layer_idx is not None
            softmax_scale /= float(layer_idx + 1)

        rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
        rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0)
        rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
        rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)

        self.self_attn = MHA(config.model_dim, config.num_head, cross_attn=cross_attn,
                             qkv_proj_bias=config.use_bias, out_proj_bias=config.use_bias,
                             dropout=config.attn_dropout, softmax_scale=softmax_scale, causal=causal_attn,
                             layer_idx=layer_idx, dwconv=False,
                             rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base,
                             rotary_emb_scale_base=rotary_emb_scale_base,
                             rotary_emb_interleaved=rotary_emb_interleaved,
                             fused_bias_fc=True, use_flash_attn=True,
                             return_residual=False, checkpointing=False, device=None, dtype=None)

        self.transition = FusedMLP(config.model_dim, ff_dim, out_features=config.model_dim,
                                   bias1=config.use_bias, bias2=config.use_bias, activation='gelu_approx',
                                   return_residual=False, checkpoint_lvl=config.checkpoint_lvl, heuristic=heuristic,
                                   device=None, dtype=None)

    def forward(self, src_act, residual):

        src_act, residual = dropout_add_layer_norm(
            src_act, residual, self.norm1.weight, self.norm1.bias,
            self.dropout1.p if self.training else 0.0, self.norm1.eps,
            rowscale=None, prenorm=True, residual_in_fp32=self.residual_in_fp32
        )
        src_act = self.self_attn(src_act)

        src_act, residual = dropout_add_layer_norm(
            src_act, residual, self.norm2.weight, self.norm2.bias,
            self.dropout2.p if self.training else 0.0, self.norm2.eps,
            rowscale=None, prenorm=True, residual_in_fp32=self.residual_in_fp32
        )
        src_act = self.transition(src_act)

        return src_act, residual
