import torch
import torch.nn as nn
from ..utils import check_consistency


class Network(torch.nn.Module):
    """ Network class with starndard forward method 
    and possibility to pass extra features."""

    def __init__(self, model, extra_features=None):
        super().__init__()

        # check model consistency
        check_consistency(model, nn.Module)
        self._model = model

        # check consistency and assign extra fatures
        if extra_features is None:
            self._extra_features = []
        else:
            for feat in extra_features:
                check_consistency(feat, nn.Module)
            self._extra_features = nn.Sequential(*extra_features)

        # check model works with inputs
        # TODO

    def forward(self, x):
        """
        Forward method for Network class. This class 
        implements the standard forward method, and 
        it adds the possibility to pass extra features.

        :param torch.tensor x: input of the network
        :return torch.tensor: output of the network
        """
        # extract features and append
        for feature in self._extra_features:
            x = x.append(feature(x))
        # perform forward pass
        return self._model(x)

    @property
    def model(self):
        return self._model

    @property
    def extra_features(self):
        return self._extra_features
