# -*- coding: UTF-8 -*-
import torch
import torch.nn as nn
from idi_initialization import idi_conv_identity_init

if __name__ == '__main__':
    # set input channel and output channel
    in_channel = 1
    out_channel = 3

    # define convolution
    conv = nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)

    # initialize convolution with IDInit Patch-maintain
    idi_conv_identity_init(conv.weight)

    # define input data
    data = torch.arange(1, 4 * 4 * in_channel + 1, dtype=torch.float).reshape(1, in_channel, 4, 4)

    # calculate the convolution
    res = conv(data)

    # print input channels
    print("Print Data:")
    for i in range(in_channel):
        print(data[:, i, :, :].detach())

    # print output channels
    print("Print Result:")
    for i in range(out_channel):
        print(res[:, i, :, :].detach().int())
