from copy import deepcopy
from functools import partial
from typing import Literal

import torch
from konductor.registry import Registry
from mamba_ssm import Mamba2
from torch import Tensor, nn

from . import perceiver_io as pio
from ._encoders import (
    ENCODER_DICT,
    XAttnEncoder,
    EncodingMethod,
    fused_algorithm,
    make_causal_mask,
    piecewise_algorithm,
    sequential_algorithm,
)
from ._iadapter import pad_agent_mask

try:
    from torch_discounted_cumsum_nd import discounted_cumsum
except AttributeError:
    print("failed to import discounted_cumsum, hopefully not needed")


UNIT_ENCODERS = Registry("unit-encoders")


@UNIT_ENCODERS.register_module("v1")
class UnitEncoderV1(nn.Module):
    """Simple unit encoder that applies cross attention then self attention"""

    def __init__(
        self,
        unit_dim: int,
        latent_dim: int,
        n_cross_layers: int,
        n_self_layers: int,
        num_heads: int = 4,
    ):
        super().__init__()
        self.unit_attn = pio.Sequential(
            pio.CrossAttentionBlock(
                n_cross_layers, latent_dim, unit_dim, num_heads, dropout=0.0
            ),
            pio.self_attention_block(n_self_layers, latent_dim, num_heads, dropout=0.0),
        )
        self.enemy_attn = deepcopy(self.unit_attn)

    def forward(
        self,
        latent: Tensor,
        units: Tensor,
        units_mask: Tensor,
        enemy: Tensor | None,
        enemy_mask: Tensor | None,
    ):
        units, units_mask = pad_agent_mask(units, units_mask)
        if enemy is not None:
            enemy, enemy_mask = pad_agent_mask(enemy, enemy_mask)

        # Invert masks for mha
        units_mask = ~units_mask
        if enemy_mask is not None:
            enemy_mask = ~enemy_mask

        latents: list[Tensor] = []
        for tidx in range(units.shape[0]):
            latent = self.unit_attn(latent, units[tidx], units_mask[tidx])
            if self.enemy_attn is not None:
                assert enemy is not None and enemy_mask is not None
                latent = self.enemy_attn(latent, enemy[tidx], enemy_mask[tidx])
            latents.append(latent)

        return torch.stack(latents, dim=0)


@UNIT_ENCODERS.register_module("scan")
class UnitEncoderScan(nn.Module):
    """Simple unit encoder that applies cross attention then self attention

    Args:
        unit_dim (int): Channel depth of incoming unit features
        latent_dim (int): Channel depth of latent state
        n_blocks (int): Number of cross-temporal blocks
        separate_alliance (bool): Flag if units/enemies are separated
        num_heads (int, optional): Number of MHA heads. Defaults to 4.
        temporal (bool, optional): Enable temporal accumulation. Defaults to True.
        cumsum_gamma (float, optional): Weighting of past elements in cumsum. Defaults to 1.
        post_self_attn (int, optional): Apply self attention after temporal cumsum with n layers.
            Defaults to 0.
    """

    def __init__(
        self,
        unit_dim: int,
        latent_dim: int,
        n_blocks: int,
        num_heads: int = 4,
        temporal: bool = True,
        cumsum_gamma: float = 1,
        self_attn_block: int = 1,
        post_self_attn: int = 0,
        cross_attn_block: int = 1,
        dropout: float = 0.0,
        pre_project: bool = True,
        use_fused_bert: bool = False,
        sample_once: bool = False,
    ):
        super().__init__()
        self._n_blocks = n_blocks
        self.sample_once = sample_once
        if pre_project:
            self.unit_proj = nn.Linear(unit_dim, latent_dim)
            kv_dim = latent_dim
        else:
            self.unit_proj = None
            kv_dim = unit_dim

        self.encoder_block: nn.ModuleList | None = None
        self.unit_attn: nn.ModuleList | None = None
        if use_fused_bert:
            assert pre_project, "Must reshape units to BERT first"
            self._build_bert_fused(latent_dim, cross_attn_block, num_heads, n_blocks)
        else:
            self._build_cross_sequential(
                kv_dim,
                latent_dim,
                num_heads,
                cross_attn_block,
                dropout,
                n_blocks,
                self_attn_block,
            )

        if post_self_attn > 0:
            self.post_self_attn = nn.ModuleList(
                pio.self_attention_block(
                    post_self_attn, latent_dim, num_heads, dropout=0.0
                )
                for _ in range(n_blocks)
            )
        else:
            self.post_self_attn = None

        self.enemy_attn = deepcopy(self.unit_attn)
        self.enemy_proj = deepcopy(self.unit_proj)
        self.temporal = temporal
        self.gamma = cumsum_gamma

    def _build_cross_sequential(
        self,
        kv_dim: int,
        latent_dim: int,
        num_heads: int,
        cross_attn_block: int,
        dropout: float,
        n_blocks: int,
        self_attn_block: int,
    ):
        """Build cross-sequential model, greater depth and params"""
        if cross_attn_block == 1:
            cross_attn = lambda: pio.cross_attention_layer(
                latent_dim, kv_dim, num_heads, dropout=dropout
            )
        else:
            cross_attn = lambda: pio.CrossAttentionBlock(
                cross_attn_block, latent_dim, kv_dim, num_heads, dropout=dropout
            )

        self.unit_attn = nn.ModuleList(
            cross_attn() for _ in range(1 if self.sample_once else n_blocks)
        )

        if self.sample_once:
            assert self_attn_block > 0

            def block_num(idx: int):
                """Subsequent blocks are 2*cross_attn_block to compensate"""
                return (
                    self_attn_block
                    if idx == 0
                    else self_attn_block + 2 * cross_attn_block
                )

            self.self_attn = nn.ModuleList(
                pio.self_attention_block(
                    block_num(i), latent_dim, num_heads, dropout=dropout
                )
                for i in range(n_blocks)
            )
        elif self_attn_block == 0:
            self.self_attn = None
        else:
            self.self_attn = nn.ModuleList(
                pio.self_attention_block(
                    self_attn_block, latent_dim, num_heads, dropout=dropout
                )
                for _ in range(n_blocks)
            )

    def _build_bert_fused(
        self, latent_dim: int, cross_attn_block: int, num_heads: int, n_blocks: int
    ):
        """Build bert-fused model, shallower with less param but higher flops"""
        self.encoder_block = nn.ModuleList(
            nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    latent_dim, num_heads, latent_dim * 4, batch_first=True
                ),
                cross_attn_block,
            )
            for _ in range(n_blocks)
        )

    def apply_cross_attn(
        self,
        index: int,
        latent: Tensor,
        units: Tensor,
        units_mask: Tensor,
        enemy: Tensor,
        enemy_mask: Tensor,
    ):
        assert self.unit_attn is not None
        assert self.enemy_attn is not None

        if not self.sample_once or index == 0:
            latent = self.unit_attn[index](latent, units, units_mask)
            latent = self.enemy_attn[index](latent, enemy, enemy_mask)

        if self.self_attn is not None:
            latent = self.self_attn[index](latent)

        return latent

    def _apply_bert_fused(
        self, index: int, latent: Tensor, units: Tensor, mask: Tensor
    ):
        """Mask should already have latent shape prefixed with zeros"""
        assert self.encoder_block is not None
        joined = torch.cat([latent, units], dim=1)
        encoded = self.encoder_block[index](joined, src_key_padding_mask=mask)
        return encoded[:, : latent.shape[1]]

    def forward(
        self,
        latent: Tensor,
        units: Tensor,
        units_mask: Tensor,
        enemy: Tensor,
        enemy_mask: Tensor,
    ):
        """Input shapes are [S,B,N,C]"""
        seq, bsz = units.shape[:2]

        if self.encoder_block is None:  # xattn-seq needs null padding
            units, units_mask = pad_agent_mask(units, units_mask)
            enemy, enemy_mask = pad_agent_mask(enemy, enemy_mask)

        # Invert 'valid' mask for MHA which is 'padding' mask
        units_mask = ~units_mask.reshape(-1, *units_mask.shape[2:])
        enemy_mask = ~enemy_mask.reshape(-1, *enemy_mask.shape[2:])

        units = units.reshape(-1, *units.shape[2:])
        if self.unit_proj is not None:
            units = self.unit_proj(units)
        enemy = enemy.reshape(-1, *enemy.shape[2:])
        if self.enemy_proj is not None:
            enemy = self.enemy_proj(enemy)

        # Expand latent to [S,B,N,C]
        latent = latent[None].repeat(seq, 1, 1, 1)

        if self.encoder_block is not None:
            units = torch.cat([units, enemy], dim=1)
            units_mask = torch.cat(
                [
                    units_mask.new_zeros(seq * bsz, latent.shape[2]),
                    units_mask,
                    enemy_mask,
                ],
                dim=1,
            )

        for idx in range(self._n_blocks):
            latent = latent.view(-1, *latent.shape[2:])
            if self.encoder_block is None:
                latent = self.apply_cross_attn(
                    idx, latent, units, units_mask, enemy, enemy_mask
                )
            else:
                latent = self._apply_bert_fused(idx, latent, units, units_mask)
            latent = latent.view(seq, bsz, *latent.shape[1:])

            if self.temporal and self.gamma == 1:
                latent = torch.cumsum(latent, dim=0)
            elif self.temporal:
                latent = discounted_cumsum(latent, dim=0, gamma=self.gamma)

            if self.post_self_attn is not None:
                latent = latent.view(-1, *latent.shape[2:])
                latent = self.post_self_attn[idx](latent)
                latent = latent.view(seq, bsz, *latent.shape[1:])

        return latent

    def _inc_first(
        self,
        latent: Tensor,
        units: Tensor,
        units_mask: Tensor,
        enemy: Tensor | None,
        enemy_mask: Tensor | None,
    ):
        """First time"""
        next_latents: list[Tensor] = [latent]
        for idx in range(len(self.unit_attn)):
            latent = self.apply_cross_attn(
                idx, latent, units, units_mask, enemy, enemy_mask
            )
            next_latents.append(latent)
            if self.post_self_attn is not None:
                latent = self.post_self_attn[idx](latent)
        next_latents.append(latent)
        return next_latents

    def _inc_temporal(
        self,
        latents: list[Tensor],
        units: Tensor,
        units_mask: Tensor,
        enemy: Tensor | None,
        enemy_mask: Tensor | None,
    ):
        """Every other time"""
        latent = latents[0]
        next_latents: list[Tensor] = [latent]
        for idx in range(len(self.unit_attn)):
            latent = self.apply_cross_attn(
                idx, latent, units, units_mask, enemy, enemy_mask
            )
            next_latents.append(latent)
            if self.temporal and self.gamma == 1:
                latent += latents[idx + 1]
            elif self.temporal:
                latent += latents[idx + 1] / self.gamma

            if self.post_self_attn is not None:
                latent = self.post_self_attn[idx](latent)
        next_latents.append(latent)
        return next_latents

    def inc_forward(
        self,
        latents: list[Tensor] | Tensor,
        units: Tensor,
        units_mask: Tensor,
        enemy: Tensor | None,
        enemy_mask: Tensor | None,
    ):
        """Forward unit data in incremental mode"""
        # Invert 'valid' mask for MHA which is 'padding' mask
        units_mask = ~units_mask
        if enemy_mask is not None:
            enemy_mask = ~enemy_mask

        units = self.unit_proj(units)
        if self.enemy_proj is not None:
            assert enemy is not None and enemy_mask is not None
            enemy = self.enemy_proj(enemy)

        if isinstance(latents, Tensor):
            next_latents = self._inc_first(
                latents, units, units_mask, enemy, enemy_mask
            )
        else:
            next_latents = self._inc_temporal(
                latents, units, units_mask, enemy, enemy_mask
            )

        return next_latents


# Backwards compatibility
UNIT_ENCODERS.register_module("v2", UnitEncoderScan)


class BaseUnitEncoder(nn.Module):
    """Base unit encoder that mostly handles observation encoding"""

    def __init__(
        self,
        unit_dim: int,
        latent_dim: int,
        n_enc_layers: int,
        encoder_type: str,
        num_heads: int = 4,
        encoding_method: EncodingMethod = EncodingMethod.fused,
    ) -> None:
        if isinstance(encoding_method, str):
            encoding_method = EncodingMethod[encoding_method]
        super().__init__()
        self.unit_enc = ENCODER_DICT[encoder_type](
            unit_dim, latent_dim, n_enc_layers, num_heads
        )

        self.enemy_enc = None
        self.alliance_embedding = None
        if encoding_method is EncodingMethod.fused:
            self.alliance_embedding = nn.Parameter(
                torch.empty(2, unit_dim).normal_().clamp_(-2, 2)
            )
            self.enc_algo = partial(fused_algorithm, encoder=self.unit_enc)
        elif encoding_method is EncodingMethod.piecewise:
            self.enemy_enc = deepcopy(self.unit_enc)
            self.enc_algo = partial(
                piecewise_algorithm, encoder_a=self.unit_enc, encoder_b=self.enemy_enc
            )
        elif encoding_method is EncodingMethod.sequential:
            self.enemy_enc = deepcopy(self.unit_enc)
            self.enc_algo = partial(
                sequential_algorithm, encoder_a=self.unit_enc, encoder_b=self.enemy_enc
            )
        else:
            raise RuntimeError(f"Unrecognised {encoding_method=}")

    def forward(
        self,
        latent: Tensor,
        units: Tensor,
        units_mask: Tensor,
        enemy: Tensor,
        enemy_mask: Tensor,
    ):
        if isinstance(self.unit_enc, XAttnEncoder):
            units, units_mask = pad_agent_mask(units, units_mask)
        units_mask = ~units_mask
        if isinstance(self.enemy_enc, XAttnEncoder):
            enemy, enemy_mask = pad_agent_mask(enemy, enemy_mask)
        enemy_mask = ~enemy_mask

        if self.alliance_embedding is None:
            feats = self.enc_algo(
                latent=latent,
                feats_a=units,
                mask_a=units_mask,
                feats_b=enemy,
                mask_b=enemy_mask,
            )
        else:
            feats = self.enc_algo(
                latent=latent,
                feats_a=units,
                mask_a=units_mask,
                feats_b=enemy,
                mask_b=enemy_mask,
                embed_a=self.alliance_embedding[0],
                embed_b=self.alliance_embedding[1],
            )

        return feats


@UNIT_ENCODERS.register_module("recurrent")
class UnitEncoderRNN(BaseUnitEncoder):
    """Unit Encoder uses LSTM for latent state. A transformer encoder"""

    def __init__(
        self,
        unit_dim: int,
        latent_dim: int,
        latent_num: int,
        n_enc_layers: int,
        rnn_type: Literal["gru", "lstm", "rnn"],
        num_heads: int = 4,
        learned_init: bool = True,
        encoder_type: str = "bert",
        encoding_method: EncodingMethod = EncodingMethod.fused,
    ) -> None:
        assert latent_num % 2 == 0, "latent_num must be even"
        super().__init__(
            unit_dim, latent_dim, n_enc_layers, encoder_type, num_heads, encoding_method
        )
        rnn_type_ = {"rnn": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU}[rnn_type]
        self.recurrent = nn.ModuleList(
            rnn_type_(input_size=latent_dim, hidden_size=latent_dim, batch_first=False)
            for _ in range(latent_num)
        )

        self.h0 = None
        self.c0 = None
        if learned_init:
            self.h0 = nn.Parameter(
                torch.empty(latent_num, latent_dim).normal_().clamp_(-2, 2)
            )
            if rnn_type == "lstm":
                self.c0 = nn.Parameter(
                    torch.empty_like(self.h0).normal_().clamp_(-2, 2)
                )

    def _process_sequence(self, unit_feats: Tensor):
        """"""
        B = unit_feats.shape[1]
        if self.h0 is None:
            init = lambda i: None
        elif self.c0 is not None:
            h0 = self.h0[:, None, None].expand(-1, -1, B, -1).contiguous()
            c0 = self.c0[:, None, None].expand(-1, -1, B, -1).contiguous()
            init = lambda i: (h0[i], c0[i])
        else:
            h0 = self.h0[:, None, None].expand(-1, -1, B, -1).contiguous()
            init = lambda i: h0[i]

        return [
            recurrent(unit_feats[:, :, i], init(i))[0]
            for i, recurrent in enumerate(self.recurrent)
        ]

    def forward(
        self,
        latent: Tensor,
        units: Tensor,
        units_mask: Tensor,
        enemy: Tensor,
        enemy_mask: Tensor,
    ):
        """"""
        feats = super().forward(latent, units, units_mask, enemy, enemy_mask)
        latents = self._process_sequence(feats)
        latents = torch.stack(latents, dim=2)  # [S,B,N,C]
        return latents


@UNIT_ENCODERS.register_module("transformer")
class UnitEncoderTransformer(BaseUnitEncoder):
    """Transformer (encoder or decoder) summarizes each observation.
    Then causally masked transformer-encoder across time."""

    def __init__(
        self,
        unit_dim: int,
        latent_dim: int,
        n_enc_layers: int,
        n_blocks: int,
        max_time: int,
        encoder_type: Literal["encoder", "decoder"],
        num_heads: int = 4,
        encoding_method: EncodingMethod = EncodingMethod.piecewise,
    ):
        super().__init__(
            unit_dim, latent_dim, n_enc_layers, encoder_type, num_heads, encoding_method
        )
        self.temporal_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(latent_dim, num_heads, batch_first=True),
            n_blocks,
        )
        self.time_embed = nn.Parameter(
            torch.empty(max_time, latent_dim).normal_().clamp_(-2, 2)
        )

    def forward(
        self,
        latent: Tensor,
        units: Tensor,
        units_mask: Tensor,
        enemy: Tensor,
        enemy_mask: Tensor,
    ):
        feats = super().forward(latent, units, units_mask, enemy, enemy_mask)
        feats = feats + self.time_embed[:, None, None, :].expand_as(feats)

        feats = feats.transpose(0, 1)  # [B,T,N,C]
        B, T, N, C = feats.shape
        feats = feats.reshape(B, T * N, C)
        encoding: Tensor = self.temporal_encoder(
            feats, mask=make_causal_mask(T, N, device=feats.device)
        )
        encoding = encoding.reshape(B, T, N, C)
        encoding = encoding.transpose(0, 1)  # [T,B,N,C]

        return encoding


@UNIT_ENCODERS.register_module("mamba")
class MambaEncoder(BaseUnitEncoder):
    """N Mamba blocks in parallel to accumulate the latent state"""

    def __init__(
        self,
        unit_dim: int,
        latent_dim: int,
        latent_num: int,
        n_enc_layers: int,
        n_blocks: int,
        encoder_type: Literal["encoder", "decoder"],
        num_heads: int = 4,
        m_state: int = 64,
        m_conv: int = 4,
        m_expand: int = 2,
        encoding_method: EncodingMethod = EncodingMethod.piecewise,
    ):
        super().__init__(
            unit_dim, latent_dim, n_enc_layers, encoder_type, num_heads, encoding_method
        )

        mamba_factory = lambda: Mamba2(
            latent_dim,
            d_state=m_state,
            d_conv=m_conv,
            expand=m_expand,
            headdim=2 * latent_dim // num_heads,
        )

        self.temporal_encoder = nn.ModuleList(
            [
                nn.Sequential(*(mamba_factory() for _ in range(n_blocks)))
                for _ in range(latent_num)
            ]
        )

    def forward(
        self,
        latent: Tensor,
        units: Tensor,
        units_mask: Tensor,
        enemy: Tensor,
        enemy_mask: Tensor,
    ):
        feats = super().forward(latent, units, units_mask, enemy, enemy_mask)
        # [T,B,N,C] -> [N,B,T,C]
        feats = feats.permute(2, 1, 0, 3)
        encoding = torch.stack(
            [enc(l) for enc, l in zip(self.temporal_encoder, feats)], dim=0
        )
        # [N,B,T,C] -> [T,B,N,C]
        encoding = encoding.permute(2, 1, 0, 3)
        return encoding
