# A demonstration script for deep Koopman-layered model:
# "Deep Koopman-layered Model with Universal Property Based on Toeplitz Matrices".

import numpy.random as nr
import random
import numpy.linalg as alg
import tensorflow as tf
from tensorflow.keras import layers, initializers, activations, regularizers, constraints, backend, losses, models, datasets
from tensorflow.keras.models import Model
from tensorflow.python.keras.engine.base_layer import Layer
import numpy as np
import copy
import cmath
np.set_printoptions(1000)

epochs=20000  # number of epochs
n=11  # number of Fourier basis regarding the first and the second variables
n3=11  # number of Fourier basis regarding the third variable
nhalf=round(n/2)
n3half=round(n3/2)
d=2  # dimension of the dynamical system
itnum=50  # number of iterations of the Krylov subspace method
datanum=1000  # number of time-series for training
testnum=1000  # number of time-series for tresting
dt=0.01  # step size for the discretization to generate the dataset

ind=np.zeros((n,n,n3,3))
for i in range(n):
  for j in range(n):
    for k in range(n3):
      ind[i,j,k,0]=-nhalf+i
      ind[i,j,k,1]=-nhalf+j
      ind[i,j,k,2]=-n3half+k
ind=tf.constant(ind,dtype=tf.float32)

D1=np.zeros((n,n,n3),dtype=complex)
for i in range(n):
  D1[i,:,:]=1j*(-nhalf+i)*np.ones((n,n3))
D1=tf.constant(D1,dtype=tf.complex64)
D2=np.zeros((n,n,n3),dtype=complex)
for i in range(n):
  D2[:,i,:]=1j*(-nhalf+i)*np.ones((n,n3))
D2=tf.constant(D2,dtype=tf.complex64)
D3=np.zeros((n,n,n3),dtype=complex)
for i in range(n3):
  D3[:,:,i]=1j*(-n3half+i)*np.ones((n,n))
D3=tf.constant(D3,dtype=tf.complex64)

def vhat(x,i):
  return tf.constant(tf.math.sin(i)*x[0]+tf.math.cos(i)*x[1],dtype=tf.float32)


class Koopman(Layer):
  def __init__(self,
               M,
               M3,
               R,
               kernel_initializer='glorot_uniform',
               kernel_initializer2='glorot_uniform',
               kernel_initializer3='glorot_uniform',
               kernel_imag_initializer='glorot_uniform',
               kernel_imag_initializer2='glorot_uniform',
               kernel_imag_initializer3='glorot_uniform',
               kernel_regularizer=None,
               kernel_constraint=None,
               activity_regularizer=None,
               **kwargs):
    if 'input_shape' not in kwargs and 'input_dim' in kwargs:
      kwargs['input_shape'] = (kwargs.pop('input_dim'),)

    super(Koopman, self).__init__(
        activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
    self.mask=None
    self.M=int(M)
    self.M3=int(M3)
    self.R=int(R)
    self.kernel_initializer = initializers.get(kernel_initializer)
    self.kernel_imag_initializer = initializers.get(kernel_imag_initializer)
    self.kernel_initializer2 = initializers.get(kernel_initializer2)
    self.kernel_imag_initializer2 = initializers.get(kernel_imag_initializer2)
    self.kernel_initializer3 = initializers.get(kernel_initializer3)
    self.kernel_imag_initializer3 = initializers.get(kernel_imag_initializer3)
    self.kernel_regularizer = regularizers.get(kernel_regularizer)
    self.kernel_constraint = constraints.get(kernel_constraint)

    self.supports_masking = True

  def build(self, input_shape):
    dtype = tf.dtypes.as_dtype(self.dtype or backend.floatx())

    self.kernel1=self.add_weight(
        'kernel',
        shape=[d+1,self.R, self.M-1,2*self.M-1,2*self.M3-1],
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)
    self.kernel2=self.add_weight(
        'kernel',
        shape=[d+1,self.R, 1,self.M-1,2*self.M3-1],
        initializer=self.kernel_initializer2,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)
    self.kernel3=self.add_weight(
        'kernel',
        shape=[d+1,self.R, 1,1,self.M3-1],
        initializer=self.kernel_initializer3,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)

    self.kernel1_imag=self.add_weight(
        'kernel_imag',
        shape=[d+1,self.R, self.M-1,2*self.M-1,2*self.M3-1],
        initializer=self.kernel_imag_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)
    self.kernel2_imag=self.add_weight(
        'kernel_imag',
        shape=[d+1,self.R, 1,self.M-1,2*self.M3-1],
        initializer=self.kernel_imag_initializer2,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)
    self.kernel3_imag=self.add_weight(
        'kernel_imag',
        shape=[d+1,self.R, 1,1,self.M3-1],
        initializer=self.kernel_imag_initializer3,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)

    self.t=self.add_weight(
        't',
        shape=[1],
        initializer=initializers.Constant(value=0.1),
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=False)


    self.built = True


  def call(self, vv):
    self.kernel=tf.concat([self.kernel1,tf.concat([self.kernel2,tf.concat([self.kernel3,tf.zeros((d+1,self.R,1,1,1)),self.kernel3[:,:,:,:,::-1]],axis=4),self.kernel2[:,:,:,::-1,:]],axis=3),self.kernel1[:,:,::-1,::-1]],axis=2)
    self.kernel_imag=tf.concat([self.kernel1_imag,tf.concat([self.kernel2_imag,tf.concat([self.kernel3_imag,tf.zeros((d+1,self.R,1,1,1)),-self.kernel3_imag[:,:,:,:,::-1]],axis=4),-self.kernel2_imag[:,:,:,::-1,:]],axis=3),-self.kernel1_imag[:,:,::-1,::-1]],axis=2)

    Q=tf.reshape(vv/tf.norm(tf.reshape(vv,[n*n*n3])),(n,n,n3,1))

    for itr in range(1,itnum+1,1):
      u1=D1*Q[:,:,:,itr-1]
      for r in range(self.R):
        u1=tf.complex(tf.nn.convolution(tf.reshape(tf.math.real(u1),(1,n,n,n3,1)),tf.reshape(self.kernel[0,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")-tf.nn.convolution(tf.reshape(tf.math.imag(u1),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[0,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"),tf.nn.convolution(tf.reshape(tf.math.real(u1),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[0,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")+tf.nn.convolution(tf.reshape(tf.math.imag(u1),(1,n,n,n3,1)),tf.reshape(self.kernel[0,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"))

      u2=D2*Q[:,:,:,itr-1]
      for r in range(self.R):
        u2=tf.complex(tf.nn.convolution(tf.reshape(tf.math.real(u2),(1,n,n,n3,1)),tf.reshape(self.kernel[1,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")-tf.nn.convolution(tf.reshape(tf.math.imag(u2),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[1,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"),tf.nn.convolution(tf.reshape(tf.math.real(u2),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[1,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")+tf.nn.convolution(tf.reshape(tf.math.imag(u2),(1,n,n,n3,1)),tf.reshape(self.kernel[1,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"))

      u3=D3*Q[:,:,:,itr-1]
      for r in range(self.R):
        u3=tf.complex(tf.nn.convolution(tf.reshape(tf.math.real(u3),(1,n,n,n3,1)),tf.reshape(self.kernel[2,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")-tf.nn.convolution(tf.reshape(tf.math.imag(u3),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[2,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"),tf.nn.convolution(tf.reshape(tf.math.real(u3),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[2,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")+tf.nn.convolution(tf.reshape(tf.math.imag(u3),(1,n,n,n3,1)),tf.reshape(self.kernel[2,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"))

      u=(u1+u2+u3)[0,:,:,:,0]
      h=tf.tensordot(tf.transpose(tf.math.conj(Q),(3,0,1,2)),u,[[1,2,3],[0,1,2]])
      u=u-tf.tensordot(Q,h,[[3],[0]])
      if itr==1:
        H=tf.concat([tf.reshape(h,(itr,1)),1/(tf.norm(tf.reshape(u,[n*n*n3])))*tf.ones((1,1),dtype=tf.complex64),tf.zeros((itnum-itr,1),dtype=tf.complex64)],axis=0)
      else:
        h=tf.concat([tf.reshape(h,(itr,1)),1/(tf.norm(tf.reshape(u,[n*n*n3])))*tf.ones((1,1),dtype=tf.complex64),tf.zeros((itnum-itr,1),dtype=tf.complex64)],axis=0)
        H=tf.concat([H,h],axis=1)
      Q=tf.concat([Q,tf.reshape(u/(tf.norm(tf.reshape(u,[n*n*n3]))),(n,n,n3,1))],axis=3)

    output=tf.tensordot(Q[:,:,:,0:itnum],tf.linalg.expm(H[0:itnum,0:itnum])[:,0],[[3],[0]])*tf.norm(tf.reshape(vv,[n*n*n3]))
    self.knorm=tf.norm(tf.linalg.expm(H[0:itnum,0:itnum]),2)

    return output

class ODENet(Model):
  def __init__(self):
    super(ODENet, self).__init__()

    self.l1=Koopman(3,2,1)
    self.l2=Koopman(3,2,1)

  def call(self, vv):
    u1=self.l1(vv)
    u2=self.l2(u1)
    knormall=self.l1.knorm+self.l2.knorm

    return u2,knormall

@tf.function
def train(ODENet, traindata, testdata, label, testlabel, opt, v):

  with tf.GradientTape() as tape :
    tape.watch(ODENet.trainable_variables)
    u,knorm=ODENet(v)
    loss=0

    tmp1=tf.tensordot(u,tf.ones(datanum,dtype=tf.complex64),axes=0)*tf.math.exp(tf.complex(tf.zeros((n,n,n3,datanum)),tf.tensordot(ind,tf.concat([traindata,np.pi/2*tf.ones((datanum,1))],axis=1),[[3],[1]])))
    y1=tf.reduce_sum(tmp1,axis=(0,1,2))-tmp1[0,0,0,:]
    tmp2=tf.tensordot(u,tf.ones(datanum,dtype=tf.complex64),axes=0)*tf.math.exp(tf.complex(tf.zeros((n,n,n3,datanum)),tf.tensordot(ind,tf.concat([traindata,0*tf.ones((datanum,1))],axis=1),[[3],[1]])))
    y2=tf.reduce_sum(tmp2,axis=(0,1,2))-tmp2[0,0,0,:]
    loss=tf.reduce_mean((tf.math.real(y1)-label[:,0])**2+(tf.math.real(y2)-label[:,1])**2)
    lossreg=loss#+0.00001*tf.cast(knorm,dtype=tf.float32)

    testloss=0
    tmp1=tf.tensordot(u,tf.ones(testnum,dtype=tf.complex64),axes=0)*tf.math.exp(tf.complex(tf.zeros((n,n,n3,testnum)),tf.tensordot(ind,tf.concat([testdata,np.pi/2*tf.ones((testnum,1))],axis=1),[[3],[1]])))
    ytest1=tf.reduce_sum(tmp1,axis=(0,1,2))-tmp1[0,0,0,:]
    tmp2=tf.tensordot(u,tf.ones(testnum,dtype=tf.complex64),axes=0)*tf.math.exp(tf.complex(tf.zeros((n,n,n3,testnum)),tf.tensordot(ind,tf.concat([testdata,0*tf.ones((testnum,1))],axis=1),[[3],[1]])))
    ytest2=tf.reduce_sum(tmp2,axis=(0,1,2))-tmp2[0,0,0,:]
    testloss=tf.reduce_mean((tf.math.real(ytest1)-testlabel[:,0])**2+(tf.math.real(ytest2)-testlabel[:,1])**2)

  grad = tape.gradient(lossreg, ODENet.trainable_variables)

  opt.apply_gradients(zip(grad, ODENet.trainable_variables))
  return loss,testloss


if __name__ == '__main__':

    nr.seed(0)
    data = nr.rand(datanum+testnum,2)
    data=tf.constant(data,dtype=tf.float32)
    traindata = data[0:datanum,:]
    testdata = data[datanum:testnum+datanum,:]

    tmp=np.zeros((datanum+testnum,100+20,2))
    tmp[:,0,:]=data
    for i in range(100+19):
      tmp[:,i+1,0]=tmp[:,i,0]+dt*tmp[:,i,1]
      tmp[:,i+1,1]=tmp[:,i,1]+dt*(3*(1-tmp[:,i,0]**2)*tmp[:,i,1]-tmp[:,i,0])
    label=tf.constant(tmp[0:datanum,100,:]+0.01*nr.randn(datanum,2),dtype=tf.float32)
    testlabel=tf.constant(tmp[datanum:datanum+testnum,100,:]+0.01*nr.randn(datanum,2),dtype=tf.float32)

    ODENet = ODENet()
    opt = tf.keras.optimizers.Adam(1*1e-3)

    for i in range(n):
      for j in range(n):
        for k in range(n3):
          x=tf.constant([-np.pi+2*np.pi/n*i,-np.pi+2*np.pi/n*j,-np.pi+2*np.pi/n3*k],dtype=tf.float32)
          if i==0 and j==0 and k==0:
            v=tf.zeros((n,n,n3),dtype=tf.complex64)
          else:
            v=v+tf.cast(vhat([-np.pi+2*np.pi/n*i,-np.pi+2*np.pi/n*j],-np.pi+2*np.pi/n3*k),dtype=tf.complex64)*tf.math.exp(tf.complex(tf.zeros((n,n,n3)),tf.tensordot(-ind,x,[[3],[0]])))/n/n/n3

    ODENet.l2(ODENet.l1(v))

    print("epochs", "  train loss", "  test loss")
    for epoch in range(1, epochs + 1):
        loss,loss2= train(ODENet, traindata, testdata, label, testlabel, opt, v)

        if epoch%1==0:
            print(epoch,loss.numpy(),loss2.numpy(),flush=True)
