from typing import Optional, Any
import math

import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules import MultiheadAttention, Linear, Dropout, BatchNorm1d, TransformerEncoderLayer
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from models.TransformerLayer import Encoder, EncoderLayer, LunaEncoder
from models.Attention import AttentionLayer, MrsAttentionLayer, Flow_Attention, MRA_head_Attention, \
    MRA2_Attention, Linear_transfomrer_Attention, FMM_transfomrer_Attention, \
    Performer_Attention, Linformer_Attention, LunaEncoderLayer, MrsLunaEncoderLayer


# From https://github.com/pytorch/examples/blob/master/word_language_model/model.py
class FixedPositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=1024).
    """

    def __init__(self, d_model, dropout=0.1, max_len=1024, scale_factor=1.0):
        super(FixedPositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)  # positional encoding
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = scale_factor * pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)  # this stores the variable in the state_dict (used for non-trainable variables)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class LearnablePositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=1024):
        super(LearnablePositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # Each position gets its own embedding
        # Since indices are always 0 ... max_len, we don't have to do a look-up
        self.pe = nn.Parameter(torch.empty(max_len, 1, d_model))  # requires_grad automatically set to True
        nn.init.uniform_(self.pe, -0.02, 0.02)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


def get_pos_encoder(pos_encoding):
    if pos_encoding == "learnable":
        return LearnablePositionalEncoding
    elif pos_encoding == "fixed":
        return FixedPositionalEncoding

    raise NotImplementedError("pos_encoding should be 'learnable'/'fixed', not '{}'".format(pos_encoding))


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    raise ValueError("activation should be relu/gelu, not {}".format(activation))


class Classiregressor(nn.Module):
    def __init__(self, model_type, feat_dim, max_len, d_model=512,
                 n_heads=8, num_layers=3, d_ff=512, num_classes=100,
                 dropout=0.0, pos_encoding='fixed', activation='gelu', norm='BatchNorm', freeze=False, factor=5):
        super(Classiregressor, self).__init__()

        self.max_len = max_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.project_inp = nn.Linear(feat_dim, d_model)
        self.pos_enc = get_pos_encoder(pos_encoding)(d_model, dropout=dropout * (1.0 - freeze), max_len=max_len)

        if model_type == "flowformer":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            Flow_Attention(attention_dropout=dropout),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "mra_head":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            MRA_head_Attention(attention_dropout=dropout),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "MRA2_sparse":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            MRA2_Attention(attention_dropout=dropout, mode='sparse'),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "MRA2_full":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            MRA2_Attention(attention_dropout=dropout, mode='full'),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "linear_transformer":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            Linear_transfomrer_Attention(attention_dropout=dropout),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "linformer":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            Linformer_Attention(attention_dropout=dropout, num_head=n_heads, head_dim=d_model//n_heads, linformer_k=256, max_seq_len=max_len),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "fmm_transformer":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            FMM_transfomrer_Attention(attention_dropout=dropout, head_dim=d_model//n_heads,
                                                      diag_size=5, num_head=n_heads, kernels=["elu", "elu_flip"], sparse_ratio=4.5),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "performer":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        AttentionLayer(
                            Performer_Attention(attention_dropout=dropout, head_dim=d_model//n_heads, rp_dim=256),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "luna_transformer":
            self.encoder = LunaEncoder(
                [
                    LunaEncoderLayer(
                        n_heads=n_heads,
                        d_model=d_model,
                        d_ff=d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                d_model = d_model,
                drop_out = dropout,
                project_embedding_length = 256,
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "linear_transformer_mra_head":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        MrsAttentionLayer(
                            Linear_transfomrer_Attention(attention_dropout=dropout),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "linformer_mra_head":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        MrsAttentionLayer(
                            Linformer_Attention(attention_dropout=dropout, num_head=1, head_dim=d_model//n_heads, linformer_k=256, max_seq_len=max_len),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "fmm_transformer_mra_head":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        MrsAttentionLayer(
                            FMM_transfomrer_Attention(attention_dropout=dropout, head_dim=d_model//n_heads,
                                                      diag_size=5, num_head=1, kernels=["elu", "elu_flip"], sparse_ratio=4.5),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "performer_mra_head":
            self.encoder = Encoder(
                [
                    EncoderLayer(
                        MrsAttentionLayer(
                            Performer_Attention(attention_dropout=dropout, head_dim=d_model//n_heads, rp_dim=256),
                            d_model, n_heads),
                        d_model,
                        d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        elif model_type == "luna_transformer_mra_head":
            self.encoder = LunaEncoder(
                [
                    MrsLunaEncoderLayer(
                        n_heads=n_heads,
                        d_model=d_model,
                        d_ff=d_ff,
                        dropout=dropout,
                        activation=activation
                    ) for l in range(num_layers)
                ],
                d_model = d_model,
                drop_out = dropout,
                project_embedding_length = 256,
                norm_layer=torch.nn.LayerNorm(d_model)
            )
        
        
        self.output_layer = nn.Linear(d_model, feat_dim)
        

        self.act = _get_activation_fn(activation)
        self.dropout = nn.Dropout(dropout)

        self.feat_dim = feat_dim
        self.num_classes = num_classes
        self.output_layer = self.build_output_module(d_model, max_len, num_classes)
        # torch.nn.init.xavier_normal_(self.output_layer.weight)
        # torch.nn.init.zeros_(self.output_layer.bias)

    def build_output_module(self, d_model, max_len, num_classes):
        output_layer = nn.Linear(d_model * max_len, num_classes)
        # no softmax (or log softmax), because CrossEntropyLoss does this internally. If probabilities are needed,
        # add F.log_softmax and use NLLoss
        return output_layer

    def forward(self, x_enc, padding_masks):
        inp = x_enc.permute(1, 0, 2)
        inp = self.project_inp(inp) * math.sqrt(
            self.d_model)  # [seq_length, batch_size, d_model] project input vectors to d_model dimensional space
        inp = self.pos_enc(inp)  # add positional encoding
        inp = inp.permute(1, 0, 2)
        # [bsz, seq_len, d_model]
        enc_out = self.encoder(inp)
        if isinstance(enc_out, tuple): # Get unpacked Luna output
            enc_out = enc_out[0]
        enc_out = enc_out.permute(1, 0, 2)
        output = self.act(enc_out)  # the output transformer encoder/decoder embeddings don't include non-linearity
        output = output.permute(1, 0, 2)  # (batch_size, seq_length, d_model)
        output = self.dropout(output)

        # Output
        output = output * padding_masks.unsqueeze(-1)  # zero-out padding embeddings
        output = output.reshape(output.shape[0], -1)  # (batch_size, seq_length * d_model)
        output = self.output_layer(output)  # (batch_size, num_classes)

        return output
