from copy import deepcopy
from functools import partial
from typing import Any, Literal

import torch
from mamba_ssm import Mamba2
from torch import Tensor, nn
from torch_discounted_cumsum_nd import discounted_cumsum

from . import perceiver_io as pio
from ._encoders import (
    ENCODER_DICT,
    EncodingMethod,
    fused_algorithm,
    piecewise_algorithm,
    sequential_algorithm,
)
from ._iadapter import InputAdapter
from .sc2_encoder import make_causal_mask


class Encoder(nn.Module):
    """
    Construct latent state by querying agent and target observations recursively
    Agents are queried first by the latent state before the targets are queried.
    """

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latents: int,
        num_latent_channels: int,
        num_heads: int = 4,
        num_block_layers: int = 6,
        dropout: float = 0.0,
        **_kwargs,
    ) -> None:
        super().__init__()

        self.latent_dim = num_latent_channels
        agent_layer = pio.CrossAttentionBlock(
            num_block_layers,
            num_latent_channels,
            input_adapter.out_channels,
            num_heads,
            dropout,
        )
        self.agent_attn = pio.Adapted(input_adapter, agent_layer)

        target_layer = pio.CrossAttentionBlock(
            num_block_layers,
            num_latent_channels,
            input_adapter.out_channels,
            num_heads,
            dropout,
        )
        self.target_attn = pio.Adapted(input_adapter, target_layer)

        self.propagate_layer = pio.self_attention_block(
            num_block_layers, num_latent_channels, num_heads, dropout
        )

        self.latent = nn.Parameter(torch.empty(num_latents, num_latent_channels))

    def run_attn(
        self,
        x_latent: Tensor,
        agents: Tensor,
        agents_mask: Tensor,
        targets: Tensor,
        targets_mask: Tensor,
    ):
        """Do cross attention between latent state and agents then targets"""
        x_latent = self.agent_attn(x_latent, agents, agents_mask)
        x_latent = self.target_attn(x_latent, targets, targets_mask)
        return x_latent

    def forward(
        self, agents: Tensor, agents_mask: Tensor, targets: Tensor, targets_mask: Tensor
    ):
        """Recursively apply cross attention between latent state, agents and targets"""
        tidxs, batch_sz = agents.shape[:2]
        x_latent = self.latent.unsqueeze(0).expand(batch_sz, -1, -1)
        x_latent = self.run_attn(
            x_latent, agents[0], agents_mask[0], targets[0], targets_mask[0]
        )

        x_latents = [x_latent]
        for tidx in range(1, tidxs):
            x_latent = self.propagate_layer(x_latent)
            x_latent = self.run_attn(
                x_latent,
                agents[tidx],
                agents_mask[tidx],
                targets[tidx],
                targets_mask[tidx],
            )
            x_latents.append(x_latent)

        x_latents = torch.stack(x_latents, dim=0)
        return x_latents


class Encoder2(Encoder):
    """Same as original but reverse order of agent and target xattn"""

    def run_attn(
        self,
        x_latent: Tensor,
        agents: Tensor,
        agents_mask: Tensor,
        targets: Tensor,
        targets_mask: Tensor,
    ):
        """Do cross attention between latent state and agents then targets"""
        x_latent = self.target_attn(x_latent, targets, targets_mask)
        x_latent = self.agent_attn(x_latent, agents, agents_mask)
        return x_latent


class ScanEncoder(nn.Module):
    """Inclusive-scan based encoder"""

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latents: int,
        num_latent_channels: int,
        num_heads: int = 4,
        n_blocks: int = 6,
        dropout: float = 0.0,
        self_attn_block: int = 1,
        post_self_attn: int = 0,
        cross_attn_block: int = 1,
        gamma: float = 2.0,
        temporal: bool = True,
        sample_once: bool = False,
        pre_project: bool = True,
        use_fused_bert: bool = False,
        use_pw_algo: bool = False,
        **_kwargs,
    ) -> None:
        super().__init__()
        self.gamma = gamma
        self.temporal = temporal
        self.latent_dim = num_latent_channels
        self.input_adapter = input_adapter
        self.sample_once = sample_once
        self._n_blocks = n_blocks
        self._use_pw = use_pw_algo

        if pre_project:
            self.agent_proj = nn.Linear(input_adapter.out_channels, self.latent_dim)
            kv_dim = self.latent_dim
        else:
            self.agent_proj = None
            kv_dim = input_adapter.out_channels

        self.encoder_block: nn.ModuleList | None = None
        self.agent_attn: nn.ModuleList | None = None
        if use_fused_bert:
            assert not use_pw_algo, "use_pw_algo and use_fused_bert contradict"
            assert pre_project, "Must reshape units to BERT first"
            self._build_bert(cross_attn_block, num_heads, n_blocks)
        else:
            self._build_cross(
                kv_dim,
                num_heads,
                cross_attn_block,
                dropout,
                sample_once,
                n_blocks,
                self_attn_block,
            )

        self.target_attn = deepcopy(self.agent_attn)
        self.target_proj = deepcopy(self.agent_proj)

        if post_self_attn > 0:
            self.post_self_attn = nn.ModuleList(
                pio.self_attention_block(
                    post_self_attn, self.latent_dim, num_heads, dropout=dropout
                )
                for _ in range(n_blocks)
            )
        else:
            self.post_self_attn = None

        self.latent = nn.Parameter(torch.empty(num_latents, num_latent_channels))
        nn.init.xavier_normal_(self.latent)

    def _build_cross(
        self,
        kv_dim: int,
        num_heads: int,
        cross_attn_block: int,
        dropout: float,
        sample_once: bool,
        n_blocks: int,
        self_attn_block: int,
    ):
        """Build cross-sequential model, greater depth and params"""
        if cross_attn_block == 1:
            cross_attn = lambda: pio.cross_attention_layer(
                self.latent_dim, kv_dim, num_heads, dropout=dropout
            )
        else:
            cross_attn = lambda: pio.CrossAttentionBlock(
                cross_attn_block, self.latent_dim, kv_dim, num_heads, dropout=dropout
            )

        self.agent_attn = nn.ModuleList(
            cross_attn() for _ in range(1 if sample_once else n_blocks)
        )

        if sample_once:
            assert self_attn_block > 0

            def block_num(idx: int):
                """Subsequent blocks are 2*cross_attn_block to compensate"""
                return (
                    self_attn_block
                    if idx == 0
                    else self_attn_block + 2 * cross_attn_block
                )

            self.self_attn = nn.ModuleList(
                pio.self_attention_block(
                    block_num(i),
                    self.latent_dim,
                    num_heads,
                    dropout=dropout,
                )
                for i in range(n_blocks)
            )
        elif self_attn_block == 0:
            self.self_attn = None
        else:
            self.self_attn = nn.ModuleList(
                pio.self_attention_block(
                    self_attn_block, self.latent_dim, num_heads, dropout=dropout
                )
                for _ in range(n_blocks)
            )

    def _build_bert(self, cross_attn_block: int, num_heads: int, n_blocks: int):
        """Build bert-fused model, shallower with less param but higher flops"""
        self.encoder_block = nn.ModuleList(
            nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    self.latent_dim, num_heads, self.latent_dim * 4, batch_first=True
                ),
                cross_attn_block,
            )
            for _ in range(n_blocks)
        )

    def _apply_sequential_cross(
        self,
        index: int,
        latent: Tensor,
        agents: Tensor,
        agents_mask: Tensor,
        targets: Tensor,
        targets_mask: Tensor,
    ):
        assert self.agent_attn is not None
        assert self.target_attn is not None
        if not self.sample_once or index == 0:
            latent = self.agent_attn[index](latent, agents, agents_mask)
            latent = self.target_attn[index](latent, targets, targets_mask)

        if self.self_attn is not None:
            latent = self.self_attn[index](latent)

        return latent

    def _apply_pw_cross(
        self,
        index: int,
        latent: Tensor,
        agents: Tensor,
        agents_mask: Tensor,
        targets: Tensor,
        targets_mask: Tensor,
    ):
        assert self.agent_attn is not None
        assert self.target_attn is not None
        half_l = latent.shape[1] // 2
        if not self.sample_once or index == 0:
            a_latent = self.agent_attn[index](latent[:, :half_l], agents, agents_mask)
            t_latent = self.target_attn[index](
                latent[:, half_l:], targets, targets_mask
            )
            latent = torch.cat([a_latent, t_latent], dim=1)

        if self.self_attn is not None:
            latent = self.self_attn[index](latent)

        return latent

    def _apply_bert_fused(
        self, index: int, latent: Tensor, agents: Tensor, mask: Tensor
    ):
        """Mask should already have latent shape prefixed with zeros"""
        assert self.encoder_block is not None
        joined = torch.cat([latent, agents], dim=1)
        encoded = self.encoder_block[index](joined, src_key_padding_mask=mask)
        return encoded[:, : latent.shape[1]]

    def forward(
        self,
        agents: Tensor,
        agents_mask: Tensor,
        targets: Tensor,
        targets_mask: Tensor,
    ):
        seq, bsz = agents.shape[:2]

        agents, agents_mask = self.input_adapter(agents, agents_mask)
        agents = agents.view(-1, *agents.shape[2:])
        agents_mask = agents_mask.reshape(-1, *agents_mask.shape[2:])
        if self.agent_proj is not None:
            agents = self.agent_proj(agents)

        targets, targets_mask = self.input_adapter(targets, targets_mask)
        targets = targets.view(-1, *targets.shape[2:])
        targets_mask = targets_mask.reshape(-1, *targets_mask.shape[2:])
        if self.target_proj is not None:
            targets = self.target_proj(targets)

        # Expand latent to [S,B,N,C]
        latent = self.latent.expand(seq, bsz, -1, -1)

        if self.encoder_block is not None:
            agents = torch.cat([agents, targets], dim=1)
            agents_mask = torch.cat(
                [
                    agents_mask.new_zeros(seq * bsz, latent.shape[2]),
                    agents_mask,
                    targets_mask,
                ],
                dim=1,
            )

        for idx in range(self._n_blocks):
            latent = latent.view(-1, *latent.shape[2:])
            if self._use_pw:
                latent = self._apply_pw_cross(
                    idx, latent, agents, agents_mask, targets, targets_mask
                )
            elif self.encoder_block is None:
                latent = self._apply_sequential_cross(
                    idx, latent, agents, agents_mask, targets, targets_mask
                )
            else:
                latent = self._apply_bert_fused(idx, latent, agents, agents_mask)
            latent = latent.view(seq, bsz, *latent.shape[1:])

            if self.temporal and self.gamma == 1:
                latent = torch.cumsum(latent, dim=0)
            elif self.temporal:
                latent = discounted_cumsum(latent, dim=0, gamma=self.gamma)

            if self.post_self_attn is not None:
                latent = latent.view(-1, *latent.shape[2:])
                latent = self.post_self_attn[idx](latent)
                latent = latent.view(seq, bsz, *latent.shape[1:])

        return latent


class BaseUnitEncoder(nn.Module):
    """Base unit encoder that mostly handles observation encoding"""

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latents: int,
        num_latent_channels: int,
        n_enc_layers: int,
        encoder_type: str,
        num_heads: int = 4,
        encoding_method: EncodingMethod = EncodingMethod.fused,
    ) -> None:
        super().__init__()
        if isinstance(encoding_method, str):
            encoding_method = EncodingMethod[encoding_method]
        self.input_adapter = input_adapter
        self.latent_dim = num_latent_channels
        self.unit_enc = ENCODER_DICT[encoder_type](
            input_adapter.out_channels, num_latent_channels, n_enc_layers, num_heads
        )
        self.latent = nn.Parameter(
            torch.empty(num_latents, num_latent_channels).normal_().clamp_(-2, 2)
        )

        self.target_enc = None
        self.alliance_embedding = None
        if encoding_method is EncodingMethod.fused:
            self.alliance_embedding = nn.Parameter(
                torch.empty(2, input_adapter.out_channels).normal_().clamp_(-2, 2)
            )
            self.enc_algo = partial(fused_algorithm, encoder=self.unit_enc)
        elif encoding_method is EncodingMethod.piecewise:
            self.target_enc = deepcopy(self.unit_enc)
            self.enc_algo = partial(
                piecewise_algorithm, encoder_a=self.unit_enc, encoder_b=self.target_enc
            )
        elif encoding_method is EncodingMethod.sequential:
            self.target_enc = deepcopy(self.unit_enc)
            self.enc_algo = partial(
                sequential_algorithm, encoder_a=self.unit_enc, encoder_b=self.target_enc
            )
        else:
            raise RuntimeError(f"Unrecognised {encoding_method=}")

    def forward(
        self,
        agents: Tensor,
        agents_mask: Tensor,
        targets: Tensor,
        targets_mask: Tensor,
    ):
        agents, agents_mask = self.input_adapter(agents, agents_mask)
        targets, targets_mask = self.input_adapter(targets, targets_mask)
        latent = self.latent[None].expand(agents.shape[1], -1, -1)

        if self.alliance_embedding is None:
            feats = self.enc_algo(
                latent=latent,
                feats_a=agents,
                mask_a=agents_mask,
                feats_b=targets,
                mask_b=targets_mask,
            )
        else:
            feats = self.enc_algo(
                latent=latent,
                feats_a=agents,
                mask_a=agents_mask,
                feats_b=targets,
                mask_b=targets_mask,
                embed_a=self.alliance_embedding[0],
                embed_b=self.alliance_embedding[1],
            )

        return feats


class RNNEncoder(BaseUnitEncoder):
    """Encoder uses RNN for latent state. A transformer encoder
    is used to summarize observation into a single feature vector"""

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latents: int,
        num_latent_channels: int,
        n_enc_layers: int,
        rnn_type: Literal["gru", "lstm", "rnn"],
        encoder_type: Literal["encoder", "decoder"] = "encoder",
        num_heads: int = 4,
        learned_init: bool = True,
        encoding_method=EncodingMethod.fused,
        **_kwargs,
    ) -> None:
        super().__init__(
            input_adapter,
            num_latents,
            num_latent_channels,
            n_enc_layers,
            encoder_type,
            num_heads,
            encoding_method,
        )
        assert num_latents % 2 == 0, "latent_num must be even"

        rnn_type_ = {"rnn": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU}[rnn_type]
        self.recurrent = nn.ModuleList(
            rnn_type_(
                input_size=self.latent_dim,
                hidden_size=self.latent_dim,
                batch_first=False,
            )
            for _ in range(num_latents)
        )

        self.h0 = None
        self.c0 = None
        if learned_init:
            self.h0 = nn.Parameter(
                torch.empty(num_latents, self.latent_dim).normal_().clamp_(-2, 2)
            )
            if rnn_type == "lstm":
                self.c0 = nn.Parameter(
                    torch.empty_like(self.h0).normal_().clamp_(-2, 2)
                )

    def _process_sequence(self, unit_feats: Tensor):
        """"""
        B = unit_feats.shape[1]
        if self.h0 is None:
            init = lambda i: None
        elif self.c0 is not None:
            h0 = self.h0[:, None, None].expand(-1, -1, B, -1).contiguous()
            c0 = self.c0[:, None, None].expand(-1, -1, B, -1).contiguous()
            init = lambda i: (h0[i], c0[i])
        else:
            h0 = self.h0[:, None, None].expand(-1, -1, B, -1).contiguous()
            init = lambda i: h0[i]

        return [
            recurrent(unit_feats[:, :, i], init(i))[0]
            for i, recurrent in enumerate(self.recurrent)
        ]

    def forward(
        self, agents: Tensor, agents_mask: Tensor, targets: Tensor, targets_mask: Tensor
    ):
        """"""
        feats = super().forward(agents, agents_mask, targets, targets_mask)
        latents = self._process_sequence(feats)
        latents = torch.stack(latents, dim=2)  # [S,B,N,C]
        return latents


class TransformerEncoder(BaseUnitEncoder):
    """Transformer (encoder or decoder) summarizes each observation.
    Then causally masked transformer-encoder across time."""

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latents: int,
        num_latent_channels: int,
        n_enc_layers: int,
        n_blocks: int,
        max_time: int,
        encoder_type: Literal["encoder", "decoder"],
        num_heads: int = 4,
        encoding_method=EncodingMethod.piecewise,
        **_kwargs,
    ):
        super().__init__(
            input_adapter,
            num_latents,
            num_latent_channels,
            n_enc_layers,
            encoder_type,
            num_heads,
            encoding_method,
        )
        self.temporal_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                num_latent_channels, num_heads, batch_first=True
            ),
            n_blocks,
        )
        self.time_embed = nn.Parameter(
            torch.empty(max_time, num_latent_channels).normal_().clamp_(-2, 2)
        )

    def load_state_dict(
        self, state_dict: dict[str, Any], strict: bool = True, assign: bool = False
    ):

        return super().load_state_dict(state_dict, strict, assign)

    def forward(
        self, agents: Tensor, agents_mask: Tensor, targets: Tensor, targets_mask: Tensor
    ):
        feats = super().forward(agents, agents_mask, targets, targets_mask)
        feats = feats + self.time_embed[:, None, None, :].expand_as(feats)

        feats = feats.transpose(0, 1)  # [B,T,N,C]
        B, T, N, C = feats.shape
        feats = feats.reshape(B, T * N, C)

        try:
            encoding: Tensor = self.temporal_encoder(
                feats, mask=make_causal_mask(T, N, device=feats.device)
            )
        except torch.OutOfMemoryError as err:
            # https://discuss.pytorch.org/t/nn-transformerencoder-oom-with-no-grad-eval-mask/210129
            if not self.temporal_encoder.training:
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
                self.temporal_encoder.train()
                encoding: Tensor = self.temporal_encoder(
                    feats, mask=make_causal_mask(T, N, device=feats.device)
                )
            else:
                raise err

        encoding = encoding.reshape(B, T, N, C)
        encoding = encoding.transpose(0, 1)  # [T,B,N,C]

        return encoding


class MambaEncoder(BaseUnitEncoder):
    """N Mamba blocks in parallel to accumulate the latent state"""

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latents: int,
        num_latent_channels: int,
        n_enc_layers: int,
        n_blocks: int,
        encoder_type: Literal["encoder", "decoder"],
        num_heads: int = 4,
        m_state: int = 64,
        m_conv: int = 4,
        m_expand: int = 2,
        squash_observation: bool = False,
        encoding_method=EncodingMethod.piecewise,
        **_kwargs,
    ):
        super().__init__(
            input_adapter,
            num_latents,
            num_latent_channels,
            n_enc_layers,
            encoder_type,
            num_heads,
            encoding_method,
        )

        self.squash_observation = squash_observation
        if squash_observation:
            self.temporal_encoder = Mamba2(
                num_latent_channels * num_latents,
                d_state=m_state,
                d_conv=m_conv,
                expand=m_expand,
                headdim=2 * self.latent_dim // num_heads,
            )
        else:
            mamba_factory = lambda: Mamba2(
                num_latent_channels,
                d_state=m_state,
                d_conv=m_conv,
                expand=m_expand,
                headdim=2 * self.latent_dim // num_heads,
            )

            self.temporal_encoder = nn.ModuleList(
                [
                    nn.Sequential(*(mamba_factory() for _ in range(n_blocks)))
                    for _ in range(num_latents)
                ]
            )

    def forward(
        self, agents: Tensor, agents_mask: Tensor, targets: Tensor, targets_mask: Tensor
    ):
        feats = super().forward(agents, agents_mask, targets, targets_mask)

        if isinstance(self.temporal_encoder, nn.ModuleList):
            # [T,B,N,C] -> [N,B,T,C]
            feats = feats.permute(2, 1, 0, 3)
            encoding = torch.stack(
                [enc(l) for enc, l in zip(self.temporal_encoder, feats)], dim=0
            )
            # [N,B,T,C] -> [T,B,N,C]
            encoding = encoding.permute(2, 1, 0, 3)
        else:
            # [T,B,N,C] -> [B,T,NC]
            latent_shape = feats.size()
            feats = feats.flatten(-2).transpose(0, 1)
            encoding: Tensor = self.temporal_encoder(feats)
            # [B,T,NC] -> [T,B,N,C]
            encoding = encoding.transpose(0, 1).reshape(latent_shape)

        return encoding
