import torch
from torch.nn import Module, Parameter


class Traverse(Module):

    block_size:int

    def __init__(self, block_size:int, num_signals:int, dtype,
                 init=torch.nn.init.xavier_normal_, device=torch.device('cpu')):
        super().__init__()
        self.block_size = block_size
        size = block_size * num_signals
        self.kernel = Parameter(torch.empty((size, size), dtype=dtype).to(device), requires_grad=True)
        init(self.kernel)

    def forward(self, signal):
        # I * x + sum(1 * R * x)
        return signal @ self.kernel

    #
    # Utilities for post-training analysis.
    #
    def determinant(self, block_id:int):
        return torch.det(self.block(block_id)).item()

    def inverse(self, block_id:int):
        return torch.inverse(self.block(block_id))

    def transpose(self, block_id:int):
        return torch.transpose(self.block(block_id), 0, 1)

    def diagonal_block(self, block_id):
        bi = self.block_size * block_id
        bj = self.block_size + bi
        return self.kernel.data[bi:bj,bi:bj]
