import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import torch as torch

class SingleAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim):
        super(SingleAE, self).__init__()
        self.encoder1 = nn.Sequential(
            # nn.Dropout(0.8),
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, z_dim),

        )
        self.decoder1 = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(z_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            # nn.Dropout(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
        )

    def forward(self, x1):
        z1 = self.encoder1(x1)
        xhat1 = self.decoder1(z1)
        return xhat1

class AE(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim):
        super(AE, self).__init__()
        self.encoder1 = nn.Sequential(
            # nn.Dropout(0.8),
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, z_dim),

        )
        self.decoder1 = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(z_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            # nn.Dropout(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
        )

        self.encoder2 = nn.Sequential(
            # nn.Dropout(0.8),
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, z_dim),

        )
        self.decoder2 = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(z_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            # nn.Dropout(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
        )

    def forward(self, x1, x2, m1, m2):
        x1 = x1*m1 + (0 * (1-m1))
        z1 = self.encoder1(x1)
        xhat1 = self.decoder1(z1)
        xhat1 = xhat1*m1 + (0 * (1-m1))
        x2 = x2 * m2 + (0 * (1 - m2))
        z2 = self.encoder1(x2)
        xhat2 = self.decoder1(z2)
        xhat2 = xhat2 * m2 + (0 * (1 - m2))
        return z1, z2, xhat1, xhat2



class AE_CIFAR10(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim):
        super(AE_CIFAR10, self).__init__()
        self.encoder1 = nn.Sequential(
            # nn.Dropout(0.8),
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0),
            nn.Linear(hidden_dim, z_dim),

        )
        self.decoder1 = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Dropout(0),
            nn.Linear(z_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            # nn.Dropout(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
        )

        self.encoder2 = nn.Sequential(
            # nn.Dropout(0.8),
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0),
            nn.Linear(hidden_dim, z_dim),

        )
        self.decoder2 = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Dropout(0),
            nn.Linear(z_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(0),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.1),
            # nn.Dropout(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
        )

    def forward(self, x1, x2, m1, m2):
        x1 = x1*m1 + (0 * (1-m1))
        z1 = self.encoder1(x1)
        xhat1 = self.decoder1(z1)
        xhat1 = xhat1*m1 + (0 * (1-m1))
        x2 = x2 * m2 + (0 * (1 - m2))
        z2 = self.encoder1(x2)
        xhat2 = self.decoder1(z2)
        xhat2 = xhat2 * m2 + (0 * (1 - m2))
        return z1, z2, xhat1, xhat2






