from collections.abc import Iterable

import torch
from torch import nn

class DenseLayer(nn.Module):
    def __init__(
            self, out_features, in_features=None, bias=True, dropout=None, activation=None,
            batch_norm=False, l1=None, l2=None):
        super(DenseLayer, self).__init__()

        self._linear = (
            nn.LazyLinear(out_features, bias)
            if in_features is None
            else nn.Linear(in_features, out_features, bias))
        self._activation = nn.ReLU() if activation is None else activation
        self._batch_norm = (nn.BatchNorm1d(out_features) if batch_norm else None)
        self._dropout = nn.Dropout(dropout) if dropout is not None else dropout
        self._l1 = l1
        self._l2 = l2

    def forward(self, x):
        x = self._linear(x)
        x = self._activation(x)
        if self._batch_norm is not None:
            x = self._batch_norm(x)
        if self._dropout is not None:
            x = self._dropout(x)

        return x

    def regularization(self):
        l1_reg = (torch.tensor(0.0) if self._l1 is None
            else self._l1 * sum([param.abs().sum() for param in self.parameters()]))
        l2_reg = (torch.tensor(0.0) if self._l2 is None
            else self._l2 * sum([param.pow(2).sum() for param in self.parameters()]))

        return l1_reg + l2_reg

class FFNNet(nn.Module):
    def __init__(self, layers, dropout=None, batch_norm=None, l1=None, l2=None):
        super(FFNNet, self).__init__()

        self._l1 = [l1] * len(layers) if not isinstance(l1, Iterable) else l1
        self._l2 = [l2] * len(layers) if not isinstance(l2, Iterable) else l2
        self._dropout = [dropout] * len(layers) if not isinstance(dropout, Iterable) else dropout
        self._batch_norm = (
            [batch_norm] * len(layers) if not isinstance(batch_norm, Iterable) else batch_norm)

        if len(layers) == 1:
            self._layers = [
                DenseLayer(
                    layers[0], activation=nn.Identity(), dropout=self._dropout[0],
                    batch_norm=self._batch_norm[0], l1=self._l1[0], l2=self._l2[0])
            ]
        elif len(layers) == 2:
            self._layers = [
                DenseLayer(
                    out_features=layers[0], dropout=self._dropout[0],
                    batch_norm=self._batch_norm[0], l1=self._l1[0], l2=self._l2[0]),
                DenseLayer(
                    in_features=layers[0], out_features=layers[1], activation=nn.Identity(),
                    dropout=self._dropout[1], batch_norm=self._batch_norm[1], l1=self._l1[1], l2=self._l2[1])
            ]
        else:
            self._layers = [DenseLayer(
                out_features=layers[0], dropout=self._dropout[0], batch_norm=self._batch_norm[0],
                l1=self._l1[0], l2=self._l2[0])]
            for i in range(1, len(layers) - 1):
                self._layers.append(DenseLayer(
                    in_features=layers[i-1], out_features=layers[i], dropout=self._dropout[i],
                    batch_norm=self._batch_norm[i], l1=self._l1[i], l2=self._l2[i]))
            self._layers.append(DenseLayer(
                in_features=layers[-2], out_features=layers[-1], activation=nn.Identity(),
                dropout=self._dropout[-1], batch_norm=self._batch_norm[-1], l1=self._l1[-1],
                l2=self._l2[-1]))

        self._layers = nn.Sequential(*self._layers)

    def forward(self, x):
        return self._layers(x)

    def regularization(self):
        return sum([layer.regularization() for layer in self._layers])
