import torch
from torch import nn
from ..nn_structure.mixture_density_net import MixtureDensityNet


class ResponseModel(nn.Module):

    def __init__(self, dropout_ratio):
        super(ResponseModel, self).__init__()
        self.net = nn.Sequential(nn.Linear(12, 128),
                                 nn.ReLU(),
                                 nn.Dropout(dropout_ratio),
                                 nn.Linear(128, 64),
                                 nn.ReLU(),
                                 nn.Dropout(dropout_ratio),
                                 nn.Linear(64, 32),
                                 nn.ReLU(),
                                 nn.Dropout(dropout_ratio),
                                 nn.Linear(32, 1))

    def forward(self, treatment, covariate):
        feature = torch.cat([treatment, covariate], dim=1)
        return self.net(feature)
    
def build_net_for_imca():
    instrumental_net = nn.Sequential(nn.Linear(12, 128),
                                     nn.ReLU(),
                                     nn.Linear(128, 64),
                                     nn.ReLU(),
                                     nn.Linear(64, 32),
                                     nn.BatchNorm1d(32),
                                     nn.ReLU(),
                                     nn.Dropout(0.2),
                                     MixtureDensityNet(32, 16, 10))

    response_net = ResponseModel(0.2)
    return instrumental_net, response_net
