import enum

import torch
from torch import Tensor, nn

from . import perceiver_io as pio


class EncodingMethod(enum.Enum):
    piecewise = enum.auto()
    sequential = enum.auto()
    fused = enum.auto()


def piecewise_algorithm(
    latent: Tensor,
    feats_a: Tensor,
    mask_a: Tensor,
    encoder_a: nn.Module,
    feats_b: Tensor,
    mask_b: Tensor,
    encoder_b: nn.Module,
):
    """Run piecewise encoding algorithm"""
    T, B = feats_a.shape[:2]
    if latent.shape[0] != T * B:  # Expand query tokens
        latent = latent.repeat(T, 1, 1)

    H = latent.shape[1] // 2  # Split latent between feats_a and feats_b
    enc_a = encoder_a(latent[:, H:], feats_a, mask_a)
    enc_b = encoder_b(latent[:, :H], feats_b, mask_b)

    # Swapping A/B is an oops but this isn't used recursively, so shouldn't matter
    latent = torch.cat([enc_a, enc_b], dim=1)
    latent = latent.reshape(T, B, *latent.shape[1:])  # [T,B,N,C]
    return latent


def sequential_algorithm(
    latent: Tensor,
    feats_a: Tensor,
    mask_a: Tensor,
    encoder_a: nn.Module,
    feats_b: Tensor,
    mask_b: Tensor,
    encoder_b: nn.Module,
):
    """Sequentially process feats_a and feats_b"""
    T, B = feats_a.shape[:2]
    if latent.shape[0] != T * B:  # Expand query tokens
        latent = latent.repeat(T, 1, 1)
    latent = encoder_a(latent, feats_a, mask_a)
    latent = encoder_b(latent, feats_b, mask_b)
    latent = latent.reshape(T, B, *latent.shape[1:])  # [T,B,N,C]
    return latent


def fused_algorithm(
    latent: Tensor,
    feats_a: Tensor,
    mask_a: Tensor,
    embed_a: Tensor,
    feats_b: Tensor,
    mask_b: Tensor,
    embed_b: Tensor,
    encoder: nn.Module,
):
    """Fuse feats_a and feats_b with added embedding and process"""
    feats_a = feats_a + embed_a[None, None, None].expand_as(feats_a)
    feats_b = feats_b + embed_b[None, None, None].expand_as(feats_b)
    feats = torch.cat([feats_a, feats_b], dim=2)
    mask = torch.cat([mask_a, mask_b], dim=2)

    T, B = feats_a.shape[:2]
    if latent.shape[0] != T * B:  # Expand query tokens
        latent = latent.repeat(T, 1, 1)

    latent = encoder(latent, feats, mask)
    latent = latent.reshape(T, B, *latent.shape[1:])  # [T,B,N,C]
    return latent


def make_causal_mask(time: int, n_tokens: int, device: torch.device | None):
    """Make causal mask where multiple tokens are in each timestep"""
    mask = (
        torch.arange(time * n_tokens, device=device)
        .unsqueeze(0)
        .repeat(time * n_tokens, 1)
        // n_tokens
    )
    indices = torch.arange(time * n_tokens, device=device).unsqueeze(-1) // n_tokens
    final = mask > indices  # True is not allowed to attend
    return final


class BERTEncoder(nn.Module):
    """BERT-style encoder to summarize state into [CLS] token(s)"""

    def __init__(
        self,
        unit_dim: int,
        hidden_dim: int,
        n_enc_layers: int,
        num_heads: int = 4,
    ):
        super().__init__()
        self.proj = nn.Linear(unit_dim, hidden_dim)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                hidden_dim,
                num_heads,
                dim_feedforward=hidden_dim * 4,
                batch_first=True,
                norm_first=True,
            ),
            n_enc_layers,
        )

    def forward(self, latent: Tensor, units: Tensor, mask: Tensor | None):
        """Encode a batched sequence of units [T, B, N, C], latent should be already [T*B,N,C]"""
        N = units.shape[2]
        # First project units to hidden_dim
        units = self.proj(units.flatten(0, 2))
        units = units.reshape(-1, N, units.shape[-1])  # [T * B, N, C]

        # Add [CLS] token(s)
        units = torch.cat([latent, units], dim=-2)

        if mask is not None and mask.ndim == 3:
            mask = mask.flatten(0, 1)  # [T * B, N]
            mask = torch.cat([mask.new_zeros(latent.shape[:-1]), mask], dim=-1)

        encoded: Tensor = self.encoder(units, src_key_padding_mask=mask)
        encoded = encoded[:, : latent.shape[1]]  # Only keep [CLS] token(s)
        return encoded


class XAttnEncoder(nn.Module):
    """Cross-attention encoder to extract observation data as multiple query tokens"""

    def __init__(
        self,
        unit_dim: int,
        hidden_dim: int,
        n_enc_layers: int,
        num_heads: int = 4,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.proj = nn.Linear(unit_dim, hidden_dim)
        self.decoder = pio.CrossAttentionBlock(
            n_enc_layers,
            hidden_dim,
            hidden_dim,
            dropout=dropout,
            num_heads=num_heads,
            dim_feedforward=hidden_dim * 4,
        )

    def forward(self, latent: Tensor, units: Tensor, mask: Tensor | None):
        """Decode a batched sequence of units [T, B, N, C], latent should be already [T*B,N,C]"""
        N = units.shape[2]
        # First project units to hidden_dim
        units = self.proj(units.flatten(0, 2))
        units = units.reshape(-1, N, units.shape[-1])  # [T*B,N,C]

        if mask is not None and mask.ndim == 3:
            mask = mask.flatten(0, 1)  # [T*B,N]

        decoded: Tensor = self.decoder(latent, units, pad_mask=mask)

        return decoded


ENCODER_DICT = {
    "encoder": BERTEncoder,
    "decoder": XAttnEncoder,
    "bert": BERTEncoder,
    "xattn": XAttnEncoder,
}
