import torch.nn as nn


class FakeL(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 1)

    def forward(self, s):
        return s[..., 1]**2 + s[..., 0]**2 + 0.1


class FakeL2(FakeL):
    def forward(self, s):
        return (s[..., 0] / 0.6)**2 + 0.01 * s[..., 1:].pow(2).sum(dim=-1)
