import math
from typing import Tuple

import torch
from jaxtyping import Float
from torch import Tensor, sparse_coo_tensor
from torch.linalg import solve_triangular
from torch_geometric.data import Batch


def get_adj(data: Batch, n: int) -> Float[Tensor, "b n n"]:
    """Transform PyG batch object to batch of adjacency matrices.

    :param data: PyG batch object.
    :param n: Max number of nodes in any graph.
    :return: Batch of adjacency matrices.
    """
    # b.shape = [n_nodes]
    b = data.batch_size

    edge_batch = data.batch[data.edge_index[0]]

    # shift edge index to start from 0 for each graph
    graph_shift = data.ptr[edge_batch].unsqueeze(0)
    edge_index = data.edge_index - graph_shift
    edge_start = edge_index[0]
    edge_end = edge_index[1]

    # create new index containing also the batch index
    index = torch.stack([edge_batch, edge_start, edge_end])

    return sparse_coo_tensor(index, data.edge_attr, [b, n, n]).to_dense()


def get_A(
    data: Batch, n: int, attr_name="edge_attr"
) -> Float[Tensor, "b n_leaves n_parents"]:
    """Transform PyG batch to batch of `A` matrix (according to parametrisation of Zügner et al.).

    :param data: PyG batch object.
    :param n: Max number of nodes in one graph.
    :return: Batch of `A` matrices.
    """
    n_leaves = (n + 1) // 2
    n_parents = (n - 1) // 2
    b = data.batch_size

    graph_n_nodes = (data.ptr[1:] - data.ptr[:-1])[data.batch]
    graph_n_leaves = (graph_n_nodes + 1) // 2

    leaf_mask = ~data.parent_mask[data.edge_index[0]]
    leaf_edge_index = data.edge_index[:, leaf_mask]

    edge_batch = data.batch[leaf_edge_index[0]]
    edge_start = leaf_edge_index[0]
    edge_end = leaf_edge_index[1]
    # shift edge index to start from 0 for each graph
    graph_shift = data.ptr[edge_batch]
    # shift parent index to start from 0 for each graph
    parent_shift = graph_n_leaves[leaf_edge_index[0]]
    edge_start = edge_start - graph_shift
    edge_end = edge_end - graph_shift - parent_shift
    index = torch.stack([edge_batch, edge_start, edge_end])
    value = data[attr_name][leaf_mask]
    return sparse_coo_tensor(index, value, [b, n_leaves, n_parents]).to_dense()


def get_XA(data: Batch, n: int) -> Float[Tensor, "b n_leaves d"]:
    """Pads the feature matrix X of leaf nodes with zeros.

    :param data: PyG batch object.
    :param n: Max number of nodes in one graph.
    :return: Padded leaf feature matrix.
    """
    b = data.batch_size
    n_leaves = (n + 1) // 2
    d = data.x.shape[1]

    XA = torch.zeros(b, n_leaves, d, device=data.x.device)
    leaf_mask = ~data.parent_mask
    batch_index = data.batch[leaf_mask]
    node_index = (
        torch.arange(0, data.num_nodes, device=data.x.device) - data.ptr[data.batch]
    )
    node_index = node_index[leaf_mask]
    XA[batch_index, node_index, :] = data.x[leaf_mask]
    return XA


def get_B(
    data: Batch, n: int, attr_name="edge_attr"
) -> Float[Tensor, "b n_parents n_parents"]:
    """Transform PyG batch object to batch of adjacency matrices.

    :param data: PyG batch object.
    :param n: Max number of nodes in one graph.
    :return: [b, n_leaves, n_parents] adjacency matrices.
    """
    n_parents = (n - 1) // 2
    b = data.batch_size

    graph_n_nodes = (data.ptr[1:] - data.ptr[:-1])[data.batch]
    graph_n_leaves = (graph_n_nodes + 1) // 2

    parent_mask = data.parent_mask[data.edge_index[0]]
    parent_edge_index = data.edge_index[:, parent_mask]

    edge_batch = data.batch[parent_edge_index[0]]
    edge_start = parent_edge_index[0]
    edge_end = parent_edge_index[1]

    # shift edge index to start from 0 for each graph
    graph_shift = data.ptr[edge_batch]
    # shift parent index to start from 0 for each graph
    parent_shift = graph_n_leaves[parent_edge_index[0]]
    edge_start = edge_start - graph_shift - parent_shift
    edge_end = edge_end - graph_shift - parent_shift
    index = torch.stack([edge_batch, edge_start, edge_end])
    value = data[attr_name][parent_mask]
    return sparse_coo_tensor(index, value, [b, n_parents, n_parents]).to_dense()


def compute_magnetic_laplacian(
    adj: Float[Tensor, "b n n"], q: float
) -> Float[Tensor, "b n n"]:
    """Compute magnetic laplacian for a batch of adjacency matrices.

    :param adj: Batch of adjacency matrices.
    :param q: Hyperparameter.
    :return: Magnetic laplacian.
    """
    adj_transposed = adj.transpose(1, 2)
    adj_symmetric = adj + adj_transposed
    symmetric_deg = adj_symmetric.sum(-2)

    theta = 1j * 2 * math.pi * q * (adj - adj_transposed)

    deg = torch.diag_embed(symmetric_deg)
    magnetic_laplacian = deg - adj_symmetric * torch.exp(theta)
    return magnetic_laplacian


def compute_anc(data: Batch, n: int) -> Float[Tensor, "b n n_parents"]:
    """Compute the ancestor matrix according to Zügner et al.

    :param data: PyG batch object.
    :param n: Max number of nodes in a graph.
    :return: Ancestor matrix.
    """
    A = get_A(data, n, attr_name="edge_attr_target")
    B = get_B(data, n, attr_name="edge_attr_target")
    b, n_parents, n_parents = B.shape
    I = torch.eye(n_parents, device=B.device).unsqueeze(0).expand(b, -1, -1)
    P_A_anc = solve_triangular(I - B, A, upper=True, left=False)
    P_B_anc = solve_triangular(I - B, I, upper=True, left=False) - I
    return torch.cat([P_A_anc, P_B_anc], dim=1)


def compute_P_anc(data: Batch, n: int) -> Float[Tensor, "b n n_parents"]:
    """Compute the probabilistic ancestor matrix according to Zügner et al.

    :param data: PyG batch object.
    :param n: Max number of nodes in a graph.
    :return: Probabilistic ancestor matrix.
    """
    A = get_A(data, n)
    B = get_B(data, n)
    b, n_parents, n_parents = B.shape
    I = torch.eye(n_parents, device=B.device).unsqueeze(0).expand(b, -1, -1)
    P_A_anc = solve_triangular(I - B, A, upper=True, left=False)
    P_B_anc = solve_triangular(I - B, I, upper=True, left=False) - I
    return torch.cat([P_A_anc, P_B_anc], dim=1)


def get_adj_row_indices(
    data: Batch,
) -> Tuple[Float[Tensor, "n_nodes"], Float[Tensor, "n_nodes"]]:
    """Get indices to access non-padded rows of batched adjacency matrices.

    :param data: PyG batch object.
    :return: Batch index and node index.
    """
    batch_index = data.batch

    node_index = torch.arange(0, data.num_nodes, device=data.t.device)
    node_index = node_index - data.ptr[batch_index]
    return batch_index, node_index


def get_adj_indices(
    data: Batch, n: int
) -> Tuple[
    Float[Tensor, "n_edges"], Float[Tensor, "n_edges"], Float[Tensor, "n_edges"]
]:
    """Get indices to access non-padded rows of batched adjacency matrices.

    :param data: PyG batch object.
    :param n: Max number of nodes in a graph.
    :return: Batch index and node index.
    """
    n_leaves = (n + 1) // 2

    graph_n_nodes = (data.ptr[1:] - data.ptr[:-1])[data.batch]
    graph_n_leaves = (graph_n_nodes + 1) // 2

    edge_batch = data.batch[data.edge_index[0]]
    edge_start = data.edge_index[0]
    edge_end = data.edge_index[1]

    # shift edge index to start from 0 for each graph
    graph_shift = data.ptr[edge_batch]
    edge_start = edge_start - graph_shift
    edge_end = edge_end - graph_shift

    # in the adj matrix the columns start with the first internal node
    edge_end = edge_end - graph_n_leaves[data.edge_index[0]]

    # since we concat padded matrix A and B, we need to shift the indices for B
    parent_mask = data.parent_mask[data.edge_index[0]]
    parent_shift_start = torch.zeros_like(graph_shift)
    parent_shift_start[parent_mask] = (
        n_leaves - graph_n_leaves[data.edge_index[0]][parent_mask]
    )
    edge_start = edge_start + parent_shift_start
    return edge_batch, edge_start, edge_end


def get_P_anc_indices(
    data: Batch, n: int
) -> Tuple[
    Float[Tensor, "n_edges"], Float[Tensor, "n_edges"], Float[Tensor, "n_edges"]
]:
    """Get indices to access non-padded entries of the probabilistic ancestor matrix.

    :param data: PyG batch object.
    :param n: Max number of nodes in a graph.
    """
    n_leaves = (n + 1) // 2

    graph_n_nodes = (data.ptr[1:] - data.ptr[:-1])[data.batch]
    graph_n_leaves = (graph_n_nodes + 1) // 2

    edge_batch = data.batch[data.edge_index[0]]
    edge_start = data.edge_index[0]
    edge_end = data.edge_index[1]

    # shift edge index to start from 0 for each graph
    graph_shift = data.ptr[edge_batch]

    parent_mask = data.parent_mask[edge_start]
    parent_shift_start = torch.zeros_like(graph_shift)
    parent_shift_start[parent_mask] = n_leaves - graph_n_leaves[edge_start[parent_mask]]
    parent_shift_end = n_leaves - graph_n_leaves[edge_start]

    edge_start = edge_start - graph_shift - parent_shift_start
    edge_end = edge_end - graph_shift - parent_shift_end
    return edge_batch, edge_start, edge_end
