# A demonstration script for an operator learning task with RKHM associated with k_n^{prod,q}:
# "Spectral Truncation Kernels: Noncommutativity in C*-algebraic Kernel Machines".

import numpy.random as nr
import numpy.linalg as alg
import numpy as np
np.set_printoptions(1000)
from statistics import mean
from scipy.linalg import toeplitz
import matplotlib.pyplot as plt

n=5  # truncation parameter
nn=128  # number of discretization points
lam=0.01  # regularization parameter

datanum=1000  # number of training samples
testnum=200  # number of test samples

if __name__ == '__main__':

    dt=0.01
    A=(np.eye(nn)-np.diag(np.ones(nn-1),k=-1))/(2*np.pi/nn)
    B=-(2*np.eye(nn)-np.diag(np.ones(nn-1),k=-1)-np.diag(np.ones(nn-1),k=1))/(2*np.pi/nn)/(2*np.pi/nn)
    C=alg.inv(-B+25*np.eye(nn))
    C=625*C.dot(C)
    t=np.arange(0,2*np.pi,2*np.pi/nn)
    ydata=np.zeros((datanum,1,nn))
    testdata=np.zeros((testnum,1,nn))
    for i in range(datanum):
        u=nr.multivariate_normal(np.zeros(nn),C)
        ydata[i,0,:]=u

    for i in range(testnum):
        u=nr.multivariate_normal(np.zeros(nn),C)
        testdata[i,0,:]=u

    tmp=np.zeros((datanum+testnum,100+20,nn))
    tmp[0:datanum,0,:]=ydata[:,0,:]
    tmp[datanum:datanum+testnum,0,:]=testdata[:,0,:]
    for i in range(100+19):
      tmp[:,i+1,:]=tmp[:,i,:]+dt*(-A.dot(tmp[:,i,:].T).T*tmp[:,i,:]+0.1*B.dot(tmp[:,i,:].T).T)

    ydataori=tmp[0:datanum,100,:]
    testdataori=tmp[datanum:datanum+testnum,100,:]


    fou=np.zeros((n,nn),dtype=complex)
    fou2=np.zeros((n,nn),dtype=complex)
    for l in range(n):
      fou[l,:]=np.exp(-1j*l*t)
      fou2[l,:]=np.exp(1j*l*t)
    S=np.zeros((nn,n,n),dtype=complex)
    for i in range(nn):
      S[i,:,:]=toeplitz(np.exp(1j*np.arange(0,n,1)*t[i]),np.exp(-1j*np.arange(0,n,1)*t[i]))
    toe=np.zeros((datanum,datanum,n,n),dtype=complex)
    for i in range(datanum):
      for j in range(datanum):
        toe[i,j,:,:]=toeplitz(np.exp(-1*(abs(ydata[i,0,:]-ydata[j,0,:])**2)).dot(fou.T)*2*np.pi/nn,np.exp(-1*(abs(ydata[i,0,:]-ydata[j,0,:])**2)).dot(fou2.T)*2*np.pi/nn)

    GG=np.zeros((nn,datanum,datanum),dtype=complex)
    for i in range(datanum):
      for j in range(datanum):
        GG[:,i,j]=1/n*np.sum(np.tensordot(np.ones(nn),toe[i,j,:,:].T.conjugate().dot(toe[i,j,:,:]),axes=0)*S,axis=(1,2))+1*abs(toe[i,j,0,0])**2*np.ones(nn)

    c=alg.solve(GG+np.tensordot(np.ones(nn),lam*np.eye(datanum),axes=0),ydataori.T)

    toe2=np.zeros((testnum,datanum,n,n),dtype=complex)
    for i in range(testnum):
      for j in range(datanum):
        toe2[i,j,:,:]=toeplitz(np.exp(-1*(abs(testdata[i,0,:]-ydata[j,0,:])**2)).dot(fou.T)*2*np.pi/nn,np.exp(-1*(abs(testdata[i,0,:]-ydata[j,0,:])**2)).dot(fou2.T)*2*np.pi/nn)

    GG=np.zeros((nn,testnum,datanum),dtype=complex)
    for i in range(testnum):
      for j in range(datanum):
        GG[:,i,j]=1/n*np.sum(np.tensordot(np.ones(nn),toe2[i,j,:,:].T.conjugate().dot(toe2[i,j,:,:]),axes=0)*S,axis=(1,2))+1*abs(toe2[i,j,0,0])**2*np.ones(nn)

    sol=np.matmul(GG,c.reshape([nn,datanum,1])).reshape([nn,testnum])
    print("Test loss: ", np.mean(np.abs(sol.T-testdataori)))
