import torch
import torch.nn as nn


def create_mlp(memory_length, weights,out=1):
    return nn.Sequential(nn.Linear(2+memory_length, weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,out))

def create_mlp_2(memory_length, weights,out=2):
    return nn.Sequential(nn.Linear(3+memory_length*2, weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,out))
def create_mlp_3(memory_length, weights,out=2):
    return nn.Sequential(nn.Linear(3+memory_length*2, weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,out))



def create_mlp_jump(memory_length, weights,no_modes=2):
    return nn.Sequential(nn.Linear(2+memory_length, weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,1+no_modes*2 + no_modes))

def create_mlp_jump_gauss(memory_length, weights,no_modes=1):
    return nn.Sequential(nn.Linear(2+memory_length, weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,3))
def create_mlp_jump_gauss_2(memory_length, weights,no_modes=1):
    return nn.Sequential(nn.Linear(3+memory_length*2, weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,3))
def create_mlp_jump_gauss_3(memory_length, weights,no_modes=1):
    return nn.Sequential(nn.Linear(3+memory_length*2, weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,weights), nn.ReLU(),
                    nn.Linear(weights,3))
