import torch
from torch import nn


class DiagonalNet(nn.Module):
    def __init__(self, alpha, L, dimD):
        super().__init__()
        ## nn.Parameter: makes the tensor trainable and tracked by autograd
        self.u = nn.Parameter(alpha / ((dimD * 2) ** 0.5) * torch.ones(dimD))
        self.v = nn.Parameter(alpha / ((dimD * 2) ** 0.5) * torch.ones(dimD))
        self.L = L
    
    def get_w(self):
        return self.u ** self.L - self.v ** self.L
    
    def forward(self, x):
        return (x @ self.get_w()).unsqueeze(-1)

