import math
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
The Transformer Code is mainly based on the following sources:
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html
https://github.com/gzerveas/mvts_transformer
"""

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)

    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)

    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads, seed = 42):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.seed = seed

        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim, bias=True)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        g = torch.Generator()
        g.manual_seed(self.seed)

        nn.init.xavier_uniform_(self.qkv_proj.weight)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        
        ### Linear Projection of v,k,q for each Head ###
        batch_size, seq_length, _ = x.size()

        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        ### Scaled Dot-Product Attention ###

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]

        ### Concat all Heads ###

        values = values.reshape(batch_size, seq_length, self.embed_dim)

        ### Linear Output Projection ###

        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o


class MyBatchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MyBatchNorm1d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(num_features))
            self.bias = nn.Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)

        self.reset_parameters()

    def reset_parameters(self):
        if self.affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)

        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)

    def forward(self, input):
        if input.dim() != 3:
            raise ValueError("Expected 3D input (batch, sequence_length, feature), "
                             "but got input with shape {}".format(input.shape))

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 1]) 
            var = input.var([0, 1], unbiased=False) # biased var for training mode, unbiased for running var
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean \
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * input.var([0, 1], unbiased=True) \
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean) / ((var + self.eps).sqrt())
        if self.affine:
            input = input * self.weight + self.bias

        return input
    

class EncoderBlock(nn.Module):

    def __init__(self, d_model, num_heads, dim_feedforward, dropout=0.2, activation = "relu", norm="layernorm", skipconnections = True, seed = 42):
        """
        Inputs:
            input_dim - Dimensionality of the input
            num_heads - Number of heads to use in the attention block
            dim_feedforward - Dimensionality of the hidden layer in the MLP
            dropout - Dropout probability to use in the dropout layers
        """
        super().__init__()
        self.activation = activation
        self.norm = norm
        self.skipconnections = skipconnections
        self.self_attn = MultiheadAttention(d_model, d_model, num_heads, seed = seed)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        if self.norm == "batchnorm":
            self.norm1 = MyBatchNorm1d(d_model)
            self.norm2 = MyBatchNorm1d(d_model)

        elif self.norm == "layernorm":
            self.norm1 = nn.LayerNorm(d_model)
            self.norm2 = nn.LayerNorm(d_model)
        else:
            print("no norm!")


        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        if self.activation == "gelu":
            self.act = nn.GELU()
        elif self.activation == "relu":
            self.act = nn.ReLU()
        else:
            print("no valid activation selection, falling back to relu")
            self.act = nn.ReLU()
    
    def forward(self, x, mask=None):

        attn_out = self.self_attn(x, mask=mask)

        if self.skipconnections:
            x = x + self.dropout1(attn_out)
        else:
            x = self.dropout1(attn_out)

        if self.norm == "batchnorm" or self.norm == "layernorm":
        
            x = self.norm1(x)

        x = self.linear2(self.dropout(self.act(self.linear1(x))))

        if self.skipconnections:
            x = x + self.dropout2(x)
        else:
            x = self.dropout2(x)

        if self.norm == "batchnorm" or self.norm == "layernorm":
            x = self.norm2(x)

        return x


# Handles layer stacking
class TransformerEncoder(nn.Module):

    def __init__(self, num_layers, **block_args):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(**block_args) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for l in self.layers:
            x = l(x, mask=mask)
        return x

    def get_attention_maps(self, x, mask=None):
        attention_maps = []
        for l in self.layers:
            _, attn_map = l.self_attn(x, mask=mask, return_attention=True)
            attention_maps.append(attn_map)
            x = l(x)
        return attention_maps

    
class LearnablePositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=1024):
        super(LearnablePositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.pe = nn.Parameter(torch.empty(1, max_len, d_model))
        nn.init.uniform_(self.pe, -0.02, 0.02)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)