# A demonstration script for a regression task with RKHM associated with k_n^{sep,q} with the deep setting:
# "Spectral Truncation Kernels: Noncommutativity in C*-algebraic Kernel Machines".

import numpy.random as nr
import numpy.linalg as alg
import numpy as np
np.set_printoptions(1000)
from statistics import mean
from scipy.linalg import toeplitz
import time
import tensorflow as tf
from tensorflow.keras import layers, initializers, activations, regularizers, constraints, backend, losses, models, datasets
from tensorflow.keras.models import Model
from tensorflow.python.keras.engine.base_layer import Layer

n=10  # truncation parameter
nn=30  # number of discretization points 

datanum=300  # number of training samples
testnum=300  # number of test samples

epochs=50000  # number of epochs

nr.seed(0)

class Kernel(Layer):
  def __init__(self,
               tau,
               kernel_initializer='glorot_uniform',
               kernel_regularizer=None,
               kernel_constraint=None,
               activity_regularizer=None,
               **kwargs):
    if 'input_shape' not in kwargs and 'input_dim' in kwargs:
      kwargs['input_shape'] = (kwargs.pop('input_dim'),)

    super(Kernel, self).__init__(
        activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
    self.mask=None
    self.kernel_initializer = initializers.get(kernel_initializer)
    self.kernel_imag_initializer = initializers.get(kernel_initializer)
    self.kernel_regularizer = regularizers.get(kernel_regularizer)
    self.kernel_constraint = constraints.get(kernel_constraint)

    self.supports_masking = True
    self.tau=tau

  def build(self, input_shape):
    dtype = tf.dtypes.as_dtype(self.dtype or backend.floatx())

    self.kernel=self.add_weight(
        'kernel',
        shape=[datanum,2*n-1],
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)

    self.kernel_imag=self.add_weight(
        'kernel_imag',
        shape=[datanum,2*n-1],
        initializer=self.kernel_imag_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)
    

    self.built = True


  def call(self, G):
      func=np.zeros((2*n-1,nn),dtype=complex)
      for i in range(2*n-1):
          func[-n+1+i,:]=np.exp(self.tau*1j*(-n+1+i)*t)
          
      func=tf.constant(func,dtype=tf.complex64)
      coef=tf.complex(self.kernel,self.kernel_imag)
      self.cc=tf.reduce_sum(tf.tensordot(tf.ones(datanum,dtype=tf.complex64),func,axes=0)*tf.tensordot(coef,tf.ones(nn,dtype=tf.complex64),axes=0),axis=1)
      tmp=tf.matmul(tf.cast(G,dtype=tf.complex64),tf.reshape(tf.transpose(self.cc,(1,0)),(nn,datanum,1)))
      self.reg=tf.reduce_mean(tf.abs(tf.matmul(tf.reshape(tf.transpose(tf.math.conj(self.cc),(1,0)),(nn,1,datanum)),tmp)))
      return tf.reshape(tmp,(nn,datanum))

class DeepRKHM(Model):
  def __init__(self):
    super(DeepRKHM, self).__init__()

    self.l1=Kernel(np.sqrt(2))
    self.l2=Kernel(np.sqrt(3))
    self.l3=Kernel(np.sqrt(5)/2)

  def call(self, G):
    u1=self.l1(G)
    u2=u1*self.l2(G)
    u3=u2*self.l3(G)
    cc=self.l1.cc*self.l2.cc*self.l3.cc
    tmp=tf.matmul(tf.cast(tf.matmul(G,tf.matmul(G,G)),dtype=tf.complex64),tf.reshape(tf.transpose(cc,(1,0)),(nn,datanum,1)))
    self.reg=tf.reduce_mean(tf.abs(tf.matmul(tf.reshape(tf.transpose(tf.math.conj(cc),(1,0)),(nn,1,datanum)),tmp)))
    
    return [u1,u2,u3]

@tf.function
def train(Model, G, Gtest, y, ytest, opt):
        
  with tf.GradientTape() as tape :
    tape.watch(Model.trainable_variables)
    u=Model.call(G)
    loss=0
    weight=[0 for i in range(len(u))]
    weight[-1]=1
    for i in range(len(u)):
      loss=loss+weight[i]*tf.cast(tf.reduce_mean(tf.norm(u[i]-y,axis=0)/nn*2*np.pi),dtype=tf.float32)
    lossreg=loss+0.1*tf.cast(Model.reg,dtype=tf.float32)
    
    testloss=tf.reduce_mean(tf.norm(Model.call(Gtest)[-1]-ytest,axis=0)/nn*2*np.pi)
    
  grad = tape.gradient(lossreg, Model.trainable_variables)
    
  opt.apply_gradients(zip(grad, Model.trainable_variables))
  
  return loss,testloss



if __name__ == '__main__':

    t=np.arange(0,2*np.pi,2*np.pi/nn)
    ydata=np.zeros((datanum,2,nn))
    testdata=np.zeros((testnum,2,nn))
    for i in range(datanum):
        ydata[i,0,:]=np.sin(0.01*i*t)+0.01*nr.randn(nn)
        ydata[i,1,:]=np.cos(0.01*i*t)+0.01*nr.randn(nn)
        testdata[i,0,:]=np.sin(0.01*(i)*t)+0.01*nr.randn(nn)
        testdata[i,1,:]=np.cos(0.01*(i)*t)+0.01*nr.randn(nn)

    index=np.zeros(nn,dtype=np.int32)
    index[0]=nn-1
    index[1:nn]=np.arange(0,nn-1,1,dtype=np.int32)
    index2=np.zeros(nn,dtype=np.int32)
    index2[nn-1]=0
    index2[0:nn-1]=np.arange(1,nn,1,dtype=np.int32)
    
    ydataori=np.sin(np.cos(ydata[:,0,:]+ydata[:,0,index]+ydata[:,0,index2]+ydata[:,1,:]+ydata[:,1,index]+ydata[:,1,index2]))+0.001*nr.randn(nn)
    testdataori=np.sin(np.cos(testdata[:,0,:]+testdata[:,0,index]+testdata[:,0,index2]+testdata[:,1,:]+testdata[:,1,index]+testdata[:,1,index2]))+0.001*nr.randn(nn)

    fou=np.zeros((n,nn),dtype=complex)
    fou2=np.zeros((n,nn),dtype=complex)
    for l in range(n):
      fou[l,:]=np.exp(-1j*l*t)
      fou2[l,:]=np.exp(1j*l*t)
    S=np.zeros((nn,n,n),dtype=complex)
    for i in range(nn):
      S[i,:,:]=toeplitz(np.exp(1j*np.arange(0,n,1)*t[i]),np.exp(-1j*np.arange(0,n,1)*t[i]))
    toe=np.zeros((datanum,datanum,n,n),dtype=complex)
    for i in range(datanum):
      for j in range(datanum):
        toe[i,j,:,:]=toeplitz(np.exp(-1*(abs(ydata[i,0,:]-ydata[j,0,:])**2+abs(ydata[i,1,:]-ydata[j,1,:])**2)).dot(fou.T)/nn,np.exp(-1*(abs(ydata[i,0,:]-ydata[j,0,:])**2+abs(ydata[i,1,:]-ydata[j,1,:])**2)).dot(fou2.T)/nn)

    GG=np.zeros((nn,datanum,datanum),dtype=complex)
    for i in range(datanum):
      for j in range(datanum):
        GG[:,i,j]=1/n*np.sum(np.tensordot(np.ones(nn),toe[i,j,:,:].T.conjugate().dot(toe[i,j,:,:]),axes=0)*S,axis=(1,2))+1*abs(toe[i,j,0,0])**2*np.ones(nn)

    GG=tf.constant(GG,dtype=tf.complex64)
    ydataori=tf.constant(ydataori.T,dtype=tf.complex64)

    toe2=np.zeros((testnum,datanum,n,n),dtype=complex)
    for i in range(testnum):
      for j in range(datanum):
        toe2[i,j,:,:]=toeplitz(np.exp(-1*(abs(testdata[i,0,:]-ydata[j,0,:])**2+abs(testdata[i,1,:]-ydata[j,1,:])**2)).dot(fou.T)/nn,np.exp(-1*(abs(testdata[i,0,:]-ydata[j,0,:])**2+abs(testdata[i,1,:]-ydata[j,1,:])**2)).dot(fou2.T)/nn)

    GGtest=np.zeros((nn,testnum,datanum),dtype=complex)
    for i in range(testnum):
      for j in range(datanum):
        GGtest[:,i,j]=1/n*np.sum(np.tensordot(np.ones(nn),toe2[i,j,:,:].T.conjugate().dot(toe2[i,j,:,:]),axes=0)*S,axis=(1,2))+1*abs(toe2[i,j,0,0])**2*np.ones(nn)

    GGtest=tf.constant(GGtest,dtype=tf.complex64)
    testdataori=tf.constant(testdataori.T,dtype=tf.complex64)

    Model = DeepRKHM()
    opt = tf.keras.optimizers.Adam(1*1e-3)

    Model.call(GG)

    print("epochs", "  train loss", "  test loss")
    for epoch in range(1, epochs + 1):
        loss,loss2= train(Model, GG, GGtest, ydataori, testdataori, opt)

        if epoch%1==0:
            print(epoch,loss.numpy().real,loss2.numpy().real,flush=True)
