import math

import torch
from torch import nn


class GCN(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        hidden_size: int | None = None,
        context_size: int | None = None,
        *,
        relational: bool = False,
        dropout: float = 0.0,
        batch_norm: bool = False,
        layer_norm: bool = False,
        layer_norm_eps: float = 1e-5,
    ):
        super().__init__()

        self.embed_dim = embed_dim
        self.hidden_size = hidden_size or embed_dim
        self.context_size = context_size or embed_dim

        self.message = nn.Linear(self.hidden_size, self.embed_dim, bias=False)
        self.update = nn.Linear(self.embed_dim, self.hidden_size, bias=False)
        self.dropout = nn.Dropout(dropout)

        self.q_hidden = nn.Linear(self.hidden_size, self.embed_dim, bias=False)
        self.q_hidden_dropout = nn.Dropout(dropout)

        self.q_ctx = nn.Linear(self.context_size, self.embed_dim, bias=False)
        self.q_ctx_dropout = nn.Dropout(dropout)

        self.k_hidden = nn.Linear(self.hidden_size, self.embed_dim, bias=False)
        self.k_hidden_dropout = nn.Dropout(dropout)

        self.k_ctx = nn.Linear(self.context_size, self.embed_dim, bias=False)
        self.k_ctx_dropout = nn.Dropout(dropout)

        if relational:
            self.rel_proj = nn.Linear(self.embed_dim, 1, bias=False)
        else:
            self.rel_proj = None

        if batch_norm:
            self.batch_norm = nn.BatchNorm1d(embed_dim)
        else:
            self.batch_norm = None
        if layer_norm:
            self.layer_norm = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
        else:
            self.layer_norm = None

        self.activation = nn.Sigmoid()

    def adjustment(self, h: torch.Tensor, context: torch.Tensor):
        """Compute the adjustment factor for the GCN.

        Args:
            h (torch.Tensor): The input tensor of shape (batch_size, num_nodes, embed_dim).
            context (torch.Tensor): The context tensor of shape (batch_size, length, embed_dim).

        Returns:
            torch.Tensor: The adjustment factor of shape (batch_size, num_nodes, num_nodes).
        """
        q: torch.Tensor = self.q_ctx(context)
        k: torch.Tensor = self.k_ctx(context)

        q = self.q_hidden_dropout(q)
        k = self.k_hidden_dropout(k)

        hidden_q: torch.Tensor = self.q_hidden(h)
        hidden_k: torch.Tensor = self.k_hidden(h)

        hidden_q = self.q_hidden_dropout(hidden_q)
        hidden_k = self.k_hidden_dropout(hidden_k)

        q = hidden_q @ q.transpose(-2, -1)
        k = k @ hidden_k.transpose(-2, -1)

        attn_weights = q @ k
        attn_weights = attn_weights / math.sqrt(self.embed_dim)
        attn_weights = torch.relu(attn_weights)

        return attn_weights

    def forward(
        self,
        h: torch.Tensor,
        adj: torch.Tensor,
        rel: torch.Tensor | None = None,
        *,
        context: torch.Tensor | None = None,
    ):
        h = h.to(self.message.weight.device)
        adj = adj.to(self.message.weight.device)
        rel = rel.to(self.message.weight.device) if rel is not None else None

        if h.shape[-1] != self.embed_dim:
            raise ValueError(
                f"The input to the GCN does not match the embedding dimension. "
                f"Got {h.shape[-1]}D tensor."
            )
        if adj.shape[-2] != adj.shape[-1]:
            raise ValueError(
                f"The adjacency matrix to the GCN is not square. "
                f"Got {adj.shape[-2]}x{adj.shape[-1]} tensor."
            )
        if adj.shape[-2] != h.shape[-2]:
            raise ValueError(
                f"The adjacency matrix to the GCN does not match the input. "
                f"Got {adj.shape[-2]}x{adj.shape[-1]} tensor and "
                f"{h.shape[-2]}x{h.shape[-1]} tensor."
            )
        if self.rel_proj is not None and rel is None:
            raise ValueError(
                "The relation input to the GCN is required when using relational GCN."
            )
        if rel is not None and (
            rel.dtype in (torch.int8, torch.int16, torch.int32, torch.int64)
        ):
            raise ValueError(
                f"The relation input to the GCN must be an integer tensor. "
                f"Got {rel.dtype} tensor."
            )
        if rel is not None and rel.shape[-1] != self.embed_dim:
            raise ValueError(
                f"The relation input to the GCN does not match the embedding dimension. "
                f"Got {rel.shape[-1]}D tensor."
            )
        if rel is not None and adj.max() >= rel.shape[-2]:
            raise ValueError(
                f"The relation input to the GCN does not match the adjacency matrix. "
                f"Got {adj.max()} index, but relation matrix has only "
                f"{rel.shape[-2]} relations."
            )

        if self.rel_proj is not None:
            adj_shape = adj.shape

            rel = self.rel_proj(rel)
            assert rel is not None, "Relation projection failed."
            rel = rel.squeeze(-1) / math.sqrt(self.embed_dim)

            rel = rel.view(-1, rel.shape[-1])
            adj = adj.view(-1, adj.shape[-2], adj.shape[-1])

            batch_idx = torch.arange(adj.shape[0]).view(-1, 1, 1)
            adj = adj.bool().to(rel.dtype) + rel[batch_idx, adj]

            adj = adj.view(*adj_shape)
        else:
            adj = adj.bool().to(h.dtype)

        if context is not None:
            gate = self.adjustment(h, context)
            adj = adj * gate

        adj = adj + torch.eye(adj.shape[-2], device=adj.device)
        deg = torch.sum(adj, dim=-1)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0.0
        deg_inv_sqrt = torch.diag_embed(deg_inv_sqrt)
        adj = deg_inv_sqrt @ adj @ deg_inv_sqrt

        h = self.message(h)
        out = adj @ h
        out = self.activation(self.update(out))
        if self.batch_norm is not None:
            out = self.batch_norm(out)
        if self.layer_norm is not None:
            out = self.layer_norm(out)
        out = self.dropout(out)
        return h + out
