import torch


class LinearRegression(torch.nn.Module):
    def __init__(self, dim, out_dim=None, w=64, time_varying=False, super_cool=False):
        super().__init__()
        self.time_varying = time_varying
        self.super_cool = super_cool
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim + (1 if time_varying else 0) + (dim if super_cool else 0), out_dim ),
            torch.nn.Linear(out_dim, out_dim ),
            )

    def forward(self, x):
        return self.net(x)



class MLP(torch.nn.Module):
    def __init__(self, dim, out_dim=None, w=32, time_varying=False, super_cool = True):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            #torch.nn.BatchNorm1d(dim + (1 if time_varying else 0)),
            torch.nn.Linear(dim + (1 if time_varying else 0)  + (dim if super_cool else 0), w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            #torch.nn.BatchNorm1d(w),  # Batch normalization layer added
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            #torch.nn.BatchNorm1d(w),  # Batch normalization layer added
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, x):
        return self.net(x)


class GradModel(torch.nn.Module):
    def __init__(self, action):
        super().__init__()
        self.action = action

    def forward(self, x):
        x = x.requires_grad_(True)
        grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0]
        return grad[:, :-1]
