from copy import deepcopy
import torch.nn as nn
import torch.nn.init as init


def create_target_network(network: nn.Module):
    target_network = deepcopy(network)
    for param in target_network.parameters():
        param.requires_grad = False
    return target_network


def init_weights(
    net: nn.Module, init_type: str = "orthogonal", init_gain: float = 1.41
):
    """Initialize network weights.
    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | orthogonal
        init_gain (float)    -- scaling factor for normal and orthogonal. Defaults to `sqrt(2)`
    """

    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, "weight") and classname.find("Linear") != -1:
            if init_type == "normal":
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == "orthogonal":
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError(
                    "initialization method [%s] is not implemented" % init_type
                )
            if hasattr(m, "bias") and m.bias is not None:
                init.constant_(m.bias.data, 0.0)

    # print("initialize network with %s" % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>
