import torch
from train.pytorch_wrapper.processors import IdentityProcessor


class Predictor(object):
    """
    Convenience wrapper for pytorch model at application time

    Parameters
    ----------
    param_file : str
        Path to model parameters
    model : data-type, optional
        The desired data-type for the array, e.g., `numpy.int8`.  Default is
        `numpy.float64`.
    input_processor : pytorch_wrapper.processors.Processor, optional, default: IdentityProcessor
        prepares input for module calls
    device : str, optional, default: None
        device where to process the model (options: None, 'cpu', 'cuda')
    """

    def __init__(self, model, param_file=None, input_processor=IdentityProcessor(), device=None):
        self.model = model
        self.input_processor = input_processor

        if device is not None:
            self.device = device
        else:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if param_file is not None:
            self.initialize_model(param_file)

    def initialize_model(self, param_file):
        """
        Initialize prediction model.

            - loads model parameters from file
            - sets model to eval mode
            - puts model to appropriate device

        Parameters
        ----------
        param_file : str
            Path to model parameters
        """

        # load model parameters
        state_dict = self.model.state_dict()
        state_dict.update(torch.load(param_file))
        self.model.load_state_dict(state_dict)

        # set model into evaluation mode
        self.model.eval()

        # put model to appropriate device
        self.model.to(self.device)

    def forward(self, inp):
        """
        Predict model output.

        Parameters
        ----------
        inp : any python object
            model input
        """

        # prepare input for model
        model_input = self.input_processor.process(inp)

        # call model
        with torch.no_grad():

            # put input to selected device
            for key in model_input.keys():
                model_input[key] = model_input[key].to(self.device)

            # predict model output
            output_dict = self.model(model_input)

        return output_dict

    def predict(self, inp, layer_key=None, output_processor=IdentityProcessor()):
        """
        Predict model output.

        Parameters
        ----------
        inp : any python object
            model input
        layer_key : str
            output key of desired layer
        output_processor : pytorch_wrapper.processors.Processor, optional, default: IdentityProcessor
            post-processes model output
        """

        # compute model output
        output_dict = self.forward(inp)

        # select appropriate layer
        if layer_key is None:
            layer_key = list(output_dict.keys())[0]
        activation = output_dict[layer_key]

        # post process output
        activation = output_processor.process(activation)

        return activation
