# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""A streamable transformer."""

import typing as tp

import torch
import torch.nn as nn
import torch.nn.functional as F


def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
    """Create time embedding for the given positions, target dimension `dim`.
    """
    # We aim for BTC format
    assert dim % 2 == 0
    half_dim = dim // 2
    adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
    phase = positions / (max_period ** (adim / (half_dim - 1)))
    return torch.cat([
        torch.cos(phase),
        torch.sin(phase),
    ], dim=-1)


class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
    def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int):  # type: ignore
        if self.norm_first:
            sa_input = self.norm1(x)
            x = x + self._sa_block(sa_input, x_past, past_context)
            x = x + self._ff_block(self.norm2(x))
        else:
            sa_input = x
            x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
            x = self.norm2(x + self._ff_block(x))

        return x, sa_input

    # self-attention block
    def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int):  # type: ignore
        _, T, _ = x.shape
        _, H, _ = x_past.shape

        queries = x
        keys = torch.cat([x_past, x], dim=1)
        values = keys

        queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
        keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
        delta = queries_pos - keys_pos
        valid_access = (delta >= 0) & (delta <= past_context)
        x = self.self_attn(queries, keys, values,
                           attn_mask=~valid_access,
                           need_weights=False)[0]
        return self.dropout1(x)


class StreamingTransformerEncoder(nn.Module):
    """TransformerEncoder with streaming support.

    Args:
        dim (int): dimension of the data.
        hidden_scale (int): intermediate dimension of FF module is this times the dimension.
        num_heads (int): number of heads.
        num_layers (int): number of layers.
        max_period (float): maxium period of cosines in the positional embedding.
        past_context (int or None): receptive field for the causal mask, infinite if None.
        gelu (bool): if true uses GeLUs, otherwise use ReLUs.
        norm_in (bool): normalize the input.
        dropout (float): dropout probability.
        **kwargs: See `nn.TransformerEncoderLayer`.
    """
    def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5,
                 max_period: float = 10000, past_context: int = 1000, gelu: bool = True,
                 norm_in: bool = True, dropout: float = 0., **kwargs):
        super().__init__()
        assert dim % num_heads == 0
        hidden_dim = int(dim * hidden_scale)

        self.max_period = max_period
        self.past_context = past_context
        activation: tp.Any = F.gelu if gelu else F.relu

        self.norm_in: nn.Module
        if norm_in:
            self.norm_in = nn.LayerNorm(dim)
        else:
            self.norm_in = nn.Identity()

        self.layers = nn.ModuleList()
        for idx in range(num_layers):
            self.layers.append(
                StreamingTransformerEncoderLayer(
                    dim, num_heads, hidden_dim,
                    activation=activation, batch_first=True, dropout=dropout, **kwargs))

    def forward(self, x: torch.Tensor,
                states: tp.Optional[tp.List[torch.Tensor]] = None,
                offset: tp.Union[int, torch.Tensor] = 0):
        B, T, C = x.shape
        if states is None:
            states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]

        positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
        pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)

        new_state: tp.List[torch.Tensor] = []
        x = self.norm_in(x)
        x = x + pos_emb

        for layer_state, layer in zip(states, self.layers):
            x, new_layer_state = layer(x, layer_state, self.past_context)
            new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
            new_state.append(new_layer_state[:, -self.past_context:, :])
        return x, new_state, offset + T
