import torch
import numpy as np
from torch import nn

def gan_model_mog(n_latent, n_out, n_hidden, 
                  num_gen_layers, num_disc_layers,
                  activation_function):

    if num_gen_layers == 1:
        G = nn.Sequential(nn.Linear(n_latent, n_hidden),
                          activation_function,
                          nn.Linear(n_hidden, n_out),)
    elif num_gen_layers == 2:
        G = nn.Sequential(nn.Linear(n_latent, n_hidden),
                          activation_function,
                          nn.Linear(n_hidden, n_hidden),
                          activation_function,
                          nn.Linear(n_hidden, n_out),) 
    elif num_gen_layers == 3:
        G = nn.Sequential(nn.Linear(n_latent, n_hidden),
                          activation_function,
                          nn.Linear(n_hidden, n_hidden),
                          activation_function,
                          nn.Linear(n_hidden, n_hidden),
                          activation_function,
                          nn.Linear(n_hidden, n_out),)

    if num_disc_layers == 1:
        D = nn.Sequential(nn.Linear(n_out, n_hidden),
                          activation_function,
                          nn.Linear(n_hidden, 1),
                          nn.Sigmoid(),)
    elif num_disc_layers == 2:
        D = nn.Sequential(
            nn.Linear(n_out, n_hidden),
              activation_function,
              nn.Linear(n_hidden, n_hidden),
              activation_function,
              nn.Linear(n_hidden, 1),
              nn.Sigmoid(),)
        
    return G, D