import torch
import torch.nn as nn
from torch.nn.functional import relu


class FC_no_bias(nn.Module):
    """The neural network :math:$f(x)=\sum a_i \sigma(w_i^Tx)$."""
    def __init__(self, in_dim, hid_dim, out_dim):
        super(FC_no_bias, self).__init__()
        self.fc1 = nn.Linear(in_features=in_dim, out_features=hid_dim, bias=False)
        self.fc2 = nn.Linear(in_features=hid_dim, out_features=out_dim, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = relu(x)
        x = self.fc2(x)
        return x


def penalty(model, lam):
    """
        Penalty of the model parameters, :math:`lam * \|a*w\|_1`.

        Parameters
        ----------
        model: FC_no_bias object
            Two layer NN as we defined.
        lam: float
            Penalty parameter.

        Returns
        -------
        reg_loss : scalar tensor
            Penalty.
    """
    w, a = list(model.parameters())
    w = w.T * a
    reg_loss = lam * w.norm(1)  # L_1 norm
    return reg_loss


def train(model, train_loader, criterion, optimizer, lam):
    """Train the LASSO regularized two-layer neural network model."""
    train_loss = 0
    model.train(True)
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()  # Initialize grad
        output = model(data)  # feed model
        loss = criterion(output.squeeze(), target) + penalty(model, lam)  # calculate loss
        loss.backward()  # back propagation
        optimizer.step()  # update parameters
        # train_loss += loss.item()  # sum up training loss
        train_loss += criterion(output.squeeze(), target).item()
    return train_loss / len(train_loader)


def test(model, test_loader, criterion):
    """Test error on the test dataset given a trained model."""
    model.train(False)
    test_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output.squeeze(), target).item()  # sum up batch loss
    test_loss = test_loss / len(test_loader)
    return test_loss


if __name__ == '__main__':
    pass
