import logging
import math

import torch

import utils


class Model(torch.nn.Module, utils.Namable):

    def __init__(self, descriptive_name, input_size, targets):
        """Instantiate a Model class.

        Parameters:
        ===========
        input_size: tuple of int dimensions of input.
        targets: int number of classes to predict.
        """
        torch.nn.Module.__init__(self)
        utils.Namable.__init__(self)
        self._descriptive_name = descriptive_name
        self._targets = targets
        self._input_size = tuple(input_size)
        self._feature_extractor = None
        self._classifier = None

    def get_descriptive_name(self):
        return self._descriptive_name

    def init(self):
        out = self.__class__(self._descriptive_name, self._input_size, self._targets)
        out._init_params()
        return out

    def get_paramcount(self):
        """Return int number of parameters in this model."""
        return sum(p.numel() for p in self.parameters())

    def forward(self, X):
        return self.classify_hidden_features(self.extract_features(X))

    def extract_features(self, X):
        assert X.shape[1:] == self._input_size, str(X.size())
        return self._feature_extractor(X)

    def classify_hidden_features(self, h):
        out = self._classifier(h)
        assert out.size() == (h.size(0), self._targets), str(out.size())
        return out

    # === PROTECTED ===

    @utils.abstract
    def create_layers(self, input_size):
        """Return list of torch.nn.Module objects, layers of this network.

        Parameters:
        ===========
        input_size: tuple of int dimensions of input.
        """

    @utils.abstract
    def create_classifier(self, targets):
        """Return the classifier module of this network.

        Parameters:
        ===========
        targets: int number of classes to predict.
        """

    def _init_params(self):
        self._feature_extractor = torch.nn.Sequential(
            *self.create_layers(self._input_size)
        )
        self._classifier = self.create_classifier(self._targets)
        self.reset_parameters()
        logging.info("{}: initialized {} parameters.".format(
            self.get_name(), self.get_paramcount()
        ))

    def reset_parameters(self):
        # NOTE: copied from https://github.com/owruby/shake-drop_pytorch/blob/master/models/shake_pyramidnet.py
        logging.info("Reset parameters according to He et al.")
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, torch.nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, torch.nn.Linear):
                m.bias.data.zero_()
