from typing import Any, Type

import torch
from torch import Tensor, nn

from . import perceiver_io as pio
from ._iadapter import AgentIA, InputAdapter
from .goal_encoder import (
    Encoder,
    Encoder2,
    EncodingMethod,
    MambaEncoder,
    RNNEncoder,
    ScanEncoder,
    TransformerEncoder,
)


class Decoder(nn.Module):
    """Query the association between agents and targets by querying latent state
    with agent to render latent variable, then query targets with latent variable?
    We get the latent state associated with the agent, then use that to get the target.
    """

    def __init__(
        self,
        input_adapter: InputAdapter,
        num_latent_channels: int,
        latent_target_proj_dim: int = 128,
        num_heads: int = 4,
        dropout: float = 0.0,
        logit_out: bool = False,
        **_kwargs,
    ) -> None:
        super().__init__()
        self.agent_adapter = input_adapter
        self.logit_out = logit_out
        self.agent_query = pio.cross_attention_layer(
            input_adapter.out_channels, num_latent_channels, num_heads, dropout
        )

        self.agent_latent_proj = nn.Linear(
            input_adapter.out_channels, latent_target_proj_dim, bias=False
        )
        self.target_proj = nn.Linear(
            input_adapter.out_channels, latent_target_proj_dim, bias=False
        )

    def match_latent_target(
        self,
        agent_latent: Tensor,
        targets: Tensor,
        targets_mask: Tensor,
    ):
        agent_proj = self.agent_latent_proj(agent_latent)
        target_proj = self.target_proj(targets)

        attn = torch.einsum("bqc,bkc->bqk", agent_proj, target_proj)
        targets_mask = targets_mask.unsqueeze(1).expand_as(attn)
        attn[targets_mask] = -torch.inf

        if self.logit_out:
            return attn

        final_attn = torch.softmax(attn, dim=-1)

        return final_attn

    def forward(
        self,
        latents: Tensor,
        agents: Tensor,
        agents_mask: Tensor,
        targets: Tensor,
        targets_mask: Tensor,
    ):
        batch_sequence = latents.ndim == 4
        if batch_sequence:
            seq, batch = latents.shape[:2]
            latents = latents.reshape(-1, *latents.shape[2:])
            agents = agents.reshape(-1, *agents.shape[2:])
            agents_mask = agents_mask.reshape(-1, *agents_mask.shape[2:])
            targets = targets.reshape(-1, *targets.shape[2:])
            targets_mask = targets_mask.reshape(-1, *targets_mask.shape[2:])

        agents, agents_mask = self.agent_adapter(agents, agents_mask)
        agent_latent = self.agent_query(agents, latents)

        targets, targets_mask = self.agent_adapter(targets, targets_mask)
        matching = self.match_latent_target(agent_latent, targets, targets_mask)

        if batch_sequence:
            matching = matching.reshape(seq, batch, *matching.shape[1:])

        return {"agent_target": matching}


class GoalPerciever(nn.Module):
    def __init__(
        self,
        encoder: dict[str, Any],
        decoder: dict[str, Any],
        input_adapter: dict[str, Any],
        output_adapter: dict[str, Any],
    ):
        super().__init__()
        ia_inst = {"agent": AgentIA}[input_adapter["type"].lower()](
            **input_adapter["args"]
        )

        enc_type = encoder.get("type", "v1")
        if enc_type == "transformer2":
            enc_type = "transformer"
            encoder["encoding_method"] = EncodingMethod.fused

        encoder_type: Type[Encoder] = {
            "v1": Encoder,
            "v2": Encoder2,
            "scan": ScanEncoder,
            "recurrent": RNNEncoder,
            "transformer": TransformerEncoder,
            "mamba": MambaEncoder,
        }[enc_type]
        self.encoder = encoder_type(input_adapter=ia_inst, **encoder)

        self.decoder = Decoder(
            input_adapter=ia_inst,
            num_latent_channels=self.encoder.latent_dim,
            **decoder,
        )

    def forward(
        self,
        agents: Tensor,
        agents_valid: Tensor,
        targets: Tensor,
        targets_valid: Tensor,
        **_ignore,
    ):
        agents = agents.moveaxis((2, 0), (0, 1))
        agents_mask = ~agents_valid.moveaxis((2, 0), (0, 1)).bool()
        targets = targets.moveaxis((2, 0), (0, 1))
        targets_mask = ~targets_valid.moveaxis((2, 0), (0, 1)).bool()

        with torch.profiler.record_function("temporal-encoder"):
            latents = self.encoder(agents, agents_mask, targets, targets_mask)

        with torch.profiler.record_function("assign-decoder"):
            matchings: dict[str, Tensor] = self.decoder(
                latents, agents, agents_mask, targets, targets_mask
            )

        # [T,B,N,C] -> [B,N,T,C]
        ret = {k: v.permute(1, 2, 0, 3) for k, v in matchings.items()}
        return ret
