from typing import *

from torch import nn

from utils import assert_type


class MultiVariableSequence(nn.Module):
    def __init__(self, modules_with_aux: List[Tuple[nn.Module, Optional[List[Any]]]]):
        super().__init__()
        self.list_module = nn.ModuleList()
        self.list_aux = list()
        for module, aux in modules_with_aux:
            assert_type(module, nn.Module)
            assert_type(aux, list, allow_none=True)

            self.list_module.append(module)
            self.list_aux.append(aux)
        # endfor
    # enddef

    def forward(self, inputs):
        x = inputs

        for module, aux in zip(self.list_module, self.list_aux):  # type: nn.Module, List[Any]
            if aux is None or len(aux) == 0:
                if isinstance(module, nn.ModuleList):
                    for layer in module:
                        x = layer(x)
                    # endfor
                else:
                    x = module(x)
                # endif
            else:
                if isinstance(module, nn.ModuleList):
                    for layer in module:
                        x = layer(x, *aux)
                    # endfor
                else:
                    x = module(x, *aux)
                # endif
            # endif
        # endfor

        return x
    # enddef
# endclass
