from __future__ import annotations

from typing import Dict, Iterable, Sequence, Tuple, Type

import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttentionLayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        head_dim: int,
        output_dim: int,
        num_heads: int,
        num_neighbors: int,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        total_dim = head_dim * num_heads
        self.agent_proj = nn.Linear(input_dim, total_dim)
        self.neighbor_proj = nn.Linear(input_dim, total_dim)
        self.neighbor_hidden_proj = nn.Linear(input_dim, total_dim)
        self.out_proj = nn.Linear(head_dim, output_dim)
        self.num_neighbors = num_neighbors

    def forward(self, feats: torch.Tensor, adjacency: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, num_agents, _ = feats.shape
        neighbor_repr = torch.einsum("banm,bmd->band", adjacency, feats)

        agent_head = self.agent_proj(feats).view(batch_size, num_agents, self.num_heads, self.head_dim)
        agent_head = agent_head.unsqueeze(3)

        neighbor_head = self.neighbor_proj(neighbor_repr)
        neighbor_head = neighbor_head.view(batch_size, num_agents, self.num_neighbors, self.num_heads, self.head_dim)
        neighbor_head = neighbor_head.permute(0, 1, 3, 2, 4)

        att_logits = torch.matmul(agent_head, neighbor_head.transpose(-1, -2))
        att = F.softmax(att_logits, dim=-1)

        neighbor_hidden = self.neighbor_hidden_proj(neighbor_repr)
        neighbor_hidden = neighbor_hidden.view(batch_size, num_agents, self.num_neighbors, self.num_heads, self.head_dim)
        neighbor_hidden = neighbor_hidden.permute(0, 1, 3, 2, 4)

        context = torch.matmul(att, neighbor_hidden)
        context = context.mean(dim=2).squeeze(2)

        out = F.relu(self.out_proj(context))
        return out, att.squeeze(3)


class CoLightGATQNetwork(nn.Module):


    def __init__(
        self,
        len_feature: int,
        mlp_layers: Sequence[int],
        cnn_layers: Sequence[Sequence[int]],
        num_actions: int,
        num_agents: int,
        num_neighbors: int,
        num_heads: int = 5,
    ) -> None:
        super().__init__()
        self.mlp_layers = nn.ModuleList()
        in_dim = len_feature
        for hidden_dim in mlp_layers:
            self.mlp_layers.append(nn.Linear(in_dim, hidden_dim))
            in_dim = hidden_dim

        att_layers = []
        att_input_dim = mlp_layers[-1] if mlp_layers else len_feature
        for h_dim, out_dim in cnn_layers:
            att_layers.append(
                MultiHeadAttentionLayer(
                    input_dim=att_input_dim,
                    head_dim=h_dim,
                    output_dim=out_dim,
                    num_heads=num_heads,
                    num_neighbors=num_neighbors,
                )
            )
            att_input_dim = out_dim
        self.att_layers = nn.ModuleList(att_layers)
        self.output_layer = nn.Linear(att_input_dim, num_actions)

    def forward(self, feats: torch.Tensor, adjacency: torch.Tensor) -> torch.Tensor:
        x = feats
        for layer in self.mlp_layers:
            x = F.relu(layer(x))

        h = x
        for att_layer in self.att_layers:
            h, _ = att_layer(h, adjacency)

        return self.output_layer(h)


_SKILL_Q_REGISTRY: Dict[str, Type[nn.Module]] = {
    "colight_gat": CoLightGATQNetwork,
}


def register_skill_q_network(name: str, cls: Type[nn.Module]) -> None:
    if name in _SKILL_Q_REGISTRY:
        raise ValueError(f"Skill Q-network '{name}' already registered")
    _SKILL_Q_REGISTRY[name] = cls


def available_skill_q_networks() -> Iterable[str]:
    return _SKILL_Q_REGISTRY.keys()


def build_skill_q_network(name: str, **kwargs) -> nn.Module:
    target = name.lower()
    if target not in _SKILL_Q_REGISTRY:
        raise ValueError(f"Unknown skill Q-network '{target}'")
    return _SKILL_Q_REGISTRY[target](**kwargs)


__all__ = [
    "MultiHeadAttentionLayer",
    "CoLightGATQNetwork",
    "register_skill_q_network",
    "available_skill_q_networks",
    "build_skill_q_network",
]
