# A demonstration script for a regression task with RKHM associated with k_n^{prod,q}:
# "Spectral Truncation Kernels: Noncommutativity in C*-algebraic Kernel Machines".

import numpy.random as nr
import numpy.linalg as alg
import numpy as np
np.set_printoptions(1000)
from statistics import mean
from scipy.linalg import toeplitz

n=11  # truncation parameter
nn=30  # number of discretization points 
lam=0.01  # regularization parameter

datanum=1000  # number of training samples
testnum=1000  # number of test samples

if __name__ == '__main__':

    t=np.arange(0,2*np.pi,2*np.pi/nn)
    ydata=np.zeros((datanum,2,nn))
    testdata=np.zeros((testnum,2,nn))
    for i in range(datanum):
        ydata[i,0,:]=np.sin(0.01*i*t)+0.01*nr.randn(nn)
        ydata[i,1,:]=np.cos(0.01*i*t)+0.01*nr.randn(nn)
        testdata[i,0,:]=np.sin(0.01*i*t)+0.01*nr.randn(nn)
        testdata[i,1,:]=np.cos(0.01*i*t)+0.01*nr.randn(nn)

    index=np.zeros(nn,dtype=np.int32)
    index[0]=nn-1
    index[1:nn]=np.arange(0,nn-1,1,dtype=np.int32)
    index2=np.zeros(nn,dtype=np.int32)
    index2[nn-1]=0
    index2[0:nn-1]=np.arange(1,nn,1,dtype=np.int32)
    
    ydataori=np.sin(np.cos(ydata[:,0,:]+ydata[:,0,index]+ydata[:,0,index2]+ydata[:,1,:]+ydata[:,1,index]+ydata[:,1,index2]))+0.001*nr.randn(nn)
    testdataori=np.sin(np.cos(testdata[:,0,:]+testdata[:,0,index]+testdata[:,0,index2]+testdata[:,1,:]+testdata[:,1,index]+testdata[:,1,index2]))+0.001*nr.randn(nn)


    fou=np.zeros((n,nn),dtype=complex)
    fou2=np.zeros((n,nn),dtype=complex)
    for l in range(n):
      fou[l,:]=np.exp(-1j*l*t)
      fou2[l,:]=np.exp(1j*l*t)
    S=np.zeros((nn,n,n),dtype=complex)
    for i in range(nn):
      S[i,:,:]=toeplitz(np.exp(1j*np.arange(0,n,1)*t[i]),np.exp(-1j*np.arange(0,n,1)*t[i]))
    toe=np.zeros((datanum,datanum,n,n),dtype=complex)
    for i in range(datanum):
      for j in range(datanum):
        toe[i,j,:,:]=toeplitz(np.exp(-1*(abs(ydata[i,0,:]-ydata[j,0,:])**2+abs(ydata[i,1,:]-ydata[j,1,:])**2)).dot(fou.T)*2*np.pi/nn,np.exp(-1*(abs(ydata[i,0,:]-ydata[j,0,:])**2+abs(ydata[i,1,:]-ydata[j,1,:])**2)).dot(fou2.T)*2*np.pi/nn)

    GG=np.zeros((nn,datanum,datanum),dtype=complex)
    for i in range(datanum):
      for j in range(datanum):
        GG[:,i,j]=1/n*np.sum(np.tensordot(np.ones(nn),toe[i,j,:,:].T.conjugate().dot(toe[i,j,:,:]),axes=0)*S,axis=(1,2))+1*abs(toe[i,j,0,0])**2*np.ones(nn)

    c=alg.solve(GG+np.tensordot(np.ones(nn),lam*np.eye(datanum),axes=0),ydataori.T)

    toe2=np.zeros((testnum,datanum,n,n),dtype=complex)
    for i in range(testnum):
      for j in range(datanum):
        toe2[i,j,:,:]=toeplitz(np.exp(-1*(abs(testdata[i,0,:]-ydata[j,0,:])**2+abs(testdata[i,1,:]-ydata[j,1,:])**2)).dot(fou.T)*2*np.pi/nn,np.exp(-1*(abs(testdata[i,0,:]-ydata[j,0,:])**2+abs(testdata[i,1,:]-ydata[j,1,:])**2)).dot(fou2.T)*2*np.pi/nn)

    GG=np.zeros((nn,testnum,datanum),dtype=complex)
    for i in range(testnum):
      for j in range(datanum):
        GG[:,i,j]=1/n*np.sum(np.tensordot(np.ones(nn),toe2[i,j,:,:].T.conjugate().dot(toe2[i,j,:,:]),axes=0)*S,axis=(1,2))+1*abs(toe2[i,j,0,0])**2*np.ones(nn)
        
    sol=np.matmul(GG,c.reshape([nn,datanum,1])).reshape([nn,testnum])
    print("Test loss: ", np.mean(alg.norm(sol.T-testdataori,axis=0)**2)/nn*2*np.pi)
