# A demonstration script for a regression task with RKHM associated with k_n^{sep,q}:
# "Spectral Truncation Kernels: Noncommutativity in C*-algebraic Kernel Machines".

import numpy.random as nr
import numpy.linalg as alg
import numpy as np
np.set_printoptions(1000)
from statistics import mean
from scipy.linalg import toeplitz

n=16  # truncation parameter
nn=30  # number of discretization points 
eps=0.01  # regularization parameter

datanum=1000  # number of training samples
testnum=1000  # number of test samples

if __name__ == '__main__':

    t=np.arange(0,2*np.pi,2*np.pi/nn)
    ydata=np.zeros((datanum,2,nn))
    testdata=np.zeros((testnum,2,nn))
    for i in range(datanum):
        ydata[i,0,:]=np.sin(0.01*i*t)+0.01*nr.randn(nn)
        ydata[i,1,:]=np.cos(0.01*i*t)+0.01*nr.randn(nn)
        testdata[i,0,:]=np.sin(0.01*(i)*t)+0.01*nr.randn(nn)
        testdata[i,1,:]=np.cos(0.01*(i)*t)+0.01*nr.randn(nn)

    index=np.zeros(nn,dtype=np.int32)
    index[0]=nn-1
    index[1:nn]=np.arange(0,nn-1,1,dtype=np.int32)
    index2=np.zeros(nn,dtype=np.int32)
    index2[nn-1]=0
    index2[0:nn-1]=np.arange(1,nn,1,dtype=np.int32)
    
    ydataori=np.sin(np.cos(ydata[:,0,:]+ydata[:,0,index]+ydata[:,0,index2]+ydata[:,1,:]+ydata[:,1,index]+ydata[:,1,index2]))+0.001*nr.randn(nn)
    testdataori=np.sin(np.cos(testdata[:,0,:]+testdata[:,0,index]+testdata[:,0,index2]+testdata[:,1,:]+testdata[:,1,index]+testdata[:,1,index2]))+0.001*nr.randn(nn)

    fou=np.zeros((n,nn),dtype=complex)
    fou2=np.zeros((n,nn),dtype=complex)
    for l in range(n):
      fou[l,:]=np.exp(-1j*l*t)
      fou2[l,:]=np.exp(1j*l*t)
    S=np.zeros((nn,n,n),dtype=complex)
    for i in range(nn):
      S[i,:,:]=toeplitz(np.exp(1j*np.arange(0,n,1)*t[i]),np.exp(-1j*np.arange(0,n,1)*t[i]))

    toex=np.zeros((datanum,2,nn),dtype=complex)
    for i in range(datanum):
        tmp=toeplitz(ydata[i,0,:].dot(fou.T)*2*np.pi/nn,ydata[i,0,:].dot(fou2.T)*2*np.pi/nn)
        toex[i,0,:]=1/n*np.sum(np.tensordot(np.ones(nn),tmp,axes=0)*S,axis=(1,2))
        tmp=toeplitz(ydata[i,1,:].dot(fou.T)*2*np.pi/nn,ydata[i,1,:].dot(fou2.T)*2*np.pi/nn)
        toex[i,1,:]=1/n*np.sum(np.tensordot(np.ones(nn),tmp,axes=0)*S,axis=(1,2))
    toetest=np.zeros((testnum,2,nn),dtype=complex)
    for i in range(testnum):
        tmp=toeplitz(testdata[i,0,:].dot(fou.T)*2*np.pi/nn,testdata[i,0,:].dot(fou2.T)*2*np.pi/nn)
        toetest[i,0,:]=1/n*np.sum(np.tensordot(np.ones(nn),tmp,axes=0)*S,axis=(1,2))
        tmp=toeplitz(testdata[i,1,:].dot(fou.T)*2*np.pi/nn,testdata[i,1,:].dot(fou2.T)*2*np.pi/nn)
        toetest[i,1,:]=1/n*np.sum(np.tensordot(np.ones(nn),tmp,axes=0)*S,axis=(1,2))

    GG=np.zeros((datanum,datanum),dtype=complex)
    for i in range(datanum):
      for j in range(datanum):
        GG[i,j]=np.exp(-0.1*(np.sum(abs(toex[i,0,:]-toex[j,0,:])**2)+np.sum(abs(toex[i,1,:]-toex[j,1,:])**2)))

    a=np.exp(np.sin(t))
    toe=toeplitz(a.dot(fou.T)*2*np.pi/nn,a.dot(fou2.T)*2*np.pi/nn)
    c=alg.solve(np.tensordot(1/n*np.sum(np.tensordot(np.ones(nn),toe.T.conjugate().dot(toe.T.conjugate()).dot(toe).dot(toe),axes=0)*S,axis=(1,2)),GG,axes=0)+np.tensordot(np.ones(nn),eps*np.eye(datanum),axes=0),ydataori.T)

    GG=np.zeros((testnum,datanum),dtype=complex)
    for i in range(testnum):
      for j in range(datanum):
        GG[i,j]=np.exp(-0.1*(np.sum(abs(toetest[i,0,:]-toex[j,0,:])**2)+np.sum(abs(toetest[i,1,:]-toex[j,1,:])**2)))

    sol=np.matmul(np.tensordot(1/n*np.sum(np.tensordot(np.ones(nn),toe.T.conjugate().dot(toe.T.conjugate()).dot(toe).dot(toe),axes=0)*S,axis=(1,2)),GG,axes=0),c.reshape([nn,datanum,1])).reshape([nn,testnum])
    print("Test loss: ", np.mean(alg.norm(sol.T-testdataori,axis=0)**2)/nn*2*np.pi)
