from torch.nn import Module

from coach.sensoriplexer import Sensoriplexer

#
# The coach wrapper module.
#
class Coach(Module):

    def __init__(self, model:Module, *model_signals, **all_signals) -> None:
        '''
        The coached main argument is a component to coach. It just needs
        to comply with Torch's Module interface.

        Intended usage:

            system = Coach(model)

        Where `system` is functionally equivalent to `model`, augmented with
        the Coach capabilities.

        The `model_signals` lists the names of signals expected by the model.
        The names must appear in the `all_signals` structure. This list lets
        the coach which signals are "native" to the model.

        The `all_signals` keyword arguments specifies all input signals,
        including the ones listed in `model_signals`, and extra ones the
        coach will allow the model to leverage. The signals are of the form:

            {
                's1': shape_tuple_1,
                ...
                'sN': shape_tuple_N
            }

        where `si` is a signal name string (e.g. `image`), and `shape_tuple_i`
        is a tuple describing the shape of the corresponding signal (e.g.
        `(28, 28, 1)` for a MNIST image).
        '''
        self.model = model
        self.model_signals = model_signals
        self.sensoriplexer = Sensoriplexer()
        for k, v in all_signals.items():
            self.sensoriplexer.add(str(k), tuple(v))

    def forward(self, signals):
        '''
        This is a bit lazy adapation to the PyTorch convention.

        `signals` is a dict of the form:

            {
                's1': t1,
                ...
                'sN': tN
            }

        where `si` is a signal name string (e.g. `image`), and must be
        one registered at `init` time. `ti` is a tensor for the corresponding
        signal.
        '''
        __s = self.sensoriplexer(**signals)

        return self.model(*[__s[s] for s in self.model_signals])
