import math
from typing import Tuple

import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch.linalg import solve_triangular


class ATT(nn.Module):
    """ATT Model."""

    def __init__(
        self, n_blocks: int, d_features: int, d_hidden: int, n: int, step_emb: int
    ) -> None:
        """Initialize the ATT model.

        :param n_blocks: Number of residual blocks.
        :param d_features: Number of features.
        :param d_hidden: Number of hidden channels.
        :param n_internal: Maximal number of nodes in any graph.
        :param step_emb: Number of embeddings for sinonoidal position embeddings.
        """
        super().__init__()

        n_internal = (n - 1) // 2
        d_XA = d_features + n_internal + 2 * d_hidden
        self.init_leaves = FeedForward(
            n_layers=2, d_input=d_XA, d_hidden=d_hidden, d_output=d_hidden
        )
        d_XB = n_internal + 3 * d_hidden
        self.init_internal = FeedForward(
            n_layers=2, d_input=d_XB, d_hidden=d_hidden, d_output=d_hidden
        )
        self.out_leaves = FeedForward(
            n_layers=2, d_input=d_hidden, d_hidden=d_hidden, d_output=n_internal
        )
        self.out_internal = FeedForward(
            n_layers=2, d_input=d_hidden, d_hidden=d_hidden, d_output=n_internal
        )
        self.time_emb = SinusoidalPositionEmbeddings(step_emb)
        self.time_net = FeedForward(
            n_layers=1, d_input=step_emb, d_hidden=d_hidden, d_output=d_hidden
        )
        self.node_emb = SinusoidalPositionEmbeddings(step_emb)
        self.node_net = FeedForward(
            n_layers=1, d_input=step_emb, d_hidden=d_hidden, d_output=d_hidden
        )
        residual_blocks = [ResidualBlock(d_hidden) for _ in range(n_blocks)]
        self.residual_blocks = nn.ModuleList(residual_blocks)

    def forward(
        self,
        XA: Float[Tensor, "b n_leaves d"],
        A: Float[Tensor, "b n_leaves n_internal"],
        B: Float[Tensor, "b n_internal n_internal"],
        t: Float[Tensor, "b 1"],
    ) -> Tuple[Float[Tensor, "b n_leaves n_internal"], Float[Tensor, "b n_internal n_internal"]]:
        """Perform a forward pass with the ATT model.

        :param XA: Node features.
        :param A: Batched adjacency matrix for leaves.
        :param B: Batched adjacency matrix for internal nodes.
        :return: Predicted adjacency matrices for leaves and internal nodes.
        """
        b, n_leaves, n_internal = A.shape
        I = torch.eye(n_internal, device=B.device).unsqueeze(0).expand(b, -1, -1)
        AB_anc = solve_triangular(I - B, A, upper=True, left=False)

        t = self.time_net(self.time_emb(t)).unsqueeze(1)
        t_leaf = t.repeat(1, n_leaves, 1)
        t_internal = t.repeat(1, n_internal, 1)

        order_leaf = torch.arange(n_leaves, device=XA.device)
        order_leaf = self.node_net(self.node_emb(order_leaf))[None].repeat(b, 1, 1)
        order_internal = torch.arange(n_internal, device=XA.device)
        order_internal = self.node_net(self.node_emb(order_internal))[None].repeat(b, 1, 1)

        XA = torch.cat([XA, A, t_leaf, order_leaf], dim=2)
        xa = self.init_leaves(XA)

        XB = torch.matmul(AB_anc.permute(0, 2, 1), xa)
        XB = torch.cat([XB, B, t_internal, order_internal], dim=2)
        xb = self.init_internal(XB)

        for block in self.residual_blocks:
            xa, xb = block(xa, xb, A, B, AB_anc)

        A_pred = self.out_leaves(xa)
        B_pred = self.out_internal(xb)
        return A_pred, B_pred


class ResidualBlock(nn.Module):
    """Residual block used by the ATT model."""

    def __init__(
        self,
        d_hidden: int,
    ) -> None:
        """Initialize the Residual Block.

        :param d_hidden: Number of hidden channels.
        """
        super().__init__()
        self.init_leaves = FeedForward(
            n_layers=1, d_input=d_hidden, d_hidden=d_hidden, d_output=d_hidden
        )
        self.init_internal = FeedForward(
            n_layers=1, d_input=d_hidden, d_hidden=d_hidden, d_output=d_hidden
        )
        self.post_leaf_aggr = FeedForward(
            n_layers=1, d_input=d_hidden, d_hidden=d_hidden, d_output=d_hidden
        )
        self.post_int_aggr = FeedForward(
            n_layers=1, d_input=d_hidden, d_hidden=d_hidden, d_output=d_hidden
        )
        self.post_global_aggr = FeedForward(
            n_layers=1, d_input=d_hidden, d_hidden=d_hidden, d_output=d_hidden
        )
        self.norm_xa = nn.BatchNorm1d(d_hidden)
        self.norm_xb = nn.BatchNorm1d(d_hidden)


    def forward(
        self,
        xa: Float[Tensor, "b n_leaves d"],
        xb: Float[Tensor, "b n_internal d"],
        A: Float[Tensor, "b n_leaves n_internal"],
        B: Float[Tensor, "b n_internal n_internal"],
        AB_anc: Float[Tensor, "b n_leaves n_internal"],
    ):
        """Forward pass with Residual Block.

        :param xa: Node features for leaves.
        :param xb: Node features for internal nodes.
        :param A: Batched adjacency matrix for leaves.
        :param B: Batched adjacency matrix for internal nodes.
        :param AB_anc: Batched probabilistic ancestor matrix for leaves.
        :return: Updated node features for leaves and internal nodes.
        """
        A_T = A.permute(0, 2, 1)
        B_T = B.permute(0, 2, 1)
        AB_anc_T = AB_anc.permute(0, 2, 1)
        haa = torch.matmul((torch.matmul(A, A_T)), xa)
        hab = torch.matmul((torch.matmul(A, B_T)), xb)
        hbb = torch.matmul((torch.matmul(B, B_T)), xb)
        hba = torch.matmul(AB_anc_T, xa)
        h_aba = torch.matmul(torch.matmul(AB_anc, AB_anc_T), xa)
        ha = self.init_leaves(xa) + self.post_leaf_aggr(haa + hab) + self.post_global_aggr(h_aba)
        hb = self.init_internal(xb) + self.post_int_aggr(hba + hbb)
        xa = ha + xa
        xb = hb + xb
        b, n_leaves, d = xa.shape
        b, n_internal, d = xb.shape
        xa = self.norm_xa(xa.reshape(-1, d)).reshape(b, n_leaves, d)
        xb = self.norm_xb(xb.reshape(-1, d)).reshape(b, n_internal, d)
        return xa, xb


class FeedForward(nn.Module):
    """Feed Forward Neural Network."""

    def __init__(self, n_layers: int, d_input: int, d_hidden: int, d_output: int) -> None:
        """Initialize the Feed Forward Neural Network.

        :param n_layers: Number of layers.
        :param d_input: Input dimension.
        :param d_hidden: Hidden dimension.
        :param d_output: Output dimension.
        """
        super().__init__()

        d = [d_input] + [d_hidden] * (n_layers) + [d_output]
        layers = []
        for i in range(len(d) - 1):
            layers.append(nn.Linear(d[i], d[i + 1]))
            if i >= n_layers:
                continue
            layers.append(nn.SELU())
        self.net = nn.Sequential(*layers)

    def forward(self, x: Float[Tensor, "n_nodes d_input"]) -> Float[Tensor, "n_nodes d_output"]:
        return self.net(x)


class SinusoidalPositionEmbeddings(nn.Module):
    """Sinusoidal Position Embeddings."""

    def __init__(self, dim: int) -> None:
        """Initialize the Sinusoidal Position Embeddings.

        :param dim: Dimension of the embeddings.
        """
        super().__init__()

        self.dim = dim

    def forward(self, time: Float[Tensor, "n 1"]) -> Float[Tensor, "n d"]:
        """Compute the Sinusoidal Position Embeddings.

        :param time: Time feature.
        :return: Sinusoidal Position Embeddings.
        """
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings
