"""Group embedding modules with CVAE implementation."""

from __future__ import annotations

from typing import Dict, Iterable, Type

import torch
import torch.nn as nn
import torch.nn.functional as F


class CVAE(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, latent_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.latent_dim = latent_dim
        self.action_dim = action_dim

        encoder_input_dim = state_dim + action_dim
        decoder_input_dim = state_dim + latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(encoder_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(decoder_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
        )

    @staticmethod
    def reparameterize(mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encode(self, state: torch.Tensor, action_onehot: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x = torch.cat([state, action_onehot], dim=-1)
        hidden = self.encoder(x)
        mu = self.fc_mu(hidden)
        log_var = self.fc_log_var(hidden)
        return mu, log_var

    def decode(self, state: torch.Tensor, latent: torch.Tensor) -> torch.Tensor:
        x = torch.cat([state, latent], dim=-1)
        return self.decoder(x)

    def forward(self, state: torch.Tensor, action_onehot: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, log_var = self.encode(state, action_onehot)
        latent = self.reparameterize(mu, log_var)
        logits = self.decode(state, latent)
        return logits, mu, log_var


class CVAEGroupEmbedding(nn.Module):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        latent_dim: int,
        lr: float,
        kl_weight: float,
        hidden_dim: int = 128,
        device: torch.device | str = "cpu",
    ) -> None:
        super().__init__()
        self.device = torch.device(device)
        self.latent_dim = latent_dim
        self.action_dim = action_dim
        self.model = CVAE(state_dim, action_dim, latent_dim, hidden_dim).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.kl_weight = kl_weight

    def encode(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        states = states.to(self.device)
        actions = actions.to(self.device)
        action_onehot = F.one_hot(actions, num_classes=self.action_dim).float()
        mu, _ = self.model.encode(states, action_onehot)
        return mu

    def sample_proto(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        states = states.to(self.device)
        actions = actions.to(self.device)
        action_onehot = F.one_hot(actions, num_classes=self.action_dim).float()
        logits, mu, log_var = self.model(states, action_onehot)
        return logits, mu, log_var

    def train_batch(self, states: torch.Tensor, actions: torch.Tensor) -> None:
        states = states.to(self.device)
        actions = actions.to(self.device)
        action_onehot = F.one_hot(actions, num_classes=self.action_dim).float()
        self.optimizer.zero_grad()
        logits, mu, log_var = self.model(states, action_onehot)
        recon_loss = F.cross_entropy(logits, actions)
        kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / states.size(0)
        (recon_loss + self.kl_weight * kld).backward()
        self.optimizer.step()

    def save(self, path: str) -> None:
        torch.save({"model": self.model.state_dict()}, path)

    def load(self, path: str) -> None:
        state = torch.load(path, map_location=self.device)
        self.model.load_state_dict(state["model"])


_GROUP_EMBED_REGISTRY: Dict[str, Type[nn.Module]] = {
    "cvae": CVAEGroupEmbedding,
}


def register_group_embedding(name: str, cls: Type[nn.Module]) -> None:
    if name in _GROUP_EMBED_REGISTRY:
        raise ValueError(f"Group embedding '{name}' already registered")
    _GROUP_EMBED_REGISTRY[name] = cls


def available_group_embeddings() -> Iterable[str]:
    return _GROUP_EMBED_REGISTRY.keys()


def build_group_embedding(name: str, **kwargs) -> nn.Module:
    target = name.lower()
    if target not in _GROUP_EMBED_REGISTRY:
        raise ValueError(f"Unknown group embedding '{target}'")
    return _GROUP_EMBED_REGISTRY[target](**kwargs)


__all__ = [
    "CVAE",
    "CVAEGroupEmbedding",
    "register_group_embedding",
    "available_group_embeddings",
    "build_group_embedding",
]
