import random
from copy import deepcopy
from typing import Any, Iterable, Literal, Type

import torch
from konductor.init import ModuleInitConfig
from konductor.utilities import comm
from mamba_ssm import Mamba2
from torch import Tensor, nn
from torch.distributed import scatter_object_list
from torch_discounted_cumsum_nd import discounted_cumsum

from . import perceiver_io as pio
from ._encoders import ENCODER_DICT, make_causal_mask
from ._iadapter import AgentIA, InputAdapter
from ._oadapter import ClassOccupancyOA, OccupancyOA, OutputAdapter


class MotionEncoder(nn.Module):
    r""""""

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latents: int,
        num_latent_channels: int,
        input_indicies: Iterable[int] | Literal["all"],
        num_cross_attention_heads: int = 4,
        num_self_attention_heads: int = 4,
        num_self_attention_layers_per_block: int = 6,
        dropout: float = 0,
        random_input_indicies: int = 0,
    ) -> None:
        super().__init__()

        self.latent_dim = num_latent_channels

        input_layer = pio.Sequential(
            pio.cross_attention_layer(
                num_q_channels=num_latent_channels,
                num_kv_channels=input_adapter.out_channels,
                num_heads=num_cross_attention_heads,
                dropout=dropout,
            ),
            pio.self_attention_block(
                num_layers=num_self_attention_layers_per_block,
                num_channels=num_latent_channels,
                num_heads=num_self_attention_heads,
                dropout=dropout,
            ),
        )
        self.input_layer = pio.Adapted(input_adapter, input_layer)

        self.propagate_layer = pio.self_attention_block(
            num_layers=num_self_attention_layers_per_block,
            num_channels=num_latent_channels,
            num_heads=num_self_attention_heads,
            dropout=dropout,
        )

        update_layer = pio.cross_attention_layer(
            num_q_channels=self.latent_dim,
            num_kv_channels=input_adapter.out_channels,
            num_heads=num_cross_attention_heads,
            dropout=dropout,
        )
        self.update_layer = pio.Adapted(input_adapter, update_layer)

        if isinstance(input_indicies, str):
            self.input_indicies = input_indicies
        else:
            self.input_indicies = set(input_indicies)

        self.random_input_indicies = random_input_indicies

        self.latent = nn.Parameter(torch.empty(num_latents, num_latent_channels))

        self._init_parameters()

    @torch.no_grad()
    def _init_parameters(self):
        self.latent.normal_(0.0, 0.02).clamp_(-2.0, 2.0)

    def init_step(self, output_idx: Tensor, agents: Tensor, agents_mask: Tensor):
        """"""
        # repeat initial latent vector along batch dimension
        x_latent: Tensor = self.latent.unsqueeze(0).expand(agents.shape[0], -1, -1)
        x_latent = self.input_layer(x_latent, agents, agents_mask)

        out_latent: list[Tensor] = []
        if 0 in output_idx:
            out_latent.append(x_latent)
            if self.training:
                x_latent = x_latent.detach().clone()

        return out_latent, x_latent

    def get_input_indicies(self, max_idx: int) -> set[int]:
        """Get input indicies for the model"""
        if self.input_indicies == "all":
            return set(range(max_idx))

        assert not isinstance(self.input_indicies, str)

        input_indicies = deepcopy(self.input_indicies)
        if self.random_input_indicies > 0 and self.training:
            candidates = list(range(min(input_indicies), max(input_indicies)))
            random.shuffle(candidates)
            input_indicies.update(candidates[: self.random_input_indicies])

            # Scatter random samples consistently if in distributed mode
            if comm.in_distributed_mode():
                if comm.is_main_process():
                    dist_list = [input_indicies] * comm.get_world_size()
                else:
                    dist_list = None
                out = [None]
                scatter_object_list(out, dist_list, src=0)
                input_indicies = set(out[0])

        return input_indicies

    def forward(
        self, output_idx: Tensor, agents: Tensor, agents_mask: Tensor
    ) -> list[Tensor]:
        """Given an input of dims [Time Batch Tokens Channels]"""
        out_latent, x_latent = self.init_step(output_idx, agents[0], agents_mask[0])

        max_idx = int(output_idx.max().item()) + 1
        input_indicies = self.get_input_indicies(max_idx)

        for t_idx in range(1, max_idx):
            x_latent = self.propagate_layer(x_latent)
            if t_idx in input_indicies:
                x_latent = self.update_layer(
                    x_latent, agents[t_idx], agents_mask[t_idx]
                )
            if t_idx in output_idx:
                out_latent.append(x_latent)
                if self.training:
                    x_latent = x_latent.detach().clone()

        return out_latent


class MotionEncoderDetach(MotionEncoder):
    """Variant of motion encoder that detaches at
    every timestep"""

    def forward(
        self,
        output_idx: Tensor,
        agents: Tensor,
        agents_mask: Tensor,
    ) -> list[Tensor]:
        """Given an input of dims [Time Batch Tokens Channels]"""
        out_latent, x_latent = self.init_step(output_idx, agents[0], agents_mask[0])

        max_idx = int(output_idx.max().item()) + 1
        input_indicies = self.get_input_indicies(max_idx)

        for t_idx in range(1, max_idx):
            if t_idx in output_idx:
                x_latent = self.propagate_layer(x_latent)
                if t_idx in input_indicies:
                    x_latent = self.update_layer(
                        x_latent, agents[t_idx], agents_mask[t_idx]
                    )
                out_latent.append(x_latent)
                if self.training:
                    x_latent = x_latent.detach().clone()

            else:
                with torch.no_grad():
                    x_latent = self.propagate_layer(x_latent)
                    if t_idx in input_indicies:
                        x_latent = self.update_layer(
                            x_latent, agents[t_idx], agents_mask[t_idx]
                        )

        return out_latent


class ScanEncoder(nn.Module):
    """Inclusive scan based encoder"""

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latents: int,
        num_latent_channels: int,
        num_heads: int = 4,
        n_blocks: int = 6,
        dropout: float = 0.0,
        self_attn_block: int = 1,
        post_self_attn: int = 0,
        cross_attn_block: int = 1,
        gamma: float = 2.0,
        temporal: bool = True,
        **_kwargs,
    ) -> None:
        super().__init__()
        self.gamma = gamma
        self.temporal = temporal
        self.latent_dim = num_latent_channels
        self.input_adapter = input_adapter

        self.agent_proj = nn.Linear(input_adapter.out_channels, self.latent_dim)
        nn.init.eye_(self.agent_proj.weight)
        nn.init.zeros_(self.agent_proj.bias)

        if cross_attn_block == 1:
            cross_attn = lambda: pio.cross_attention_layer(
                self.latent_dim, self.latent_dim, num_heads, dropout=dropout
            )
        else:
            cross_attn = lambda: pio.CrossAttentionBlock(
                cross_attn_block,
                self.latent_dim,
                self.latent_dim,
                num_heads,
                dropout=dropout,
            )

        self.agent_attn = nn.ModuleList(cross_attn() for _ in range(n_blocks))

        self.self_attn = nn.ModuleList(
            pio.self_attention_block(
                self_attn_block, self.latent_dim, num_heads, dropout=dropout
            )
            for _ in range(n_blocks)
        )

        if post_self_attn > 0:
            self.post_self_attn = nn.ModuleList(
                pio.self_attention_block(
                    post_self_attn, self.latent_dim, num_heads, dropout=dropout
                )
                for _ in range(n_blocks)
            )
        else:
            self.post_self_attn = None

        self.latent = nn.Parameter(torch.empty(num_latents, num_latent_channels))
        nn.init.xavier_normal_(self.latent)

    def forward(self, output_idx: Tensor, agents: Tensor, agents_mask: Tensor):
        seq, bsz = agents.shape[:2]

        agents, agents_mask = self.input_adapter(agents, agents_mask)
        agents = self.agent_proj(agents.view(-1, *agents.shape[2:]))
        agents_mask = agents_mask.reshape(-1, *agents_mask.shape[2:])

        # Expand latent to [S,B,N,C]
        latent: Tensor = self.latent.expand(seq, bsz, -1, -1)
        for idx in range(len(self.self_attn)):
            latent = latent.view(-1, *latent.shape[2:])
            latent = self.agent_attn[idx](latent, agents, agents_mask)
            latent = self.self_attn[idx](latent)
            latent = latent.view(seq, bsz, *latent.shape[1:])

            if self.temporal and self.gamma == 1:
                latent = torch.cumsum(latent, dim=0)
            elif self.temporal:
                latent = discounted_cumsum(latent, dim=0, gamma=self.gamma)

            if self.post_self_attn is not None:
                latent = latent.view(-1, *latent.shape[2:])
                latent = self.post_self_attn[idx](latent)
                latent = latent.view(seq, bsz, *latent.shape[1:])

        latent = latent[output_idx[0]]
        return latent


class TransformerEncoder(nn.Module):
    """Transformer (encoder or decoder) summarizes each observation.
    Then causally masked transformer-encoder across time."""

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latents: int,
        num_latent_channels: int,
        n_enc_layers: int,
        n_blocks: int,
        max_time: int,
        encoder_type: Literal["encoder", "decoder"],
        num_heads: int = 4,
        **_kwargs,
    ):
        super().__init__()
        self.input_adapter = input_adapter
        self.latent_dim = num_latent_channels
        self.unit_enc = ENCODER_DICT[encoder_type](
            input_adapter.out_channels, num_latent_channels, n_enc_layers, num_heads
        )

        self.temporal_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                num_latent_channels, num_heads, batch_first=True
            ),
            n_blocks,
        )

        self.time_embed = nn.Parameter(
            torch.empty(max_time, num_latent_channels).normal_().clamp_(-2, 2)
        )
        self.latent = nn.Parameter(
            torch.empty(num_latents, num_latent_channels).normal_().clamp_(-2, 2)
        )

    def forward(self, output_idx: Tensor, agents: Tensor, agents_mask: Tensor):
        agents, agents_mask = self.input_adapter(agents, agents_mask)

        latent = self.latent.expand(agents.shape[1], -1, -1)
        unit_feats: Tensor = self.unit_enc(latent, agents, agents_mask)
        time_embed = self.time_embed[:, None, None, :].expand_as(unit_feats)
        unit_feats = unit_feats + time_embed

        unit_feats = unit_feats.transpose(0, 1)  # [B,T,N,C]
        B, T, N, C = unit_feats.shape
        unit_feats = unit_feats.reshape(B, T * N, C)

        try:
            encoding: Tensor = self.temporal_encoder(
                unit_feats, mask=make_causal_mask(T, N, device=unit_feats.device)
            )
        except torch.OutOfMemoryError as err:
            # https://discuss.pytorch.org/t/nn-transformerencoder-oom-with-no-grad-eval-mask/210129
            if not self.temporal_encoder.training:
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
                self.temporal_encoder.train()
                encoding: Tensor = self.temporal_encoder(
                    unit_feats, mask=make_causal_mask(T, N, device=unit_feats.device)
                )
            else:
                raise err

        encoding = encoding.reshape(B, T, N, C)[:, output_idx[0]]
        encoding = encoding.transpose(0, 1)  # [T,B,N,C]

        return encoding


class MambaEncoder(nn.Module):
    """N Mamba blocks in parallel to accumulate the latent state"""

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latents: int,
        num_latent_channels: int,
        n_enc_layers: int,
        n_blocks: int,
        encoder_type: Literal["encoder", "decoder"],
        num_heads: int = 4,
        m_state: int = 64,
        m_conv: int = 4,
        m_expand: int = 2,
        **_kwargs,
    ):
        super().__init__()
        self.input_adapter = input_adapter
        self.latent_dim = num_latent_channels
        self.unit_enc = ENCODER_DICT[encoder_type](
            input_adapter.out_channels, num_latent_channels, n_enc_layers, num_heads
        )

        mamba_factory = lambda: Mamba2(
            num_latent_channels,
            d_state=m_state,
            d_conv=m_conv,
            expand=m_expand,
            headdim=2 * self.latent_dim // num_heads,
        )

        self.temporal_encoder = nn.ModuleList(
            [
                nn.Sequential(*(mamba_factory() for _ in range(n_blocks)))
                for _ in range(num_latents)
            ]
        )

        self.latent = nn.Parameter(
            torch.empty(num_latents, num_latent_channels).normal_().clamp_(-2, 2)
        )

    def forward(self, output_idx: Tensor, agents: Tensor, agents_mask: Tensor):
        agents, agents_mask = self.input_adapter(agents, agents_mask)

        latent = self.latent.expand(agents.shape[1], -1, -1)
        unit_feats: Tensor = self.unit_enc(latent, agents, agents_mask)

        # [T,B,N,C] -> [N,B,T,C]
        unit_feats = unit_feats.permute(2, 1, 0, 3)

        encoding = torch.stack(
            [enc(l) for enc, l in zip(self.temporal_encoder, unit_feats)], dim=0
        )

        # [N,B,T,C] -> [T,B,N,C]
        encoding = encoding[:, :, output_idx[0]].permute(2, 1, 0, 3)

        return encoding


class MotionPerceiver(nn.Module):
    def __init__(self, encoder: dict[str, Any], decoder: dict[str, Any]) -> None:
        super().__init__()
        # Setup Encoder
        encoder = deepcopy(encoder)
        in_adapt = ModuleInitConfig(**encoder.pop("adapter"))
        input_adapter = {"agent": AgentIA}[in_adapt.type.lower()](**in_adapt.args)

        enc_map: dict[str, Type[nn.Module]] = {
            "recursive": MotionEncoder,
            "recursive-detach": MotionEncoderDetach,
            "scan": ScanEncoder,
            "transformer": TransformerEncoder,
            "mamba": MambaEncoder,
        }

        if "version" in encoder:
            enc_cls = list(enc_map.values())[encoder.pop("version") - 1]
        elif "type" in encoder:
            enc_cls = enc_map[encoder.pop("type")]
        else:
            raise KeyError("Missing 'version' or 'type' from encoder")

        self.encoder = enc_cls(input_adapter=input_adapter, **encoder)

        # Setup Decoder
        decoder = deepcopy(decoder)
        out_adapt = ModuleInitConfig(**decoder.pop("adapter"))
        if decoder["position_encoding_type"] == "fourier":
            out_adapt.args["num_output_channels"] = 4 * decoder["num_frequency_bands"]
        out_adapter: OutputAdapter = {
            "occupancy": OccupancyOA,
            "class-occupancy": ClassOccupancyOA,
        }[out_adapt.type.lower()](**out_adapt.args)

        self.decoder = pio.PerceiverDecoder(
            output_adapter=out_adapter,
            num_latent_channels=self.encoder.latent_dim,
            **decoder,
        )

    def forward(
        self, time_idx: Tensor, agents: Tensor, agents_valid: Tensor, **kwargs
    ) -> dict[str, Tensor]:
        """
        Format of x_latent, data and mask is T,B,N,C
        For the time being, the time_idxs to sample from are broadcast
        across the batch.
        """
        kwargs = {
            "output_idx": time_idx[0],
            "agents": agents.moveaxis((2, 0), (0, 1)),
            "agents_mask": ~agents_valid.moveaxis((2, 0), (0, 1)).bool(),
        }

        x_latents: list[Tensor] = self.encoder(**kwargs)

        # [T,{K:[B,H,W]}]
        x_logits = [self.decoder(x_latent) for x_latent in x_latents]

        # {K:[B,T,H,W]}
        out_logits = {
            name: torch.cat([x[name] for x in x_logits], dim=1)
            for name in self.decoder.names
        }

        return out_logits
