"""
Visualization utility functions and classes for TensorGalerkin
"""

import torch
import torch.nn as nn


class TensorModule(nn.Module):
    """A PyTorch module that wraps a tensor parameter"""
    
    def __init__(self, *shape):
        super().__init__()
        self.tensor = nn.Parameter(torch.zeros(shape))
        self.reset_parameters(torch.rand(shape))
    
    def reset_parameters(self, weight):
        """Reset parameters with given weight"""
        self.tensor.data.copy_(weight)
    
    def forward(self):
        """Forward pass returns the tensor"""
        return self.tensor