"""Agent embedding modules (LSTM only)."""

from __future__ import annotations

from typing import Dict, Iterable, Type

import torch
import torch.nn as nn


class LSTMAgentEmbedding(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        embed_dim: int,
        num_layers: int = 1,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        lstm_dropout = dropout if num_layers > 1 else 0.0
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=lstm_dropout,
        )
        self.proj = nn.Linear(hidden_dim, embed_dim)

    def forward(self, seq: torch.Tensor) -> torch.Tensor:
        if seq.dim() != 4:
            raise ValueError("Expected input of shape (batch, seq_len, num_agents, input_dim)")

        batch_size, seq_len, num_agents, input_dim = seq.shape
        seq = seq.reshape(batch_size * num_agents, seq_len, input_dim)

        outputs, _ = self.lstm(seq)
        last_hidden = outputs[:, -1, :]
        embeddings = self.proj(last_hidden)
        embeddings = embeddings.view(batch_size, num_agents, -1)
        return embeddings


_AGENT_EMBED_REGISTRY: Dict[str, Type[nn.Module]] = {
    "lstm": LSTMAgentEmbedding,
}


def register_agent_embedding(name: str, cls: Type[nn.Module]) -> None:
    if name in _AGENT_EMBED_REGISTRY:
        raise ValueError(f"Agent embedding '{name}' already registered")
    _AGENT_EMBED_REGISTRY[name] = cls


def available_agent_embeddings() -> Iterable[str]:
    return _AGENT_EMBED_REGISTRY.keys()


def build_agent_embedding(name: str, **kwargs) -> nn.Module:
    target = name.lower()
    if target not in _AGENT_EMBED_REGISTRY:
        raise ValueError(f"Unknown agent embedding '{target}'")
    return _AGENT_EMBED_REGISTRY[target](**kwargs)


LSTM = LSTMAgentEmbedding

__all__ = [
    "LSTMAgentEmbedding",
    "register_agent_embedding",
    "available_agent_embeddings",
    "build_agent_embedding",
    "LSTM",
]
