# net.py

import torch.nn as nn
import torch.nn.functional as F

__all__ = ['EncDecGaussian', 'Gaussian', 'AdvGaussian']


class EncDecGaussian(nn.Module):
    
    def __init__(self, ndim, nout, r, hdl):
        super(EncDecGaussian, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(ndim, hdl),
            nn.PReLU(),
            nn.BatchNorm1d(hdl),
            nn.Linear(hdl, int(hdl/2)),
            nn.PReLU(),
            nn.BatchNorm1d(int(hdl/2)),
            nn.Linear(int(hdl/2), r)
        )

        self.decoder = nn.Sequential(
            nn.Linear(r, hdl),
            nn.PReLU(),
            nn.BatchNorm1d(hdl),
            # nn.Dropout(p=0.5),
            nn.Linear(hdl, 2*hdl),
            nn.PReLU(),
            nn.Linear(2*hdl, nout),
        )

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return z, out


class Gaussian(nn.Module):

    def __init__(self, ndim, nout, r, hdl):
        super(Gaussian, self).__init__()

        self.decoder = nn.Sequential(
            nn.Linear(r, hdl),
            nn.PReLU(),
            nn.BatchNorm1d(hdl),
            # nn.Dropout(p=0.5),
            nn.Linear(hdl, 2*hdl),
            nn.PReLU(),
            nn.Linear(2*hdl, nout),
        )

    def forward(self, x):
        out = self.decoder(x)
        return out


class AdvGaussian(nn.Module):

    def __init__(self, ndim, nout, r, hdl):
        super(AdvGaussian, self).__init__()

        self.decoder = nn.Sequential(
            nn.Linear(r, hdl),
            nn.PReLU(),
            nn.BatchNorm1d(hdl),
            # nn.Dropout(p=0.5),
            nn.Linear(hdl, 2*hdl),
            nn.PReLU(),
            nn.Linear(2*hdl, 1),
        )

    def forward(self, x):
        out = self.decoder(x)
        return out

