import time
import math
import torch
import numpy as np
import MRA_n
#import MRA_scale2
import torch.nn as nn
from torch.cuda.amp import custom_fwd, custom_bwd
#from GPUtil import showUtilization as gpu_usage

if torch.cuda.is_available():
    dev = torch.device('cuda')
    print('has cuda')
else:
    dev = "cpu"
    print('no cuda')
cuda_device = torch.device(dev)

# x1: dim1 x dim2
# x2: dim1 x dim2
# y : dim1 x1
# y[i] = prod_j  sinc(x1[i][j] + 0.5*x2[i][j]^2 )

class MyFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, x1, x2, scales):
        x1 = x1.contiguous()
        x2 = x2.contiguous()
        n_head = x1.size(1)
        #scales = 2*torch.ones(n_head, dtype = torch.int)
        #scales[0]=1
        #scales[-1] = 4
        Y = MRA_n.forward(x1, x2, scales)
        variables = x1, x2, Y, scales
        ctx.save_for_backward(*variables)
        return Y

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_Y):
        grad_Y = grad_Y.contiguous()
        x1, x2, Y, scales = ctx.saved_tensors
        grads = MRA_n.backward(x1, x2, Y, grad_Y, scales)
        grad_x1, grad_x2 = grads
        return grad_x1, grad_x2, None



bsz = 64
n_head=1
qlen=256
klen=256
d_head=16


torch.manual_seed(5)
Q1 = torch.rand( (bsz, n_head, qlen, d_head), device = 0, dtype=torch.float, requires_grad=True)
K1 = torch.randn( (bsz, n_head, klen, d_head), device = 0, dtype=torch.float, requires_grad=True)
Q1.retain_grad();
K1.retain_grad();

torch.manual_seed(5)
Q0 = torch.rand( (bsz, n_head, qlen, d_head), device = 0, dtype=torch.float, requires_grad=True)
K0 = torch.randn( (bsz, n_head, klen, d_head), device = 0, dtype=torch.float, requires_grad=True)
Q0.retain_grad();
K0.retain_grad();

Q = torch.nn.functional.avg_pool1d( Q0[:,0,:,:].transpose(1,2), kernel_size=2, stride=2). transpose(1,2)
K = torch.nn.functional.avg_pool1d( K0[:,0,:,:].transpose(1,2), kernel_size=2, stride=2). transpose(1,2)
Q.retain_grad();
K.retain_grad();

Y00 = torch.matmul(Q,K.transpose(1,2)).unsqueeze(1)
print(Y00.shape)
Y0 = torch.nn.functional.interpolate(Y00,scale_factor=(2,2), mode='nearest')

tensor_0 = torch.zeros( (bsz, n_head, qlen, klen), dtype=torch.float).cuda()
tensor_1 = torch.zeros( (bsz, n_head, qlen, klen), dtype=torch.float).cuda()



loss_fn = torch.nn.MSELoss(reduction='sum')

#print(Y1.shape, Y0.shape)

start=time.time() 
for i in range(1):
    Y000 = torch.matmul(Q0, K0.transpose(2,3))
    loss0 = 0.5*loss_fn(Y000,tensor_0)
    loss0.backward(retain_graph=True)
#end
time0 =  (time.time()-start)

scales = 2*torch.ones(n_head, dtype = torch.int)
#scales[0]=1
#scales[-1] = 4

start=time.time() 
for i in range(1):
    Y1 = MyFunction.apply(Q1,K1, scales)
    loss1 = 0.5*loss_fn(Y1,tensor_1)
    loss1.backward(retain_graph=True)
#end
time1 =  (time.time()-start)

print('time0 = ',time0,', time1 = ', time1)
#loss0 = 0.5*loss_fn(Y1,tensor_0)

grad_Q = torch.matmul(Y000, K0);

AY0 = torch.nn.functional.avg_pool1d( Y000[:,0,:,:].transpose(1,2), kernel_size=2, stride=2). transpose(1,2)
AY = torch.nn.functional.interpolate(AY0.unsqueeze(1),scale_factor=(2,1), mode='nearest')
AK = torch.nn.functional.interpolate(K.unsqueeze(1),scale_factor=(2,1), mode='nearest')
AQ = torch.nn.functional.interpolate(Q.unsqueeze(1),scale_factor=(2,1), mode='nearest')
grad_Q2 = torch.matmul(AY,AK)

AY_t = torch.nn.functional.avg_pool1d( Y000[:,0,:,:], kernel_size=2, stride=2)
AY_t = torch.nn.functional.interpolate(AY_t.unsqueeze(1),scale_factor=(1,2), mode='nearest').transpose(2,3)
grad_K2 = torch.matmul(AY_t,AQ)

#print('Y_no_scale=',Y000[0,0,0:4,0:4])
print('Y_nearest =', Y0[0,0,0:4,0:4])
print('Y1        =', Y1[0,0,0:4,0:4])
print('norm(Y1-Y_nearest) = ', torch.norm(Y1-Y0, p='fro'), '\nnorm(Y_nearest)=', torch.norm(Y0,p='fro'))

print('\n')

#print('grad_Q(neatest) =', Q0.grad[0,0,0:4,0:4])
#print('grad_Q(no_scale)=', grad_Q[0,0,0:4,0:4])
print('grad_Q1     =',    Q1.grad[0,0,0,:])
print('grad_Q(forumla)=', grad_Q2[0,0,0,:])
print('norm(grad_Q2-grad_Q1) = ', torch.norm(grad_Q2-Q1.grad, p='fro'), '\nnorm(grad_Q2)=', torch.norm(grad_Q2,p='fro'))

print('\n')

#print('grad_K(no_scale)=', K0.grad[0,0,0:4,0:4])
print('grad_K(formula) =', grad_K2[0,0,0:4,0:4])
print('grad_K1         =', K1.grad[0,0,0:4,0:4])
print('norm(grad_K1-grad_K(formula)) = ', torch.norm(grad_K2-K1.grad, p='fro'), '\nnorm(grad_K(fomrula))=', torch.norm(grad_K2,p='fro'))

#df/dK = (df/dY)^T x Q 
#df/dK = (A grad_Y)^T x (AQ) 

#print('\n\n')



""" torch.manual_seed(5)
Q3 = torch.randn( (bsz, n_head, qlen, d_head), device = 0, dtype=torch.float, requires_grad=True)
K3 = torch.randn( (bsz, n_head, klen, d_head), device = 0, dtype=torch.float, requires_grad=True)
Q3.retain_grad();
K3.retain_grad();

AQ3 = torch.matmul(A,Q3)
AK3 = torch.matmul(A,K3)
Y3 = torch.matmul( AQ3, AK3.transpose(2,3))

tensor_3 = torch.zeros( (bsz, n_head, qlen, klen), dtype=torch.float).cuda()
loss3 = 0.5*loss_fn(Y3,tensor_3)
loss3.backward(retain_graph=True) """

#print('Y3=',Y3[0,0,0:4,0:4])
#print('grad(Q3)=', Q3.grad[0,0,0:4,0:4])
#print('grad(K3)=', K3.grad[0,0,0:4,0:4])

print('\n')

n_head=4
torch.manual_seed(5)
Q4 = torch.rand( (bsz, n_head, qlen, d_head), device = 0, dtype=torch.float, requires_grad=True)
K4 = torch.randn( (bsz, n_head, klen, d_head), device = 0, dtype=torch.float, requires_grad=True)
Q4.retain_grad();
K4.retain_grad();

I1 = torch.ones(256,1,1, device=0)
I2 = 0.5*torch.ones(128,2,2, device=0)
I4 = 0.25*torch.ones(64,4,4, device = 0)
A1 = torch.block_diag(*I1)
A2 = torch.block_diag(*I2)
A4 = torch.block_diag(*I4)

A = torch.cat( (A1.unsqueeze(0), A2.unsqueeze(0), A2.unsqueeze(0), A4.unsqueeze(0)) )

AQ4 = torch.matmul(A,Q4)
AK4 = torch.matmul(A,K4)
Y4_0 = torch.matmul( AQ4, AK4.transpose(2,3))
#Y = AQ (AK)^T
#df/dK = (df/dY)^T x Q 
#df/dK = (A grad_Y)^T x (AQ) 
#df/dQ = A^T grad_Y (AK)
AgradY = torch.matmul(A, Y4_0)
gradQ4 = torch.matmul( AgradY, AK4 )
gradK4 = torch.matmul( AgradY.transpose(2,3), AQ4 )

scales = 2*torch.ones(n_head, dtype = torch.int)
scales[0]=1
scales[-1] = 4

Y4 = MyFunction.apply(Q4,K4, scales)
tensor_4 = torch.zeros( (bsz, n_head, qlen, klen), dtype=torch.float).cuda()
loss4 = 0.5*loss_fn(Y4,tensor_4)
loss4.backward(retain_graph=True)

print('norm(Y4-Y4_0) = ', torch.norm(Y4-Y4_0, p='fro'), '\nnorm(Y4_0)=', torch.norm(Y4_0,p='fro')) 
print('norm(gradQ4-gradQ4_0) = ', torch.norm(gradQ4-Q4.grad, p='fro'), '\nnorm(gradQ4)=', torch.norm(gradQ4,p='fro')) 
print('norm(gradK4-gradK4_0) = ', torch.norm(gradK4-K4.grad, p='fro'), '\nnorm(gradK4)=', torch.norm(gradK4,p='fro')) 
print('gradQ4=',Q4.grad[0,0,0:4,0:4])
print('gradK4=',K4.grad[0,0,0:4,0:4])


