import torch
import torch.nn as nn


class GAINDisc(nn.Module):
    def __init__(self, inp_size):
        super(GAINDisc, self).__init__()
        self.fc1 = torch.nn.Linear(inp_size * 2, 256)
        self.fc2 = torch.nn.Linear(256, 128)
        self.fc3 = torch.nn.Linear(128, inp_size)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.init_weight() 

    def init_weight(self):
        layers = [self.fc1, self.fc2, self.fc3]
        [torch.nn.init.xavier_normal_(layer.weight) for layer in layers]
        
    def forward(self, x, m, g, h):
        org_shape = x.shape
        x = x.view(x.shape[0], -1)
        m = m.view(m.shape[0], -1)
        g = g.view(g.shape[0], -1)
        h = h.view(h.shape[0], -1)

        inp = m * x + (1-m) * g 
        inp = torch.cat((inp, h), dim=1)
        out = self.relu(self.fc1(inp))
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        
        return out.reshape(*org_shape)    


class GAINGen(torch.nn.Module):
    def __init__(self, inp_size):
        super(GAINGen, self).__init__()
        self.fc1 = torch.nn.Linear(inp_size * 2, 256)
        self.fc2 = torch.nn.Linear(256, 128)
        self.fc3 = torch.nn.Linear(128, inp_size)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.init_weight()
    
    def init_weight(self):
        layers = [self.fc1, self.fc2, self.fc3]
        [torch.nn.init.xavier_normal_(layer.weight) for layer in layers]
        
    def forward(self, x, z, m):
        org_shape = x.shape
        x = x.view(x.shape[0], -1)
        m = m.view(m.shape[0], -1)
        z = z.view(z.shape[0], -1)

        inp = m * x + (1-m) * z
        inp = torch.cat((inp, m), dim=1)
        out = self.relu(self.fc1(inp))
        out = self.relu(self.fc2(out))
        out = self.sigmoid(self.fc3(out)) 
        
        return out.reshape(*org_shape)