# A demonstration script for image recovering of MNIST with RKHM associated with k_n^{prod,q}:
# "Spectral Truncation Kernels: Noncommutativity in C*-algebraic Kernel Machines".

import numpy.random as nr
import numpy.linalg as alg
import numpy as np
import copy
np.set_printoptions(1000)
import idx2numpy
from statistics import mean
from scipy.linalg import toeplitz

n=50  # truncation parameter
nn=28*28  # number of discretization points (number of pixels)
eps=0.01  # regularization parameter
st=0  # the first index for the training data

datanum=200  # number of training samples
testnum=200  # number of test samples

if __name__ == '__main__':

    ydata = idx2numpy.convert_from_file('./train-images.idx3-ubyte')
    ydata=ydata/255 

    ydataori=ydata[st:st+datanum,:,:]
    testdataori=ydata[st+datanum:st+datanum+testnum,:,:]

    ydata1=copy.deepcopy(ydataori)
    testdata=copy.deepcopy(testdataori)

    ydata1[0:datanum,10:18,10:18]=np.zeros((datanum,8,8))
    testdata[0:datanum,10:18,10:18]=np.zeros((testnum,8,8))

    ydata1=np.reshape(ydata1,[datanum,nn])
    testdata=np.reshape(testdata,[testnum,nn])
    ydataori=np.reshape(ydataori,[datanum,nn])
    testdataori=np.reshape(testdataori,[testnum,nn])
            
    t=np.arange(0,2*np.pi,2*np.pi/nn)
    fou=np.zeros((n,nn),dtype=complex)
    fou2=np.zeros((n,nn),dtype=complex)
    for l in range(n):
      fou[l,:]=np.exp(-1j*l*t)
      fou2[l,:]=np.exp(1j*l*t)
    S=np.zeros((nn,n,n),dtype=complex)
    for i in range(nn):
      S[i,:,:]=toeplitz(np.exp(1j*np.arange(0,n,1)*t[i]),np.exp(-1j*np.arange(0,n,1)*t[i]))
    toe=np.zeros((datanum,datanum,n,n),dtype=complex)
    for i in range(datanum):
      for j in range(datanum):
        toe[i,j,:,:]=toeplitz(np.exp(-0.1*abs(ydata1[i,:]-ydata1[j,:])**2).dot(fou.T)*2*np.pi/nn,np.exp(-0.1*abs(ydata1[i,:]-ydata1[j,:])**2).dot(fou2.T)*2*np.pi/nn)

    GG=np.zeros((nn,datanum,datanum),dtype=complex)
    for i in range(datanum):
      for j in range(datanum):
        GG[:,i,j]=1/n*np.sum(np.tensordot(np.ones(nn),toe[i,j,:,:].T.conjugate().dot(toe[i,j,:,:]),axes=0)*S,axis=(1,2))+0.01*abs(toe[i,j,0,0])**2*np.ones(nn)

    w,_=alg.eig(GG[0,:,:])
    c=alg.solve(GG+np.tensordot(np.ones(nn),eps*np.eye(datanum),axes=0),ydataori.T)

    toe2=np.zeros((testnum,datanum,n,n),dtype=complex)
    for i in range(testnum):
      for j in range(datanum):
        toe2[i,j,:,:]=toeplitz(np.exp(-0.1*abs(testdata[i,:]-ydata1[j,:])**2).dot(fou.T)*2*np.pi/nn,np.exp(-0.1*abs(testdata[i,:]-ydata1[j,:])**2).dot(fou2.T)*2*np.pi/nn)
    
    GG=np.zeros((nn,testnum,datanum),dtype=complex)
    for i in range(testnum):
      for j in range(datanum):
        GG[:,i,j]=1/n*np.sum(np.tensordot(np.ones(nn),toe2[i,j,:,:].T.conjugate().dot(toe2[i,j,:,:]),axes=0)*S,axis=(1,2))+0.01*abs(toe2[i,j,0,0])**2*np.ones(nn)

    sol=np.matmul(GG,c.reshape([nn,datanum,1])).reshape([nn,testnum])
    print("Test loss: ", np.mean(alg.norm(sol.T-testdataori,axis=0)*2*np.pi/nn))
