import math
from typing import Optional

import torch
from einops import rearrange
from torch import Tensor, nn


class PositionalEncoding(nn.Module):

    def __init__(self,
                 d_model: int,
                 dropout: float = 0.1,
                 max_len: int = 1000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
    Args:
        x: Tensor, shape [seq_len, batch_size, embedding_dim]
    """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


def generate_square_subsequent_mask(sz: int) -> Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)


class TransformerPredictor(nn.Module):

    def __init__(self,
                 enc_in_dim: int,
                 dec_in_dim: int,
                 out_dim: int,
                 embed_dim: int,
                 n_head: int,
                 ff_dim: int,
                 n_enc_layers: int,
                 n_dec_layers: int,
                 dropout: float = 0.2,
                 image_observation: bool = False,
                 use_image_decoder: bool = True,
                 image_state_dim: int = -1,
                 coord_conv: bool = False,
                 pretrained: bool = False):
        super(TransformerPredictor, self).__init__()
        self.src_embed_fn = nn.Linear(enc_in_dim, embed_dim)
        self.tgt_embed_fn = nn.Linear(dec_in_dim, embed_dim)
        self.head = nn.Linear(embed_dim, out_dim)
        self.pos_encoder = PositionalEncoding(d_model=embed_dim,
                                              dropout=dropout)
        self.transformer = nn.Transformer(
            d_model=embed_dim,
            nhead=n_head,
            num_encoder_layers=n_enc_layers,
            num_decoder_layers=n_dec_layers,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True,
            norm_first=True,
        )

        if image_observation:
            assert image_state_dim > 0
            from common.ours.models import ConvDecoder, ConvEncoder
            self.source_image_encoder = ConvEncoder(
                in_channels=3,
                out_dim=image_state_dim,
                pretrained=pretrained,
                coord_conv=coord_conv,
            )
            self.target_image_encoder = ConvEncoder(
                in_channels=3,
                out_dim=image_state_dim,
                pretrained=pretrained,
                coord_conv=coord_conv,
            )
            self.image_encoder = self.target_image_encoder

            if use_image_decoder:
                self.source_image_decoder = ConvDecoder(
                    image_latent_dim=image_state_dim, )
                self.target_image_decoder = ConvDecoder(
                    image_latent_dim=image_state_dim, )

    def forward(
        self,
        source_obs: Tensor,
        target_obs: Tensor,
        tgt_look_ahead_mask: Tensor,
        src_pad_mask: Optional[Tensor] = None,
        tgt_pad_mask: Optional[Tensor] = None,
    ):
        """

    Args:
      source_obs:  [batch, seq_len, obs_dim]
      target_obs:  [batch, seq_len, obs_dim]

    Returns:
      action prediction
    """
        source_obs = rearrange(source_obs, 'b s d -> s b d')
        target_obs = rearrange(target_obs, 'b s d -> s b d')

        source_ebd = self.src_embed_fn(source_obs)
        target_ebd = self.tgt_embed_fn(target_obs)

        source_ebd = self.pos_encoder(source_ebd)
        target_ebd = self.pos_encoder(target_ebd)

        source_ebd = rearrange(source_ebd, 's b d -> b s d')
        target_ebd = rearrange(target_ebd, 's b d -> b s d')

        out = self.transformer(
            source_ebd,
            target_ebd,
            src_mask=None,
            tgt_mask=tgt_look_ahead_mask,
            src_key_padding_mask=src_pad_mask,
            tgt_key_padding_mask=tgt_pad_mask,
        )
        out = self.head(out)

        return out
