
import collections
import functools
from typing import Iterable, Optional, Type, Union
import math
import torch
from torch import nn



class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model) #batch first
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: torch.Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        if len(x.shape) < 3:
            x = x.unsqueeze(0)
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class TransformerEnc(nn.Module):
    def __init__(self, input_dim, d_model, 
            nhead=4,
            d_hid=128,
            dropout=0.1,
            nlayers=2,
            max_history_len=10):
        super().__init__()

        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.embedding = nn.Linear(input_dim, d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model,
                        nhead=nhead,
                        dim_feedforward=d_hid,
                        dropout=dropout,
                        batch_first =True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, 
                        num_layers=nlayers)

        self.max_history_len = max_history_len

        #self.init_weights()

    #def init_weights(self):
    #    initrange = 0.1
    #    self.embedding.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, pad_mask=None):
        x = self.embedding(x) #* math.sqrt(self.d_model)
        src = self.pos_encoder(x)

        bs, seqlen, _ = src.shape
        att_mask = self.generate_local_causal_attention_mask(seqlen).to('cuda:0')
        if pad_mask is not None:
            # pad_mask is 1 if true value, 0 if padding. opposite to transformer notation
            pad_mask = ~pad_mask
        output = self.transformer_encoder(src, mask = att_mask, src_key_padding_mask=pad_mask)
        
        output = output.masked_fill(torch.isnan(output), 0)

        return output

        
    def generate_local_causal_attention_mask(self, seqlen: int):
        """Generates an upper-triangular matrix of -inf, with zeros on diag."""
        full_pad = torch.full((seqlen, seqlen), float('-inf'))
        return torch.triu(full_pad, diagonal=1) + torch.tril(full_pad, diagonal=-self.max_history_len)
