from torch import nn
from torch.nn import functional as F

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(4096, 512)  
        self.fc2 = nn.Linear(512, 512)   
        self.fc3 = nn.Linear(512, 4096)  
        
    def forward(self, x):
        x = F.relu(self.fc1(x))  
        x = F.relu(self.fc2(x))  
        x = self.fc3(x)          
        return x
    
class DouGen(nn.Module):
    def __init__(self):
        super(DouGen, self).__init__()
        self.gen1 = SimpleNN()
        self.gen2 = SimpleNN()
    def forward(self, inputs1, inputs2):
        out1_2 = self.gen2(inputs1)
        out2_1 = self.gen1(inputs2)
        out1_1 = self.gen1(out1_2)
        out2_2 = self.gen2(out2_1)
        return out1_2, out2_1, out1_1, out2_2