import math
import torch
from torch.nn.init import _make_deprecate


# From https://github.com/yl-1993/ConvDeltaOrthogonal-Init/blob/c8ae2a1d582e64bf9649f97e8e92b12733f8b2eb/_ext/nn/init.py#L10

""" The implementation below corresponds to Tensorflow implementation.
    Refer https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/init_ops.py for details.
"""


def conv_delta_orthogonal_(tensor, gain=1.):
    r"""Initializer that generates a delta orthogonal kernel for ConvNets.
    The shape of the tensor must have length 3, 4 or 5. The number of input
    filters must not exceed the number of output filters. The center pixels of the
    tensor form an orthogonal matrix. Other pixels are set to be zero. See
    algorithm 2 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393
    Args:
        tensor: an n-dimensional `torch.Tensor`, where :math:`3 \leq n \leq 5`
        gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
    Examples:
        >>> w = torch.empty(5, 4, 3, 3)
        >>> nn.init.conv_delta_orthogonal_(w)
    """
    if tensor.ndimension() < 3 or tensor.ndimension() > 5:
        raise ValueError("The tensor to initialize must be at least "
                         "three-dimensional and at most five-dimensional")

    if tensor.size(1) > tensor.size(0):
        raise ValueError("In_channels cannot be greater than out_channels.")

    # Generate a random matrix
    a = tensor.new(tensor.size(0), tensor.size(0)).normal_(0, 1)
    # Compute the qr factorization
    q, r = torch.qr(a)
    # Make Q uniform
    d = torch.diag(r, 0)
    q *= d.sign()
    q = q[:, :tensor.size(1)]
    with torch.no_grad():
        tensor.zero_()
        if tensor.ndimension() == 3:
            tensor[:, :, (tensor.size(2) - 1) // 2] = q
        elif tensor.ndimension() == 4:
            tensor[:, :, (tensor.size(2) - 1) // 2, (tensor.size(3) - 1) // 2] = q
        else:
            tensor[:, :, (tensor.size(2) - 1) // 2, (tensor.size(3) - 1) // 2, (tensor.size(4) - 1) // 2] = q
        tensor.mul_(math.sqrt(gain))
    return tensor
