from . import *
from .Encoder import Encoder
from .Decoder import Decoder
from .Noise import Noise

class EncoderDecoder(nn.Module):
    def __init__(self, noise_layers):
        super(EncoderDecoder, self).__init__()
        self.encoder = Encoder()
        self.noise = Noise(noise_layers)
        self.decoder = Decoder()

    def forward(self, image, message):
        encoded_image = self.encoder(image, message)
        noised_image = self.noise([encoded_image, image])
        decoded_message = self.decoder(noised_image)
        return encoded_image, noised_image, decoded_message