# -*- coding: UTF-8 -*-

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


def generate_reptive_identity_matrix(m, n):
    if m < n:
        tensor_squre = torch.ones(m).diag()
        tensor_else = torch.zeros(m, n - m)
        tensor_res = torch.cat([tensor_squre, tensor_else], dim=1)
    elif m == n:
        tensor_res = torch.ones(m).diag()
    else:
        k = int(np.ceil(m / n))
        tensor_squre = torch.ones(n).diag()
        tensor_res = tensor_squre.repeat(k, 1)
        tensor_res = tensor_res[:m, :]

    return tensor_res


def generate_reptive_neg_identity_matrix(m, n):
    if m < n:
        tensor_squre = torch.ones(m).diag() * -1.
        tensor_else = torch.zeros(m, n - m)
        tensor_res = torch.cat([tensor_squre, tensor_else], dim=1)
    elif m == n:
        tensor_res = torch.ones(m).diag() * -1.
    else:
        k = int(np.ceil(m / n))
        tensor_squre = torch.ones(n).diag()
        tensor_res = tensor_squre.repeat(k, 1)
        tensor_res = tensor_res[:m, :] * -1.

    return tensor_res


def idi_fc_identity_init(tensor: Tensor, gain=1.0):
    tensor_shape = tensor.shape
    dim_num = len(tensor_shape)
    if dim_num != 2:
        raise ValueError("Number of input dim should be 2, not %d." % dim_num)

    m, n = tensor_shape
    if m < n:
        tensor_squre = torch.ones(m).diag()
        tensor_else = torch.zeros(m, n-m)
        tensor_res = torch.cat([tensor_squre, tensor_else], dim=1)
    elif m == n:
        tensor_res = torch.ones(m).diag()
    else:
        tensor_res = generate_reptive_identity_matrix(m, n)

    tensor_small = torch.normal(0., 1e-6, (m, n))
    tensor_small *= tensor_res
    tensor_res += tensor_small

    tensor_res *= gain

    with torch.no_grad():
        tensor.copy_(tensor_res)

    return tensor


def idi_conv_identity_init(tensor: Tensor, groups=1, gain=1.0):
    tensor_shape = tensor.shape
    dim_num = len(tensor_shape)
    if dim_num != 4:
        raise ValueError("Number of input dim should be 4, not %d." % dim_num)

    cout_total, cin, hk, wk = tensor_shape
    cout = cout_total // groups
    m, n = cout, cin * hk * wk
    if m < n:
        tensor_squre = torch.ones(m).diag()
        tensor_else = torch.zeros(m, n - m)
        tensor_res = torch.cat([tensor_squre, tensor_else], dim=1)
    elif m == n:
        tensor_res = torch.ones(m).diag()
    else:
        tensor_res = generate_reptive_identity_matrix(m, n)

    tensor_res = tensor_res.repeat(groups, 1)

    tensor_small = torch.normal(0., 1e-6, (m, n))
    tensor_small *= tensor_res
    tensor_res += tensor_small

    tensor_res *= gain

    with torch.no_grad():
        tensor.copy_(tensor_res.reshape(cout_total, cin, hk, wk))

    return tensor



def idi_fc_zero_init(tensor: Tensor, gain=1.0):
    tensor_shape = tensor.shape
    dim_num = len(tensor_shape)
    if dim_num != 2:
        raise ValueError("Number of input dim should be 2, not %d." % dim_num)

    m, n = tensor_shape
    if m < n:
        tensor_squre = torch.ones(m).diag()
        tensor_else = generate_reptive_neg_identity_matrix(m, n-m)
        tensor_res = torch.cat([tensor_squre, tensor_else], dim=1)
    else:
        if m == n:
            tensor_identity = torch.ones(m).diag()
        else:
            tensor_identity = generate_reptive_identity_matrix(m, n)
        tensor_extra = torch.cat([tensor_identity[:, -1:], tensor_identity[:, :-1]], dim=1) * -1.
        tensor_res = tensor_identity + tensor_extra

    tensor_res *= gain
    with torch.no_grad():
        tensor.copy_(tensor_res)

    return tensor


def idi_conv_zero_init(tensor: Tensor, groups=1, gain=1.0):
    tensor_shape = tensor.shape
    dim_num = len(tensor_shape)
    if dim_num != 4:
        raise ValueError("Number of input dim should be 4, not %d." % dim_num)

    cout_total, cin, hk, wk = tensor_shape
    cout = cout_total // groups

    m, n = cout, cin * hk * wk
    if m < n:
        tensor_squre = torch.ones(m).diag()
        tensor_else = generate_reptive_neg_identity_matrix(m, n - m)
        tensor_res = torch.cat([tensor_squre, tensor_else], dim=1)
    else:
        if m == n:
            tensor_identity = torch.ones(m).diag()
        else:
            tensor_identity = generate_reptive_identity_matrix(m, n)
        tensor_extra = torch.cat([tensor_identity[:, -1:], tensor_identity[:, :-1]], dim=1) * -1.
        tensor_res = tensor_identity + tensor_extra

    tensor_res = tensor_res.repeat(groups, 1)

    tensor_res *= gain
    with torch.no_grad():
        tensor.copy_(tensor_res.reshape(cout, cin, hk, wk))

    return tensor


if __name__ == '__main__':
    # group test
    batch_image_shape = [1, 4, 2, 2]
    batch_image = torch.arange(np.prod(batch_image_shape)).float().reshape(*batch_image_shape)
    conv = nn.Conv2d(4, 4, 2, groups=2, bias=False)
    idi_conv_identity_init(conv.weight, groups=1)
    res = conv(batch_image)

    fc = nn.Linear(4, 7)
    idi_fc_identity_init(fc.weight)

    tensor_data1 = torch.zeros(7, 4)
    idi_fc_zero_init(tensor_data1)

    pass
