# A demonstration script for deep Koopman-layered model for nonautonomous dynamical systems with the computation of eigenvalues of the Koopman layers:
# "Deep Koopman-layered Model with Universal Property Based on Toeplitz Matrices".

import numpy.random as nr
import random
import numpy.linalg as alg
import tensorflow as tf
from tensorflow.keras import layers, initializers, activations, regularizers, constraints, backend, losses, models, datasets
from tensorflow.keras.models import Model
from tensorflow.python.keras.engine.base_layer import Layer
import numpy as np
import copy
import cmath
np.set_printoptions(1000)

epochs=3000  # number of epochs
n=11  # number of Fourier basis regarding the first and the second variables
n3=5  # number of Fourier basis regarding the third variable
nhalf=round(n/2)
n3half=round(n3/2)
d=2  # dimension of the dynamical system
itnum=50  # number of iterations of the Krylov subspace method
sdatanum=1000  # number of time-series for training
stestnum=1000  # number of time-series for testing
datanum=20*sdatanum  # number of samples in each sub-dataset for training
testnum=20*stestnum  # number of samples in each sub-dataset for testing
dt=0.01  # step size for the discretization to generate the dataset

ind=np.zeros((n,n,n3,3))
for i in range(n):
  for j in range(n):
    for k in range(n3):
      ind[i,j,k,0]=-nhalf+i
      ind[i,j,k,1]=-nhalf+j
      ind[i,j,k,2]=-n3half+k
ind=tf.constant(ind,dtype=tf.float32)

D1=np.zeros((n,n,n3),dtype=complex)
for i in range(n):
  D1[i,:,:]=1j*(-nhalf+i)*np.ones((n,n3))
D1=tf.constant(D1,dtype=tf.complex64)
D2=np.zeros((n,n,n3),dtype=complex)
for i in range(n):
  D2[:,i,:]=1j*(-nhalf+i)*np.ones((n,n3))
D2=tf.constant(D2,dtype=tf.complex64)
D3=np.zeros((n,n,n3),dtype=complex)
for i in range(n3):
  D3[:,:,i]=1j*(-n3half+i)*np.ones((n,n))
D3=tf.constant(D3,dtype=tf.complex64)


def vhat(x,i):
  return tf.constant(tf.math.sin(i)*x[0]+tf.math.cos(i)*x[1],dtype=tf.float32)

class Koopman(Layer):
  def __init__(self,
               M,
               M3,
               R,
               kernel_initializer='glorot_uniform',
               kernel_initializer2='glorot_uniform',
               kernel_initializer3='glorot_uniform',
               kernel_imag_initializer='glorot_uniform',
               kernel_imag_initializer2='glorot_uniform',
               kernel_imag_initializer3='glorot_uniform',
               kernel_regularizer=None,
               kernel_constraint=None,
               activity_regularizer=None,
               **kwargs):
    if 'input_shape' not in kwargs and 'input_dim' in kwargs:
      kwargs['input_shape'] = (kwargs.pop('input_dim'),)

    super(Koopman, self).__init__(
        activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
    self.mask=None
    self.M=int(M)
    self.M3=int(M3)
    self.R=int(R)
    self.kernel_initializer = initializers.get(kernel_initializer)
    self.kernel_imag_initializer = initializers.get(kernel_imag_initializer)
    self.kernel_initializer2 = initializers.get(kernel_initializer2)
    self.kernel_imag_initializer2 = initializers.get(kernel_imag_initializer2)
    self.kernel_initializer3 = initializers.get(kernel_initializer3)
    self.kernel_imag_initializer3 = initializers.get(kernel_imag_initializer3)
    self.kernel_regularizer = regularizers.get(kernel_regularizer)
    self.kernel_constraint = constraints.get(kernel_constraint)

    self.supports_masking = True

  def build(self, input_shape):

    self.kernel1=self.add_weight(
        'kernel',
        shape=[d+1,self.R, self.M-1,2*self.M-1,2*self.M3-1],
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)
    self.kernel2=self.add_weight(
        'kernel',
        shape=[d+1,self.R, 1,self.M-1,2*self.M3-1],
        initializer=self.kernel_initializer2,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)
    self.kernel3=self.add_weight(
        'kernel',
        shape=[d+1,self.R, 1,1,self.M3],
        initializer=self.kernel_initializer3,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)

    self.kernel1_imag=self.add_weight(
        'kernel_imag',
        shape=[d+1,self.R, self.M-1,2*self.M-1,2*self.M3-1],
        initializer=self.kernel_imag_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)
    self.kernel2_imag=self.add_weight(
        'kernel_imag',
        shape=[d+1,self.R, 1,self.M-1,2*self.M3-1],
        initializer=self.kernel_imag_initializer2,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)
    self.kernel3_imag=self.add_weight(
        'kernel_imag',
        shape=[d+1,self.R, 1,1,self.M3-1],
        initializer=self.kernel_imag_initializer3,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=True)

    self.t=self.add_weight(
        't',
        shape=[1],
        initializer=initializers.Constant(value=0.1),
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        dtype=tf.float32,
        trainable=False)


    self.built = True


  def call(self, vv):
    self.kernel=tf.concat([self.kernel1,tf.concat([self.kernel2,tf.concat([self.kernel3,self.kernel3[:,:,:,:,self.M3-2::-1]],axis=4),self.kernel2[:,:,:,::-1,:]],axis=3),self.kernel1[:,:,::-1,::-1]],axis=2)
    self.kernel_imag=tf.concat([self.kernel1_imag,tf.concat([self.kernel2_imag,tf.concat([self.kernel3_imag,tf.zeros((d+1,self.R,1,1,1)),-self.kernel3_imag[:,:,:,:,::-1]],axis=4),-self.kernel2_imag[:,:,:,::-1,:]],axis=3),-self.kernel1_imag[:,:,::-1,::-1]],axis=2)

    Q=tf.reshape(vv/tf.norm(tf.reshape(vv,[n*n*n3])),(n,n,n3,1))

    for itr in range(1,itnum+1,1):
      u1=D1*Q[:,:,:,itr-1]
      for r in range(self.R):
        u1=tf.complex(tf.nn.convolution(tf.reshape(tf.math.real(u1),(1,n,n,n3,1)),tf.reshape(self.kernel[0,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")-tf.nn.convolution(tf.reshape(tf.math.imag(u1),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[0,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"),tf.nn.convolution(tf.reshape(tf.math.real(u1),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[0,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")+tf.nn.convolution(tf.reshape(tf.math.imag(u1),(1,n,n,n3,1)),tf.reshape(self.kernel[0,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"))

      u2=D2*Q[:,:,:,itr-1]
      for r in range(self.R):
        u2=tf.complex(tf.nn.convolution(tf.reshape(tf.math.real(u2),(1,n,n,n3,1)),tf.reshape(self.kernel[1,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")-tf.nn.convolution(tf.reshape(tf.math.imag(u2),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[1,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"),tf.nn.convolution(tf.reshape(tf.math.real(u2),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[1,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")+tf.nn.convolution(tf.reshape(tf.math.imag(u2),(1,n,n,n3,1)),tf.reshape(self.kernel[1,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"))

      u3=D3*Q[:,:,:,itr-1]
      for r in range(self.R):
        u3=tf.complex(tf.nn.convolution(tf.reshape(tf.math.real(u3),(1,n,n,n3,1)),tf.reshape(self.kernel[2,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")-tf.nn.convolution(tf.reshape(tf.math.imag(u3),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[2,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"),tf.nn.convolution(tf.reshape(tf.math.real(u3),(1,n,n,n3,1)),tf.reshape(self.kernel_imag[2,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME")+tf.nn.convolution(tf.reshape(tf.math.imag(u3),(1,n,n,n3,1)),tf.reshape(self.kernel[2,r,:,:,:],(2*self.M-1,2*self.M-1,2*self.M3-1,1,1)),padding="SAME"))

      u=(u1+u2+u3)[0,:,:,:,0]
      h=tf.tensordot(tf.transpose(tf.math.conj(Q),(3,0,1,2)),u,[[1,2,3],[0,1,2]])
      u=u-tf.tensordot(Q,h,[[3],[0]])
      
      if itr==1:
        H=tf.concat([tf.reshape(h,(itr,1)),1/(tf.norm(tf.reshape(u,[n*n*n3])))*tf.ones((1,1),dtype=tf.complex64),tf.zeros((itnum-itr,1),dtype=tf.complex64)],axis=0)
      else:
        h=tf.concat([tf.reshape(h,(itr,1)),1/(tf.norm(tf.reshape(u,[n*n*n3])))*tf.ones((1,1),dtype=tf.complex64),tf.zeros((itnum-itr,1),dtype=tf.complex64)],axis=0)
        H=tf.concat([H,h],axis=1)
      
      Q=tf.concat([Q,tf.reshape(u/(tf.norm(tf.reshape(u,[n*n*n3]))),(n,n,n3,1))],axis=3)

    output=tf.tensordot(Q[:,:,:,0:itnum],tf.linalg.expm(H[0:itnum,0:itnum])[:,0],[[3],[0]])*tf.norm(tf.reshape(vv,[n*n*n3]))
    self.knorm=tf.norm(tf.linalg.expm(H[0:itnum,0:itnum]),2)
    self.w=tf.linalg.eigvals(H[0:itnum,0:itnum])

    return output

class ODENet(Model):
  def __init__(self):
    super(ODENet, self).__init__()

    self.l1=Koopman(3,2,1)
    self.l2=Koopman(3,2,1)
    self.l3=Koopman(3,2,1)
    self.l4=Koopman(3,2,1)
    self.l5=Koopman(3,2,1)

  def call(self, vv):
    u1=self.l1(vv)
    u2=self.l2(u1)
    u3=self.l3(u2)
    u4=self.l4(u3)
    u5=self.l5(u4)

    return [u1,u2,u3,u4,u5],[self.l1.w,self.l2.w,self.l3.w,self.l4.w,self.l5.w]

@tf.function
def train(ODENet, label, testlabel, opt, v):

  with tf.GradientTape() as tape :
    tape.watch(ODENet.trainable_variables)
    u,eig=ODENet.call(v)
    
    loss=0
    for i in range(len(u)):
      tmp1=tf.tensordot(u[i],tf.ones(datanum,dtype=tf.complex64),axes=0)*tf.math.exp(tf.complex(tf.zeros((n,n,n3,datanum)),tf.tensordot(ind,tf.concat([label[:,4-i,:],np.pi/2*tf.ones((datanum,1))],axis=1),[[3],[1]])))
      y1=tf.reduce_sum(tmp1,axis=(0,1,2))
      tmp2=tf.tensordot(u[i],tf.ones(datanum,dtype=tf.complex64),axes=0)*tf.math.exp(tf.complex(tf.zeros((n,n,n3,datanum)),tf.tensordot(ind,tf.concat([label[:,4-i,:],0*tf.ones((datanum,1))],axis=1),[[3],[1]])))
      y2=tf.reduce_sum(tmp2,axis=(0,1,2))
      loss=loss+tf.reduce_mean((tf.math.real(y1)-label[:,5,0])**2+(tf.math.real(y2)-label[:,5,1])**2)

    lossreg=loss+0.01*tf.math.real(tf.reduce_mean(tf.norm(ODENet.l1.kernel-ODENet.l2.kernel)**2+tf.norm(ODENet.l1.kernel_imag-ODENet.l2.kernel_imag)**2+tf.norm(ODENet.l2.kernel-ODENet.l3.kernel)**2+tf.norm(ODENet.l2.kernel_imag-ODENet.l3.kernel_imag)**2+tf.norm(ODENet.l3.kernel-ODENet.l4.kernel)**2+tf.norm(ODENet.l3.kernel_imag-ODENet.l4.kernel_imag)**2+tf.norm(ODENet.l4.kernel-ODENet.l5.kernel)**2+tf.norm(ODENet.l4.kernel_imag-ODENet.l5.kernel_imag)**2))

    testloss=0
    for i in range(len(u)):
      tmp1=tf.tensordot(u[i],tf.ones(testnum,dtype=tf.complex64),axes=0)*tf.math.exp(tf.complex(tf.zeros((n,n,n3,testnum)),tf.tensordot(ind,tf.concat([testlabel[:,4-i,:],np.pi/2*tf.ones((testnum,1))],axis=1),[[3],[1]])))
      ytest1=tf.reduce_sum(tmp1,axis=(0,1,2))
      tmp2=tf.tensordot(u[i],tf.ones(testnum,dtype=tf.complex64),axes=0)*tf.math.exp(tf.complex(tf.zeros((n,n,n3,testnum)),tf.tensordot(ind,tf.concat([testlabel[:,4-i,:],0*tf.ones((testnum,1))],axis=1),[[3],[1]])))
      ytest2=tf.reduce_sum(tmp2,axis=(0,1,2))
      testloss=testloss+tf.reduce_mean((tf.math.real(ytest1)-testlabel[:,5,0])**2+(tf.math.real(ytest2)-testlabel[:,5,1])**2)

  grad = tape.gradient(lossreg, ODENet.trainable_variables)

  opt.apply_gradients(zip(grad, ODENet.trainable_variables))
  return loss,testloss,eig


if __name__ == '__main__':

    data = -1+2*nr.rand(sdatanum+stestnum,2)

    tmp=np.zeros((sdatanum+stestnum,100+20,2))
    tmp[:,0,:]=data
    for i in range(100+19):
      # measure preserving system
      #tmp[:,i+1,0]=tmp[:,i,0]+dt*0.1*np.sin(tmp[:,i,1])*np.exp(0.1*np.cos(tmp[:,i,0]-i*dt)+0.1*np.cos(tmp[:,i,1]))
      #tmp[:,i+1,1]=tmp[:,i,1]-dt*0.1*np.sin(tmp[:,i,0])*np.exp(0.1*np.cos(tmp[:,i,0]-i*dt)+0.1*np.cos(tmp[:,i,1]))

      # damping oscillator with external force
      tmp[:,i+1,0]=tmp[:,i,0]+dt*(-0.1*tmp[:,i,0]-tmp[:,i,1]-0.1*np.sin(dt*i))
      tmp[:,i+1,1]=tmp[:,i,1]+dt*tmp[:,i,0]

    tmp2=np.zeros((sdatanum*20,6,2))
    for i in range(20):
      tmp2[i*sdatanum:(i+1)*sdatanum,:,:]=tmp[0:sdatanum,i:120:20,:]

    tmp3=np.zeros((stestnum*20,6,2))
    for i in range(20):
      tmp3[i*stestnum:(i+1)*stestnum,:,:]=tmp[sdatanum:sdatanum+stestnum,i:120:20,:]

    label=tf.constant(tmp2,dtype=tf.float32)
    testlabel=tf.constant(tmp3,dtype=tf.float32)

    data=tf.constant(data,dtype=tf.float32)

    ODENet = ODENet()
    opt = tf.keras.optimizers.Adam(1*1e-3)

    for i in range(n):
      for j in range(n):
        for k in range(n3):
          x=tf.constant([-np.pi+2*np.pi/n*i,-np.pi+2*np.pi/n*j,-np.pi+2*np.pi/n3*k],dtype=tf.float32)
          if i==0 and j==0 and k==0:
            v=tf.cast(vhat([-np.pi+2*np.pi/n*i,-np.pi+2*np.pi/n*j],-np.pi+2*np.pi/n3*k),dtype=tf.complex64)*tf.math.exp(tf.complex(tf.zeros((n,n,n3)),tf.tensordot(-ind,x,[[3],[0]])))/n/n/n3
          else:
            v=v+tf.cast(vhat([-np.pi+2*np.pi/n*i,-np.pi+2*np.pi/n*j],-np.pi+2*np.pi/n3*k),dtype=tf.complex64)*tf.math.exp(tf.complex(tf.zeros((n,n,n3)),tf.tensordot(-ind,x,[[3],[0]])))/n/n/n3


    ODENet.l5(ODENet.l4(ODENet.l3(ODENet.l2(ODENet.l1(v)))))

    print("epochs", "  train loss", "  test loss")
    for epoch in range(1, epochs + 1):
        
        loss,loss2,eig= train(ODENet, label, testlabel, opt, v)

        if epoch%1==0:
            print(epoch,loss.numpy(),loss2.numpy(),flush=True)
        if epoch%1000==0:
            # save the eigenvalues of the approximations of the generators
            for i in range(5):
              np.savetxt("eig"+str(epoch)+"_"+str(i)+".txt",eig[i])
