# This is a serial implementation of the MGRIT algorithm in support of
# the paper titled "Parallel Training of GRU Networks with A Multi-Grid Solver
# for Long Sequences," submitted to ICLR 2022.
#
# The MGRIT algorithm was originally presented in:
#
#   R. D. Falgout, S. Friedhoff, T. V. Kolev, S. P. MacLachlan, and J. B. Schroder.
#   Parallel time integration with multigrid. 
#   SIAM Journal on Scientific Computing, 2014
# 

import numpy as np
import numpy.linalg as la

# define the ODE (network) to solve
steps     = 128   # Number of steps
dim       = 10    # Dimension of the ODE
decay     = 1.0   # Large values make for stiff ODE

# Weights and biases for the ODE
A_weight = np.random.randn(dim,dim)
B_weight = np.random.randn(dim,dim)
b_bias   = np.random.randn(dim)

def rnn_step(h_nminus1,d_n,dt):
  """
  An semi-implicit discretization of n
  OOE modeling n RNN

     h' = -decay * h + sigmoid(A * h + B * d + b)

  The ODE will be implicit in the term with the 
  'decay' parameter. 'd' is the data for the RNN. Both
  'h' and 'd' are time dependent
  """
  temp = A_weight.dot(h_nminus1)+B_weight.dot(d_n)+b_bias
  return (h_nminus1 +dt*1./(1.+np.exp(-temp)))/(1.0+decay*dt)

def forward_prop(data,dt,forcing=None):
  """
  Perform forward propagation, with an optional external
  forcing, used for the tau-corrrection in MGRIT.
  """
  if forcing is None:
    forcing = np.zeros(data.shape)
  hall = np.zeros(data.shape)
  h = hall[0,:]
  for step in range(data.shape[0]-1):
    # this loop can be parallelilized
    dstep = data[step+1,:]
    h = rnn_step(h,dstep,dt) + forcing[step+1,:]
    hall[step+1,:] = h
  return hall

def residual(hguess,data,dt):
  """
  Compute the residual of a given gueess at
  the solution to the ODE.
  """
  steps = data.shape[0]-1
  res = np.zeros((steps+1,dim))
  for step in range(steps):
    dstep = data[step+1,:]
    res[step+1,:] = hguess[step+1]-rnn_step(hguess[step],dstep,dt)
  return res

def relax(hinput,data,dt,cf,relax_type):
  """
  Perform F or FC-relaxation given the current
  guess at the solution to the ODE.
  """
  relax_type = 1 if relax_type=='f' else 0
  houtput = hinput.copy()
  for c in range(int((data.shape[0]-1)/cf)):
    # this loop can be parallelilized
    h = hinput[c*cf,:]
    for substep in range(cf-relax_type):
      step = c*cf+substep
      h = rnn_step(h,data[step+1,:],dt)
      houtput[step+1,:] = h
  return houtput

def restrict(h,cf):
  """
  Computue the restriction of the ODE solutuion
  by choosing a subsquence based on the coaresing
  factor.
  """
  # this computuation can be parallelized
  return h[::cf]

def prolong(h_c,cf):
  """
  Computue the prolongation of the coarse ODE 
  solutuion using piecewise constant interpolation.
  """
  steps_c = h_c.shape[0]-1
  steps_f = cf*steps_c
  h = np.zeros((steps_f+1,dim))
  for c in range(steps_c):
    # this loop can be parallelilized
    for substep in range(cf):
      step = c*cf+substep
      h[step,:] = h_c[c]
    
  return h
  
def two_level_mg(h,data,dt,cf,relax_steps):
  """
  A two level multi-grid method to solve
  the RNN inspired ODE. The coarsening factor
  and the number of relaxation steps are the only
  parameters

  Line numbers specified in the comments can be
  correlated with line numbers of Algorithm 2
  in the originating paper "Parallel Training of GRU..."
  """ 

  h_prime = h.copy()

  # Line 1: perform FCF-relaxation for multiple steps
  for _ in range(relax_steps):
    h_prime = relax(h_prime,data,dt,cf,relax_type='fc')
    h_prime = relax(h_prime,data,dt,cf,relax_type='f')
  
  # Lines 2 and 3: coarsen the fine grid vectors
  res_c     = restrict(residual(h_prime,data,dt),cf)
  h_c_prime = restrict(h_prime,cf)


  # Line 4: solve the coarse grid problem
  data_c    = restrict(data,cf)             
  correct_c = residual(h_c_prime,data_c,dt*cf)-res_c # coarse tau-correction
  h_c_star = forward_prop(data_c,cf*dt,correct_c)

  # Line 5: update the fine grid solution
  h_pprime = h_prime+prolong(h_c_star-h_c_prime,cf)
 
  # Line 6: Do one sweep of F-relaxation to propgate the coarse solution
  return relax(h_pprime,data,1.0,cf,relax_type='f')

if __name__ == '__main__':
  dt          = 1.0 # time step size
  cf          = 4   # coarsening factor
  mg_iters    = 8   # number of multigrid iterations

  # compute somoe contrined data
  timenodes = np.linspace(0.0,steps,steps+1)
  data      = np.sin(3.0*np.pi*np.outer(timenodes,np.random.randn(dim))/steps)
  data[0,:] = 0.0
  
  # compute the exact solution to the ODE
  exact = forward_prop(data,dt)
  
  # Run a two level MGRIT print the error and residual
  print('Two Level MG')
  two_level = np.zeros(exact.shape)
  for _ in range(mg_iters):
    two_level = two_level_mg(two_level,data,dt,cf,2)
    print(f'  error = {la.norm(two_level-exact):.4e}, residual = {la.norm(residual(two_level,data,dt)):.4e}')
