import torch
import torch.nn.functional as F
import math

import decompose_fp_cuda
import pre_align_linear_cuda

M = 1000
N = 1000
K = 3 * 3 * 512
#K = 2
num_extra_bits = 2
rounding_mode = 0

a = torch.randn(M, K).cuda()
b = torch.randn(N, K).sign().cuda()
c = torch.zeros(M, N).cuda()

max_a, _ = a.abs().max(-1)

#print("A")
#print(a)
#print("B")
#print(b)

#sign = torch.zeros(M, K).cuda().int()
#exp = torch.zeros(M, K).cuda().int()
#frac = torch.zeros(M, K).cuda().int()
#decompose_fp_cuda.forward(sign, exp, frac, a)
#print("A sign")
#print(sign)
#print("A exp")
#print(exp)
#print("A frac")
#print(frac)

#print("e_aligned")
#sign = torch.zeros(M).cuda().int()
#exp = torch.zeros(M).cuda().int()
#frac = torch.zeros(M).cuda().int()
#decompose_fp_cuda.forward(sign, exp, frac, max_a)
#print(exp)

pre_align_linear_cuda.forward(max_a, a, b, c, num_extra_bits, rounding_mode)

torch_gpu = F.linear(a, b)
torch_cpu = F.linear(a.cpu(), b.cpu()).cuda()


#print("pre_align_linear_cuda result")
#print(c)
#
#print("answer")
#print(answer)

print("diff_pre_align")
diff = ((c - torch_cpu) / torch_cpu).abs()
print(diff.max())
print( diff.sum() / diff.numel() )

print("diff_torch")
diff = ((torch_gpu - torch_cpu) / torch_cpu).abs()
print( diff.max() )
print( diff.sum() / diff.numel() )
