import torch
from torch import nn


def build_net_for_imca():
    treatment_net = nn.Sequential(nn.Linear(16, 32),
                                  nn.LeakyReLU(),
                                  nn.Linear(32, 64),
                                  nn.LeakyReLU(),
                                  nn.Linear(64, 128),
                                  nn.LeakyReLU(),
                                  nn.Linear(128, 6),
                                  nn.LeakyReLU()
                                  )

    instrumental_net = nn.Sequential(nn.Linear(12, 32),
                                     nn.ReLU(),
                                     nn.Linear(32, 64),
                                     nn.ReLU(),
                                     nn.Linear(64, 128),
                                     nn.ReLU(),
                                     nn.Linear(128, 6),
                                     nn.BatchNorm1d(6))

    covariate_net = None

    return treatment_net, instrumental_net, covariate_net
