import cupy as cp
import numpy as np
import torch
from torch import nn
import torch.autograd
import argparse
import scipy.linalg
from utils import load_cifar, load_modelnet
from NTK_tool import Kernel_Lib


parser = argparse.ArgumentParser(description = 'Convolutional Neural Tangent Kernel (CNTK) for CIFAR-10')
parser.add_argument('--depth', default = 15, type = int, help = 'depth of CNTK (#conv layers + 1)')
parser.add_argument('--depth_linear', default = 3, type = int, help = 'depth of linear (#conv layers + 1)')
parser.add_argument('--gap', default = "yes", type = str, help = 'whether GAP (global average pooling) is used')
parser.add_argument('--fix', default = "no", type = str, help = 'whether first layer and last layer are fixed (or trained) (see Section 4.2 in our paper)')
args = parser.parse_args()

d = args.depth
d_linear = args.depth_linear
gap = (args.gap == "yes")
fix = (args.fix == "yes")

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     torch.backends.cudnn.deterministic = True
setup_seed(0)
train_num = 500
# CUDA kernel for convolution operation
conv3 = cp.RawKernel(r'''
extern "C" __global__
void conv3(const float s[1024][1024], float t[1024][1024])
{
    int x1 = threadIdx.x + blockIdx.x - 1023;
    int x2 = threadIdx.x;

    __shared__ float d[1024];

    if (x1 < 0 || x1 > 1023){
        d[x2] = 0;
        return;
    }
    else
        d[x2] = s[x1][x2];
    __syncthreads();

    t[x1][x2] = d[x2] ;

}''', 'conv3')
conv_blocks = 2047,
conv_threads = 1024,

#CUDA kernel for activation
trans = cp.RawKernel(r'''
extern "C" __global__
void trans(float s[1024][1024], float t[1024][1024], const float l[1024], const float r[1024], const float il[1024], const float ir[1024])
{
    int x1 = blockIdx.x;
    int x2 = threadIdx.x + (blockIdx.y << 8);
    float S = s[x1][x2], T = t[x1][x2], L = l[x1], R = r[x2], iL = il[x1], iR = ir[x2];
    S = S * iL * iR;
    float BS = (S * (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) + sqrtf(1.0f - min(S * S, 1.0f))) * L * R / 3.1415926f;
    S = (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) / 3.1415926f;
    t[x1][x2] = T * S + BS;
    s[x1][x2] = BS;

}''', 'trans')
trans_blocks = (1024, 4)
trans_threads = 256,
#S = (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) / 9.424777960769379f;

# #CUDA kernel for convolution operation, v2
# conv3 = cp.RawKernel(r'''
# extern "C" __global__
# void conv3(const float s[1024][1024], float t[1024][1024])
# {
#     int x1 = threadIdx.x*32 + threadIdx.y + blockIdx.x*64 + blockIdx.y - 2047;
#     int x2 = threadIdx.x*32 + threadIdx.y;

#     __shared__ float d[1024 + 2];
#     if (x2 == 0){
#         d[0] = d[1025] = 0;
#     }

#     if (x1 < 0 || x1 > 1023){
#         d[x2 + 1] = 0;
#         return;
#     }
#     else
#         d[x2 + 1] = s[x1][x2];
#     __syncthreads();

#     t[x1][x2] = d[x2] + d[x2 + 1] + d[x2 + 2];

# }''', 'conv3')
# conv_blocks = (64,64)
# conv_threads = (32,32)

# #CUDA kernel for activation
# trans = cp.RawKernel(r'''
# extern "C" __global__
# void trans(float s[1024][1024], float t[1024][1024], const float l[1024], const float r[1024], const float il[1024], const float ir[1024])
# {
#     int x1 = blockIdx.x*32 + blockIdx.y;
#     int x2 = threadIdx.x*16 + threadIdx.y + ((blockIdx.z >> 7) << 8);
#     float S = s[x1][x2], T = t[x1][x2], L = l[x1], R = r[x2], iL = il[x1], iR = ir[x2];
#     S = S * iL * iR;
#     float BS = (S * (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) + sqrtf(1.0f - min(S * S, 1.0f))) * L * R / 28.274333882308138f;
#     S = (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) / 28.274333882308138;
#     t[x1][x2] = T * S + BS;
#     s[x1][x2] = BS;

# }''', 'trans')
# trans_blocks = (32,32, 512)
# trans_threads = (16,16)


#Calculate diagonal entries of $\Sigma^{(h)}(x, x)$ and their reciprocals. See Section 4.3 in our paper. 
def xx(x):
    RL = [1.0, ]
    iRL = [1.0, ]
    S = cp.matmul(x.T, x).reshape(1024, 1024)
    conv3(conv_blocks, conv_threads, (S, S))
    T = cp.zeros((1024, 1024), dtype = cp.float32)
    if not fix:
        T += S

    for i in range(1, d - 1):
        L = cp.sqrt(cp.diag(S))
        iL = 1.0 / L
        RL.append(L)
        iRL.append(iL)
        trans(trans_blocks, trans_threads, (S, T, L, L, iL, iL))
        conv3(conv_blocks, conv_threads, (S, S))
        conv3(conv_blocks, conv_threads, (T, T))

    L = cp.sqrt(cp.diag(S))
    iL = 1.0 / L
    RL.append(L)
    iRL.append(iL)
    trans(trans_blocks, trans_threads, (S, T, L, L, iL, iL))	
    
    return RL, iRL

#Caclulate the kernel value of x and z.
#Lx and Lz are diagonal entries of $\Sigma^{(h)}(x, x)$ and $\Sigma^{(h)}(z, z)$. 
#iLx and iLz are reciprocals of diagonal entries of $\Sigma^{(h)}(x, x)$ and $\Sigma^{(h)}(z, z)$. 
def xz(x, z, Lx, Lz, iLx, iLz):
    S = cp.matmul(x.T, z).reshape(1024, 1024)
    conv3(conv_blocks, conv_threads, (S, S))
    T = cp.zeros((1024, 1024), dtype = cp.float32)
    if not fix:
        T += S

    for i in range(1, d - 1):
        trans(trans_blocks, trans_threads, (S, T, Lx[i], Lz[i], iLx[i], iLz[i]))		
        conv3(conv_blocks, conv_threads, (S, S))
        conv3(conv_blocks, conv_threads, (T, T))

    trans(trans_blocks, trans_threads, (S, T, Lx[-1], Lz[-1], iLx[-1], iLz[-1]))	

    return cp.mean(T) if gap else cp.max(T), cp.mean(S) if gap else cp.max(S)

#Load CIFAR-10.
(X_train, y_train), (X_test, y_test) = load_modelnet()

# num = 500

# X_train = X_train[:num,:,:]
# y_train = y_train[:num]
# X_train = X_train[:(int)(num*0.7),:,:]
# y_train = y_train[:(int)(num*0.7)]
# X_test = X_test[:(int)(num*0.3),:,:]
# y_test = y_test[:(int)(num*0.3)]

print(X_train.shape)

channel = 3

X = np.concatenate((X_train, X_test), axis = 0)
N = X.shape[0]
N_train = X_train.shape[0]
N_test = X_test.shape[0]
X = cp.asarray(X).reshape(-1, channel, 1024)

#Calculate diagonal entries.
L = []
iL = []
for i in range(N):
    print("process:{0}/{1}".format(i,N))
    Lx, iLx = xx(X[i])	
    L.append(Lx)
    iL.append(iLx)
print("Diagonal over")

#####Calculate kernel values.
#####Below we provide a naive implementation using for-loops.
#####Parallelize this part according to your specific computing enviroment to utilize multiple GPUs.
H = np.zeros((N, N), dtype = np.float32)
Sigma = np.zeros((N, N), dtype = np.float32)
for i in range(N):
    print("process:{0}/{1}".format(i,N))
    for j in range(N):
        H[i][j], Sigma[i][j] = xz(X[i], X[j], L[i], L[j], iL[i], iL[j])

#####
device = "cuda:0"
H = torch.from_numpy(H).to(device)
Sigma = torch.from_numpy(Sigma).to(device)
ker = Kernel_Lib()

H = ker.NTK_goon(H, Sigma, d_linear, fix)

#Solve kernel regression.
H = H.cpu().numpy()
print(H)
np.save("./sample/pointNTK_ModelNet10_m.npy", H)
classes = 40
Y_train = np.ones((N_train, classes)) * -0.1
for i in range(N_train):
	Y_train[i][y_train[i]] = 0.9
u = H[N_train:, :N_train].dot(scipy.linalg.solve(H[:N_train, :N_train], Y_train))
print("test accuracy:", 1.0 * np.sum(np.argmax(u, axis = 1) == y_test) / N_test)
