import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class MLP(pl.LightningModule):
    def __init__(self, input_size, output_size, hidden, lr):
        super().__init__()

        self.lr = lr

        layers = []
        for d_in, d_out in zip([input_size] + hidden, hidden + [output_size]):
            layers.append(nn.Linear(d_in, d_out))
            layers.append(nn.ReLU())
        layers = layers[:-1]
        self.layers = nn.Sequential(*layers)

    def forward(self, x, **kwargs):
        return self.layers(x)

    def training_step(self, batch, batch_idx):
        x_s, x_full = batch
        x_recon = self(x_s)
        loss = F.mse_loss(x_recon, x_full)

        self.log('loss', loss, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.layers.parameters(), lr=self.lr)
        return optimizer
