import math
from functools import partial
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair

import hydra

from einops import reduce, rearrange

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
    flash_attn_unpadded_qkvpacked_func = None

try:
    from src.ops.fused_dense import FusedDenseTD
except ImportError:
    FusedDenseTD = None



def pooling(x, pooling_mode='CLS', key_padding_mask=None, batch_first=True):
    if pooling_mode not in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN']:
        raise NotImplementedError(f'pooling_mode must be MEAN, SUM, CLS, LAST, FLATTEN')
    if pooling_mode in ['MEAN', 'SUM']:
        if key_padding_mask is not None:
            mask = rearrange(~key_padding_mask.bool_matrix,
                             'b s -> b s 1' if batch_first else 'b s -> s b 1')
            x = x.masked_fill(mask, 0)
        return reduce(x, 'b s ... -> b ...' if batch_first else 's b ... -> b ...',
                      pooling_mode.lower())
    elif pooling_mode == 'CLS':
        return x[:, 0] if batch_first else x[0]
    elif pooling_mode == 'LAST':
        assert key_padding_mask is None
        return x[:, -1] if batch_first else x[-1]
    elif pooling_mode == 'FLATTEN':
        return rearrange(x, 'b ... -> b (...)' if batch_first else 's b ... -> b (s ...)')


class ClassificationHeadLinear(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, d_model, num_classes, pooling_mode='MEAN',
                 batch_first=False, **kwargs):
        super().__init__()
        assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported'
        self.pooling_mode = pooling_mode
        self.batch_first = batch_first
        self.out_proj = nn.Linear(d_model, num_classes)

    def forward(self, hidden_states, key_padding_mask=None, **kwargs):
        """
            hidden_states: (B, S, D) if batch_first else (S, B, D)
        """
        hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode,
                                key_padding_mask=key_padding_mask, batch_first=self.batch_first)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/models/reformer/modeling_reformer.py
class ClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN',
                 batch_first=False):
        super().__init__()
        assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported'
        self.pooling_mode = pooling_mode
        self.batch_first = batch_first
        self.dense = nn.Linear(d_model, d_inner)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_inner, num_classes)

    def forward(self, hidden_states, key_padding_mask=None, **kwargs):
        """
            hidden_states: (B, S, D) if batch_first else (S, B, D)
        """
        hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode,
                                key_padding_mask=key_padding_mask, batch_first=self.batch_first)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        # Huggingface uses tanh instead of relu
        hidden_states = torch.relu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


class ClassificationHeadDual(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN',
                 batch_first=False, interaction='NLI'):
        super().__init__()
        assert pooling_mode in ['MEAN', 'SUM', 'CLS'], 'pooling_mode not supported'
        assert interaction in [None, 'NLI'], 'interaction not supported'
        self.pooling_mode = pooling_mode
        self.batch_first = batch_first
        self.interaction = interaction
        self.dense = nn.Linear(d_model * (4 if self.interaction == 'NLI' else 2), d_inner)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_inner, num_classes)

    def forward(self, hidden_states1, hidden_states2,
                key_padding_mask1=None, key_padding_mask2=None, **kwargs):
        """
            hidden_states: (B, S, D) if batch_first else (S, B, D)
        """
        x1 = pooling(hidden_states1, pooling_mode=self.pooling_mode,
                     key_padding_mask=key_padding_mask1, batch_first=self.batch_first)
        x2 = pooling(hidden_states2, pooling_mode=self.pooling_mode,
                     key_padding_mask=key_padding_mask2, batch_first=self.batch_first)
        hidden_states = (torch.cat([x1, x2, x1 * x2, x1 - x2], dim=-1) if self.interaction == 'NLI'
                         else torch.cat([x1, x2], dim=-1))
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        # Huggingface uses tanh instead of relu
        hidden_states = torch.relu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


class LMHead(nn.Module):

    def __init__(self, d_model, num_classes, batch_first=True, bias=True):
        super().__init__()
        self.lm_head = nn.Linear(d_model, num_classes, bias=bias)

    def forward(self, hidden_states, **kwargs):
        """
            hidden_states: (B, S, D) if batch_first else (S, B, D)
        """
        CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
        return CausalLMOutput(self.lm_head(hidden_states))


def sinusoidal_init_(tensor):
    """
        tensor: (max_len, d_model)
    """
    max_len, d_model = tensor.shape
    position = rearrange(torch.arange(0.0, max_len), 's -> s 1')
    div_term = torch.exp(-math.log(10000.0) * torch.arange(0.0, d_model, 2.0) / d_model)
    tensor[:, 0::2] = torch.sin(position * div_term)
    tensor[:, 1::2] = torch.cos(position * div_term)
    return tensor


# Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False, initializer=None):
        super().__init__()
        self.batch_first = batch_first
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.empty(max_len, d_model)
        if initializer is None:
            sinusoidal_init_(pe)
            pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d')
            self.register_buffer('pe', pe)
        else:
            hydra.utils.call(initializer, pe)
            pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d')
            self.pe = nn.Parameter(pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim] if not batch_first else [B, S, D]
            output: [sequence length, batch size, embed dim] if not batch_first else [B, S, D]
        Examples:
            >>> output = pos_encoder(x)
        """
        x = x + (self.pe[:, :x.size(1)] if self.batch_first else self.pe[:x.size(0)])
        return self.dropout(x)


# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
                 act_fn=None, drop=0., device=None, dtype=None):
        """TD [2021-10-27] act_fn takes precedence over act_layer if set.
        This is to support Pytorch 1.10 Transformer interface that construct the activation
        *function*, not the activation *layer*.
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs = _pair(drop)
        self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
        self.act = act_layer() if act_fn is None else act_fn
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class MlpBig(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
                 act_fn=None, drop=0., device=None, dtype=None):
        """Copied from Mlp above. If num_layers > 2, add more Mlp layers, doubling each time.
        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        cur_hidden_features = hidden_features
        layers = []
        for _ in range(4):
            layers.append(nn.Linear(in_features, cur_hidden_features, **factory_kwargs))
            layers.append(act_layer())
            layers.append(nn.Dropout(drop))
            in_features = cur_hidden_features
            cur_hidden_features *= 2
        layers.append(nn.Linear(in_features, out_features, **factory_kwargs))
        layers.append(nn.Dropout(drop))
        self.fwd = nn.Sequential(*layers)

    def forward(self, x):
        return self.fwd(x)

class GluMlp(nn.Module):
    """ MLP w/ GLU style gating
    See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        assert hidden_features % 2 == 0
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features // 2, out_features)
        self.drop = nn.Dropout(drop)

    def init_weights(self):
        # override init of fc1 w/ gate portion set to weight near zero, bias=1
        fc1_mid = self.fc1.bias.shape[0] // 2
        nn.init.ones_(self.fc1.bias[fc1_mid:])
        nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x, gates = x.chunk(2, dim=-1)
        x = x * self.act(gates)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class GatedMlp(nn.Module):
    """ MLP as used in gMLP
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
                 gate_layer=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        if gate_layer is not None:
            assert hidden_features % 2 == 0
            self.gate = gate_layer(hidden_features)
            hidden_features = hidden_features // 2  # FIXME base reduction on gate property?
        else:
            self.gate = nn.Identity()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.gate(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class ConvMlp(nn.Module):
    """ MLP using 1x1 convs that keeps spatial dims
    """
    def __init__(
            self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
        self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        return x


class FlashAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
    """
    def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
        super().__init__()
        assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed'
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, qkv, causal=False):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
        """
        assert qkv.dtype in [torch.float16, torch.bfloat16]
        assert qkv.is_cuda
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
        qkv = rearrange(qkv, 'b s ... -> (b s) ...')
        max_s = seqlen
        cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                  device=qkv.device)
        output = flash_attn_unpadded_qkvpacked_func(
            qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
            softmax_scale=self.softmax_scale, causal=causal
        )
        output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
        return output


class SelfAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
    """
    def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
        super().__init__()
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, qkv, causal=False):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
        """
        batch_size, seqlen = qkv.shape[0], qkv.shape[1]
        q, k, v = qkv.unbind(dim=2)
        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
        scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
        if causal:
            # triu_tril_cuda_template" not implemented for 'BFloat16'
            # So we have to construct the mask in float
            causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0,
                                                device=scores.device), 1)
            scores = scores + causal_mask.to(dtype=scores.dtype)
        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
        attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0)
        output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
        return output


class MHA(nn.Module):

    def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, dropout=0.0,
                 softmax_scale=None, causal=False, fused_bias_fc=False, use_flash_attn=False,
                 checkpointing=False, bf16=False, device=None, dtype=None, **kwargs) -> None:
        assert batch_first
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal
        self.checkpointing = checkpointing

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads

        if fused_bias_fc and FusedDenseTD is None:
            raise ImportError('fused_dense is not installed')
        linear_cls = nn.Linear if not fused_bias_fc else partial(FusedDenseTD, bf16=bf16)
        self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
        inner_attn_cls = FlashAttention if use_flash_attn else SelfAttention
        self.inner_attn = inner_attn_cls(softmax_scale=softmax_scale, attention_dropout=dropout,
                                         **factory_kwargs)
        # self.out_proj = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
        # output projection always have the bias (for now)
        self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)

    def _inner_attention(self, qkv):
        return self.inner_attn(qkv, causal=self.causal)

    def forward(self, x):
        """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
        """
        qkv = self.Wqkv(x)
        qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
        if not self.checkpointing:
            context = self.inner_attn(qkv, causal=self.causal)
        else:
            context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv)
        return self.out_proj(rearrange(context, 'b s h d -> b s (h d)'))
