import torch
# from torch.utils.cpp_extension import load
from torch.utils.cpp_extension import load

# load the PyTorch extension
cudnn_convolution = load(name="cudnn_convolution", sources=["cudnn_convolution.cpp"])

# cudnn_convolution = load(name="cudnn_convolution", sources=["cudnn_convolution.cpp"], verbose=False, 
#                         extra_cflags=[])

# create dummy input, convolutional weights and bias
x  = torch.randn(128, 5, 32, 32).to('cuda')
weight = torch.randn(64, 5, 3, 3).to('cuda')
bias   = torch.randn(64).to('cuda')

stride   = (1, 1)
padding  = (1, 1)
dilation = (1, 1)
groups   = 1

# compute the result of convolution
output = cudnn_convolution.convolution(x, weight, bias, stride, padding, dilation, groups, False, False)

print(x.shape, weight.shape, output.shape)

# create dummy gradient w.r.t. the output
grad_output = torch.randn(128, 64, 32, 32).to('cuda')

# compute the gradient w.r.t. the weights and input
grad_weight = cudnn_convolution.convolution_backward_weight(x, weight.shape, grad_output, stride, padding, dilation, groups, False, False, False)
grad_input  = cudnn_convolution.convolution_backward_input(x.shape, weight, grad_output, stride, padding, dilation, groups, False, False, False)

print(grad_weight.shape)
print(grad_input.shape)