"""
The decomposition of a batched multi featured graphs.
"""
import torch
from torch import Tensor
from typing import List


def batched_matrix_decomposition(A: Tensor,siamese:str) -> Tensor:
    """
    The batched matrix decomposition.
    Args:
        A: The batch tensor of shape [batch,num_graphs,num_features,n,n]

    Returns: The decomposition.
    """
    batch_size, num_graphs, num_features, n, _ = A.shape
    device = A.device
    # we decompose A to 7 irrep projections
    # Diagonal.
    diagonal = torch.eye(n, device=device).repeat(batch_size, num_graphs, num_features, 1, 1)
    # The first is diagonal matrix such that all entries are the same.
    a0 = (A * diagonal).sum(-1).sum(-1).view(batch_size, num_graphs, num_features, 1, 1) * diagonal / n
    off_diag = (A - a0).view(-1, num_graphs, num_features, n, n)
    off_diag_sum = off_diag.sum(-1).sum(-1)
    # The second is an off-diagonal matrix with entries all the same.
    a1 = (1 - diagonal) * off_diag_sum.view(-1, num_graphs, num_features, 1, 1) / (n ** 2 - n)
    # Difference.
    A_hat = A - a0 - a1
    # A_hat row and col sum
    row_sum = A_hat.sum(dim=-1)
    col_sum = A_hat.sum(dim=-2)
    # Make diagonal.
    diag = A_hat.diagonal(dim1=4, dim2=3)
    # Rows to be duplicated.
    r = (n * diag - (n - 1) * row_sum - col_sum) / (n ** 2 - 2 * n)
    # Columns to be duplicated.
    c = (n * diag - (n - 1) * col_sum - row_sum) / (n ** 2 - 2 * n)
    d = - (r + c + diag)
    # Sum zero diagonal matrix.
    a2 = -1 * torch.diag_embed(d)
    # Duplicate rows/cols
    a3 = -1 * c.unsqueeze(-2).repeat(1, 1, 1, n, 1)
    a4 = -1 * r.unsqueeze(-1).repeat(1, 1, 1, 1, n)
    A_tilde = A_hat - a2 - a3 - a4
    # Anti symmetric, row sum zero matrix.
    a5 = 0.5 * (A_tilde - A_tilde.transpose(4, 3))
    # Symmetric, row sum zero matrix.
    a6 = 0.5 * (A_tilde + A_tilde.transpose(4, 3))
    # Compose further the first two spaces.
    if siamese == 'SchurNet':
        b_0 = a0.mean(1).unsqueeze(1).repeat(1, num_graphs, 1, 1, 1)
        b_1 = a0 - b_0
        b_2 = a1.mean(1).unsqueeze(1).repeat(1, num_graphs, 1, 1, 1)
        b_3 = a1 - b_2
    else:
        b_0 = a0
        b_1 = torch.zeros_like(a0)
        b_2 = a1
        b_3 = torch.zeros_like(a1)

    return torch.stack([b_0, b_1, b_2, b_3, a2, a3, a4, a5, a6], dim=-1)


def get_zero_sum_vectors(X) -> Tensor:
    """
    Returns the n-1 dimensional vectors sums.
    Args:
        X: The matrix to decompose.

    Returns: The projection tensor.

    """
    # return r, c and d
    irreps = batched_matrix_decomposition(X).squeeze(2)
    vector_sum = 0
    vector_sum += irreps[:, :, :,:, 0, 6]
    vector_sum += irreps[:, :,:, 0, :, 5]
    vector_sum += irreps[:, :, :, :,:, 4].diagonal(dim1=-2, dim2=-1)
    return vector_sum


# get all isomorphisms between representations as an indices matrix
def fill_indices(mat: Tensor, indices: List):
    """
    Fill the indices according to indices.
    Args:
        mat: The matrix to fill.
        indices: The indices.

    Returns: The index matrix.

    """
    for i in indices:
        for j in indices:
            mat[i, j] = 1
    return mat


def get_iso_matrix() -> Tensor:
    """
    Fills the isomorphism matrix according to the same dimension spaces.
    Returns:
    """
    iso_matrix = torch.eye(9, 9)
    # n-1 dimensional reps are isomorphic
    iso_matrix = fill_indices(iso_matrix, [4, 5, 6])
    # W0+ and W1+ are isomorphic, same for W1
    iso_matrix = fill_indices(iso_matrix, [0, 2])
    iso_matrix = fill_indices(iso_matrix, [1, 3])
    return iso_matrix
