import copy
import math
from functools import partial
from collections import namedtuple
from typing import Optional, Union, Callable

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor

from torchvision.ops import StochasticDepth

from einops import rearrange, repeat

import hydra
import omegaconf

from timm.models.layers import PatchEmbed, lecun_normal_

from src.models.modules.masking import LengthMask
from src.models.modules.seq_common import ClassificationHead, PositionalEncoding, Mlp
from src.models.ssm.s4 import S4
from src.models.modules.seq_common import MHA

try:
    from src.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None


# Adapted from https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#S4EncoderLayer
class S4EncoderLayer(nn.Module):
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model, d_inner=2048, ssm_cls=S4, ssm_cfg=None, 
                 ssm2_cfg=None, attn_cls=None, attn_cfg=None,
                 mlp_cls=None, resid_dropout=0.1, dropout_cls=nn.Dropout,
                 drop_path=None, activation=F.gelu,
                 layer_norm_eps=1e-5, batch_first=False, norm_first=False, fused_dropout_add_ln=False,
                 layer_idx=None, ssm2_layer_idx=None, attn_layer_idx=None, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        assert batch_first
        super().__init__()
        self.norm_first = norm_first
        self.layer_idx = layer_idx
        self.drop_path = StochasticDepth(drop_path[layer_idx], mode='row') if drop_path is not None else nn.Identity()
        self.fused_dropout_add_ln = fused_dropout_add_ln
        if fused_dropout_add_ln:
            assert dropout_cls is nn.Dropout
            assert drop_path is None or drop_path[layer_idx] == 0.0, 'FusedDropoutAddLN does not support DropPath'
        if attn_layer_idx is not None and layer_idx in attn_layer_idx:
            if attn_cls is None:
                attn_cls = MHA
            causal = attn_cfg.pop('causal', True)
            self.mixer = attn_cls(d_model, causal=causal,
                                  **(attn_cfg if attn_cfg is not None else {}))
        elif ssm2_layer_idx is not None and layer_idx in ssm2_layer_idx:
            self.mixer = ssm_cls(d_model, **(ssm2_cfg if ssm2_cfg is not None else {}))
        else:
            self.mixer = ssm_cls(d_model, **(ssm_cfg if ssm_cfg is not None else {}))
        # Implementation of Feedforward model
        if mlp_cls is None:
            # Legacy string support for activation function.
            if isinstance(activation, str):
                activation = _get_activation_fn(activation)
            self.mlp = Mlp(d_model, d_inner, act_fn=activation, **factory_kwargs)
        else:
            # Legacy string support for activation function.
            if isinstance(activation, str):
                activation = _get_activation_fn(activation)
            # self.mlp = mlp_cls(d_model, d_inner, act_fn=activation, **factory_kwargs)
            self.mlp = mlp_cls(d_model, d_inner, **factory_kwargs)
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = dropout_cls(resid_dropout)
        if not isinstance(self.mlp, nn.Identity):
            self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
            self.dropout2 = dropout_cls(resid_dropout)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.gelu
        super(S4EncoderLayer, self).__setstate__(state)

    def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, **kwargs) -> Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: if postnorm, residual=None, If prenorm, hidden_states = LayerNorm(residual)

        Shape:
            see the docs in S4Sequence class.
        """
        if self.norm_first:
            assert residual is not None
            mixer_out = self.mixer(hidden_states)
            if not self.fused_dropout_add_ln:
                residual = self.drop_path(self.dropout1(mixer_out)) + residual
                hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
            else:
                hidden_states, residual = dropout_add_layer_norm(
                    mixer_out, residual, self.norm1.weight, self.norm1.bias,
                    self.dropout1.p if self.training else 0.0, self.norm1.eps, prenorm=True
                )
            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
                if not self.fused_dropout_add_ln:
                    residual = self.drop_path(self.dropout2(mlp_out)) + residual
                    hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
                else:
                    hidden_states, residual = dropout_add_layer_norm(
                        mlp_out, residual, self.norm2.weight, self.norm2.bias,
                        self.dropout2.p if self.training else 0.0, self.norm2.eps, prenorm=True
                    )
            return hidden_states, residual
        else:
            assert residual is None
            mixer_out = self.mixer(hidden_states)
            if not self.fused_dropout_add_ln:
                hidden_states = self.norm1((self.drop_path(self.dropout1(mixer_out))
                                            + hidden_states).to(dtype=self.norm1.weight.dtype))
            else:
                hidden_states = dropout_add_layer_norm(
                    mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
                    self.dropout1.p if self.training else 0.0, self.norm1.eps, prenorm=False
                )
            if not isinstance(self.mlp, nn.Identity):
                mlp_out = self.mlp(hidden_states)
                if not self.fused_dropout_add_ln:
                    hidden_states = self.norm2((self.drop_path(self.dropout2(mlp_out))
                                                + hidden_states).to(dtype=self.norm2.weight.dtype))
                else:
                    hidden_states = dropout_add_layer_norm(
                        mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
                        self.dropout2.p if self.training else 0.0, self.norm2.eps, prenorm=False
                    )
            return hidden_states


class S4Encoder(nn.Module):
    r"""S4Encoder is a stack of N encoder layers

    Args:
        encoder_layer: an instance of the S4EncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).

    Examples::
        >>> encoder_layer = nn.S4EncoderLayer(d_model=512)
        >>> transformer_encoder = nn.S4Encoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = transformer_encoder(src)
    """
    __constants__ = ['norm']

    def __init__(self, encoder_layer, num_layers, norm_first=False, norm=None):
        super().__init__()
        self.norm_first = norm_first
        self.layers = nn.ModuleList([encoder_layer(norm_first=norm_first, layer_idx=i)
                                     for i in range(num_layers)])

        self.num_layers = num_layers
        self.norm = norm

    def forward(self, hidden_states: Tensor, **kwargs) -> Tensor:
        r"""Pass the input through the encoder layers in turn.

        Args:
            hidden_states: the sequence to the encoder (required).
            mask: the mask for the hidden_states sequence (optional).
            src_key_padding_mask: the mask for the hidden_states keys per batch (optional).

        Shape:
            see the docs in S4Sequence class.
        """
        if self.norm_first:
            residual = hidden_states
            if self.norm is not None:
                hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
            for layer in self.layers:
                hidden_states, residual = layer(hidden_states, residual, **kwargs)
        else:
            for layer in self.layers:
                hidden_states = layer(hidden_states, **kwargs)
            if self.norm is not None:
                hidden_states = self.norm(hidden_states)
        return hidden_states


class S4Sequence(nn.Module):
    r"""
    Args:
        d_model: the number of expected features in the encoder/decoder inputs (default=512).
        n_layer: the number of sub-encoder-layers in the encoder (default=6).
        d_inner: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of encoder/decoder intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: gelu
        custom_encoder: custom encoder (default=None).
        custom_decoder: custom decoder (default=None).
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
            other attention and feedforward operations, otherwise after. Default: ``False`` (after).

    Examples::
        >>> transformer_model = nn.S4Sequence(n_layer=12)
        >>> src = torch.rand((10, 32, 512))
        >>> tgt = torch.rand((20, 32, 512))
        >>> out = transformer_model(src, tgt)

    Note: A full example to apply nn.S4Sequence module for the word language model is available in
    https://github.com/pytorch/examples/tree/master/word_language_model
    """

    def __init__(self, d_model: int = 512, n_layer: int = 6, d_inner: int = 2048,
                 ssm_cfg=None, mlp_cls=None, resid_dropout: float = 0.1, dropout_cls=nn.Dropout,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.gelu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 encoder_final_norm: bool = True, drop_path=None,
                 device=None, dtype=None, **kwargs) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.d_model = d_model
        self.batch_first = batch_first
        encoder_layer = partial(S4EncoderLayer, d_model, d_inner=d_inner, ssm_cfg=ssm_cfg,
                                mlp_cls=mlp_cls, resid_dropout=resid_dropout,
                                dropout_cls=dropout_cls, activation=activation,
                                layer_norm_eps=layer_norm_eps, batch_first=batch_first,
                                drop_path=drop_path, **kwargs)
        encoder_norm = (nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
                        if encoder_final_norm else None)
        self.encoder = S4Encoder(encoder_layer, n_layer, norm_first=norm_first, norm=encoder_norm)

    def forward(self, src: Tensor, **kwargs) -> Tensor:
        r"""Take in and process masked source/target sequences.

        Args:
            src: the sequence to the encoder (required).
            tgt: the sequence to the decoder (required).
            src_mask: the additive mask for the src sequence (optional).
            tgt_mask: the additive mask for the tgt sequence (optional).
            memory_mask: the additive mask for the encoder output (optional).
            src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
            tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
            memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).

        Shape:
            - src: :math:`(S, N, E)`, `(N, S, E)` if batch_first.
            - tgt: :math:`(T, N, E)`, `(N, T, E)` if batch_first.
            - src_mask: :math:`(S, S)`.
            - tgt_mask: :math:`(T, T)`.
            - memory_mask: :math:`(T, S)`.
            - src_key_padding_mask: :math:`(N, S)`.
            - tgt_key_padding_mask: :math:`(N, T)`.
            - memory_key_padding_mask: :math:`(N, S)`.

            Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
            positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
            while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
            are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
            is provided, it will be added to the attention weight.
            [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
            the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
            positions will be unchanged. If a BoolTensor is provided, the positions with the
            value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.

            - output: :math:`(T, N, E)`, `(N, T, E)` if batch_first.

            Note: Due to the multi-head attention architecture in the transformer model,
            the output sequence length of a transformer is same as the input sequence
            (i.e. target) length of the decode.

            where S is the source sequence length, T is the target sequence length, N is the
            batch size, E is the feature number

        Examples:
            >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
        """

        output = self.encoder(src)
        return output


class S4wHead(nn.Module):

    def __init__(self, d_model: int, n_layer: int, d_inner: int, ssm_cfg=None,
                 mlp_cls=None, embedding_cls=None, head_cls=None,
                 norm_first=False, resid_dropout: float = 0.1, dropout_cls=nn.Dropout,
                 activation: str = "gelu", layer_norm_eps: float = 1e-5,
                 batch_first: bool = False, encoder_final_norm: bool = False, **kwargs) -> None:
        super().__init__()
        self.embedding = embedding_cls(d_model) if embedding_cls is not None else nn.Identity()
        self.batch_first = batch_first
        self.s4seq = S4Sequence(d_model, n_layer, d_inner, ssm_cfg, mlp_cls, resid_dropout, dropout_cls,
                                activation, layer_norm_eps, batch_first, norm_first,
                                encoder_final_norm, **kwargs)
        self.head = (head_cls(d_model, batch_first=batch_first)
                     if head_cls is not None else nn.Identity())
        self.tie_weights()

    def tie_weights(self):
        pass

    def forward_features(self, src: Tensor, lengths=None, **kwargs) -> Tensor:
        if lengths is not None:
            src_key_padding_mask = LengthMask(lengths,
                                              max_len=src.size(1 if self.batch_first else 0),
                                              device=src.device)
        else:
            src_key_padding_mask = None
        src = self.embedding(src)
        features = self.s4seq(src, **kwargs)
        return features, src_key_padding_mask

    def forward(self, src: Tensor, lengths=None, **kwargs) -> Tensor:
        features, src_key_padding_mask = self.forward_features(src, lengths=lengths, **kwargs)
        return self.head(features, key_padding_mask=src_key_padding_mask)


class S4DualHeads(S4wHead):

    def forward(self, src1: Tensor, src2: Tensor,
                lengths1=None, lengths2=None,
                **kwargs) -> Tensor:
        features1, src1_key_padding_mask = self.forward_features(src1, lengths=lengths1, **kwargs)
        features2, src2_key_padding_mask = self.forward_features(src2, lengths=lengths2, **kwargs)
        return self.head(features1, features2,
                         key_padding_mask1=src1_key_padding_mask,
                         key_padding_mask2=src2_key_padding_mask)

def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True,
                  glu_act=True):
    """Trying to match GPT2 init
    """
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=initializer_range)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
            # If using GLU activation for now, we scale the std by 2
            elif name in ["output_linear.0.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                if not glu_act:
                    nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
                else:
                    out_features = p.shape[0]
                    # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5
                    # on average.
                    nn.init.normal_(p[:out_features // 2], mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) * 2)


class S4LM(nn.Module):

    def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, ssm_cfg=None,
                 mlp_cls=None, pos_emb_cls=None, norm_first=False, resid_dropout: float = 0.1,
                 embed_dropout: float = 0.1, dropout_cls=nn.Dropout, activation: str = "gelu",
                 layer_norm_eps: float = 1e-5, batch_first: bool = False,
                 initializer_cfg=None, pad_vocab_size_multiple_8: int = False, **kwargs) -> None:
        super().__init__()
        self.batch_first = batch_first
        if pad_vocab_size_multiple_8:
            if vocab_size % 8 != 0:
                vocab_size += 8 - (vocab_size % 8)
        self.embedding = torch.nn.Embedding(vocab_size, d_model)
        if pos_emb_cls is not None:
            self.pos_embedding = pos_emb_cls(d_model)
        else:
            self.register_parameter('pos_embedding', None)
        self.drop = dropout_cls(embed_dropout)
        self.s4seq = S4Sequence(d_model, n_layer, d_inner, ssm_cfg, mlp_cls, resid_dropout, dropout_cls,
                                activation, layer_norm_eps, batch_first, norm_first,
                                encoder_final_norm=True, **kwargs)
        # encoder_final_norm=True to match GPT2 model
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)  # Following GPT2, bias=False
        # Without custom weight init, the loss starts out very high (~395 instead of ~11)
        self.apply(partial(_init_weights, n_layer=n_layer,
                           **(initializer_cfg if initializer_cfg is not None else {})))
        self.tie_weights()

    def tie_weights(self):
        self.lm_head.weight = self.embedding.weight

    def forward_features(self, src: Tensor) -> Tensor:
        hidden_states = self.embedding(src)
        if self.pos_embedding is not None:
            position_ids = torch.arange(src.shape[-1], dtype=torch.long, device=src.device)
            hidden_states = hidden_states + self.pos_embedding(position_ids)
        hidden_states = self.drop(hidden_states).float()  # Force residual in fp32
        features = self.s4seq(hidden_states)
        return features

    def forward(self, src: Tensor) -> Tensor:
        features = self.forward_features(src)
        CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
        return CausalLMOutput(self.lm_head(features))

    def generate(self, src: Tensor, mode: str = 'slow', max_tokens: int = 10, end_tokens: list = [], **kwargs) -> Tensor:
        assert(src.shape[0] == 1)
        input = src
        ret = []
        while len(ret) < max_tokens:
            logits = self.lm_head(self.forward_features(input))[:, -1]
            preds = torch.argmax(logits, dim=-1)
            ret.append(preds)
            if preds in end_tokens:
                break
            input = torch.cat([input, preds.unsqueeze(0)], dim=-1)
        return torch.cat(ret).unsqueeze(0)


def _init_vit_weights(m, n: str = '', head_bias: float = 0.):
    """ ViT weight initialization
    """
    if isinstance(m, nn.Linear):
        if n.startswith('head'):
            nn.init.zeros_(m.weight)
            nn.init.constant_(m.bias, head_bias)
        else:
            if m.bias is not None:
                nn.init.zeros_(m.bias)
            dense_init_fn_ = partial(nn.init.trunc_normal_, std=.02)
            dense_init_fn_(m.weight)


class S4ViT(nn.Module):

    def __init__(self, d_model: int, n_layer: int, d_inner: int, img_size=224, patch_size=16,
                 in_chans=3, num_classes=1000, ssm_cfg=None,
                 mlp_cls=None, norm_first=True, resid_dropout: float = 0.0,
                 embed_dropout: float = 0.0, dropout_cls=nn.Dropout,
                 drop_path_rate=0.0, activation: str = "gelu",
                 layer_norm_eps: float = 1e-5, **kwargs) -> None:
        super().__init__()
        self.embedding = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
                                    embed_dim=d_model)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.drop = dropout_cls(embed_dropout)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layer)]  # stochastic depth decay rule
        batch_first = True
        self.s4seq = S4Sequence(d_model, n_layer, d_inner, ssm_cfg, mlp_cls, resid_dropout, dropout_cls,
                                activation, layer_norm_eps, batch_first, norm_first,
                                encoder_final_norm=True, drop_path=dpr, **kwargs)
        # encoder_final_norm=True to match GPT2 model
        self.head = nn.Linear(d_model, num_classes)
        # Weight init
        nn.init.trunc_normal_(self.cls_token, std=.02)
        self.apply(_init_vit_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'cls_token'}

    def forward_features(self, x: Tensor) -> Tensor:
        x = self.embedding(x)
        cls_token = repeat(self.cls_token, '1 1 d -> b 1 d', b=x.shape[0])
        x = torch.cat((cls_token, x), dim=1)
        x = self.drop(x).float()  # Force residual in fp32
        features = self.s4seq(x)
        return features[:, 0]

    def forward(self, src: Tensor) -> Tensor:
        features = self.forward_features(src)
        return self.head(features)


class S4LMManifest(nn.Module):
    '''S4LM wrapper for Manifest.
    
    Assumes that config has the following parameters:
    {
        tokenizer_name: str,
        tokenizer_params: dict (optional),
        model_config: dict (model config),
    }
    '''
    def __init__(self, config_path: str, weights_path: str = None):
        super().__init__()
        self.config = omegaconf.OmegaConf.load(config_path)
        self.model_name = self.config.model_name
        self.model = hydra.utils.instantiate(self.config.model_config).cuda()
        self.model.eval()
        if self.config.tokenizer_name == 'grokking':
            from src.datamodules.grokking import GrokkingDataModule
            self.tokenizer = GrokkingDataModule(**self.config.tokenizer_params).tokenizer
        else:
            from transformers import AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.config.tokenizer_name,
                model_max_length=8192,
                truncation_side='left'
            )
        if weights_path is not None:
            weights = torch.load(weights_path)['state_dict']
            weights = {k.replace('model.', ''): v for k, v in weights.items()}
            self.model.load_state_dict(weights, strict=True)

    def get_model_name(self):
        if self.model_name:
            return self.model_name
        else:
            return 'ZooModel'

    def generate(
        self, prompt: str, 
        model_config: str=None, # passed from manifest, ignore
        model_name:str=None, # passed from manifest, ignore
        model_path: str=None, # passed from manifest, ignore
        stop_at_newline: bool=True,
        **kwargs
    ) -> str:
        if self.config.tokenizer_name == 'grokking':
            prompt = torch.Tensor(self.tokenizer.tokenize(prompt)['input_ids']).long().unsqueeze(0).cuda()
            model_response = self.model.generate(prompt, **kwargs)[0]
            return [self.tokenizer.decode(model_response)]
        else:
            end_tokens = [
                self.tokenizer.eos_token_id
            ]
            if stop_at_newline:
                end_tokens.append(self.tokenizer.encode('\n')[0])
            prompt = torch.tensor(self.tokenizer.encode(prompt)).unsqueeze(0).cuda()
            return [self.tokenizer.decode(self.model.generate(prompt, end_tokens = end_tokens, **kwargs)[0])]
    
    def forward(
        self, src: Tensor
    ) -> Tensor:
        return self.model.forward(src)

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
