import torch
from lore.g.base import g

# Various g functions
class Linear(g):
    # Equivalent to nuclear norm
    def forward(self, singular_values):
        return torch.sum(singular_values)
    
class SchattenPUnnormed(g):
    def __init__(self, p, eps=1e-6):
        super(SchattenPUnnormed, self).__init__()
        self.p = p
        self.eps = torch.tensor(eps)

    def forward(self, singular_values):
        return torch.sum((singular_values + self.eps) ** self.p)
