import math
from typing import Optional

import torch
from torch import Tensor, nn


class PositionalEncoding(nn.Module):

    def __init__(self,
                 d_model: int,
                 dropout: float = 0.1,
                 max_len: int = 1000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
    Args:
        x: Tensor, shape [seq_len, batch_size, embedding_dim]
    """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


class TransformerPredictor(nn.Module):

    def __init__(self,
                 enc_in_dim: int,
                 dec_in_dim: int,
                 out_dim: int,
                 embed_dim: int,
                 n_head: int,
                 ff_dim: int,
                 n_enc_layers: int,
                 n_dec_layers: int,
                 dropout: float = 0.2):
        super(TransformerPredictor, self).__init__()
        self.src_embed_fn = nn.Linear(enc_in_dim, embed_dim)
        self.tgt_embed_fn = nn.Linear(dec_in_dim, embed_dim)
        self.head = nn.Linear(embed_dim, out_dim)
        self.pos_encoder = PositionalEncoding(d_model=embed_dim,
                                              dropout=dropout)
        self.transformer = nn.Transformer(
            d_model=embed_dim,
            nhead=n_head,
            num_encoder_layers=n_enc_layers,
            num_decoder_layers=n_dec_layers,
            dim_feedforward=ff_dim,
            dropout=dropout,
        )

    def forward(
        self,
        source_obs: Tensor,
        target_obs: Tensor,
        tgt_look_ahead_mask: Tensor,
        src_pad_mask: Optional[Tensor] = None,
        tgt_pad_mask: Optional[Tensor] = None,
    ):
        """

    Args:
      source_obs:  [batch, seq_len, obs_dim]
      target_obs:  [batch, seq_len, obs_dim]

    Returns:
      action prediction
    """
        source_obs = torch.permute(source_obs, (1, 0, 2))
        target_obs = torch.permute(target_obs, (1, 0, 2))

        source_ebd = self.src_embed_fn(source_obs)
        target_ebd = self.tgt_embed_fn(target_obs)

        source_ebd = self.pos_encoder(source_ebd)
        target_ebd = self.pos_encoder(target_ebd)

        out = self.transformer(
            source_ebd,
            target_ebd,
            src_mask=None,
            tgt_mask=tgt_look_ahead_mask,
            src_key_padding_mask=src_pad_mask,
            tgt_key_padding_mask=tgt_pad_mask,
        )
        out = self.head(out)
        out = torch.permute(out, (1, 0, 2))

        return out
