import torch
from torch import nn
from typing import List

class MLP(torch.nn.Module):
    def __init__(self, data_dim, w=64, time_varying=False, conditional_dim=1):
        super().__init__()
        self.time_varying     = time_varying
        self.conditional_dim  = conditional_dim
        self.data_dim         = data_dim
        
        out_dim = (conditional_dim+1) * data_dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(data_dim + (1 if time_varying else 0) + conditional_dim, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, x):
        return self.net(x).reshape(-1,self.data_dim,1+self.conditional_dim)

class FlowNet(torch.nn.Module):
    def __init__(self, mynet):
        super().__init__()
        self.net = mynet
        self.data_dim = mynet.data_dim
        self.conditional_dim = mynet.conditional_dim

    def forward(self, x):
        return self.net(x).reshape(-1,self.data_dim,1+self.conditional_dim)

class DeepMLP4EFM(nn.Sequential):
    def __init__(
        self,
        data_dim: int,
        conditional_dim: int,
        time_varying: bool = False,
        hidden_features: List[int] = [128]*6,
    ):
        layers = []
        self.data_dim = data_dim
        self.conditional_dim = conditional_dim 
        self.in_features = data_dim + (1 if time_varying else 0) + conditional_dim
        self.out_features = (self.conditional_dim+1) * self.data_dim
        self.hidden_features = hidden_features

        for a, b in zip(
            (self.in_features, *hidden_features),
            (*hidden_features, self.out_features),
        ):
            layers.extend([nn.Linear(a, b), nn.ELU()])

        super().__init__(*layers[:-1])



class GradModel(torch.nn.Module):
    def __init__(self, action):
        super().__init__()
        self.action = action

    def forward(self, x):
        x = x.requires_grad_(True)
        grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0]
        return grad[:, :-1]