from abc import ABCMeta, abstractmethod

import torch

from XXX.uib.utils.safe_module import SafeModule


class DecoderInterface(SafeModule, metaclass=ABCMeta):
    @abstractmethod
    def safe_forward(self, encoding):
        """Returns pre-log-softmax values!"""
        # decode the encoding.
        return encoding

    @abstractmethod
    def reset(self):
        pass

    @abstractmethod
    def fit(self, encodings, targets):
        pass


class PassthroughDecoder(DecoderInterface):
    decoder: torch.nn.Module
    disable_reset: bool

    def __init__(self, decoder: torch.nn.Module, *, disable_reset=False):
        super().__init__()

        self.decoder = decoder
        self.disable_reset = disable_reset

    def safe_forward(self, encoding):
        return self.decoder(encoding)

    def reset(self):
        def reset_parameters(module):
            if hasattr(module, "reset_parameters"):
                module.reset_parameters()

        if not self.disable_reset:
            self.decoder.apply(reset_parameters)

    def fit(self, encodings, targets):
        # We train through the decoder module.
        pass
