from functools import wraps
import importlib


def model_constructor(f):
    """ Wraps the function 'f' which returns the network. An extra field 'constructor' is added to the network returned
    by 'f'. This field contains an instance of the  'NetConstructor' class, which contains the information needed to
    re-construct the network, such as the name of the function 'f', the function arguments etc. Thus, the network can
    be easily constructed from a saved checkpoint by calling NetConstructor.get() function.
    """
    @wraps(f)
    def f_wrapper(*args, **kwds):
        net_constr = NetConstructor(f.__name__, f.__module__, args, kwds)
        output = f(*args, **kwds)
        if isinstance(output, (tuple, list)):
            # Assume first argument is the network
            output[0].constructor = net_constr
        else:
            output.constructor = net_constr
        return output
    return f_wrapper


class NetConstructor:
    """ Class to construct networks. Takes as input the function name (e.g. atom_resnet18), the name of the module
    which contains the network function (e.g. ltr.models.bbreg.atom) and the arguments for the network
    function. The class object can then be stored along with the network weights to re-construct the network."""
    def __init__(self, fun_name, fun_module, args, kwds):
        """
        args:
            fun_name - The function which returns the network
            fun_module - the module which contains the network function
            args - arguments which are passed to the network function
            kwds - arguments which are passed to the network function
        """
        self.fun_name = fun_name
        self.fun_module = fun_module
        self.args = args
        self.kwds = kwds

    def get(self):
        """ Rebuild the network by calling the network function with the correct arguments. """
        net_module = importlib.import_module(self.fun_module)
        net_fun = getattr(net_module, self.fun_name)
        return net_fun(*self.args, **self.kwds)
