"""
KS beating-travelling example

FFT-reduced GMKRR

We employ identical preprocessing steps in CANDyMan, which is available in their repo.

Detailed result comparison is done in ks_travel.py
"""
import pickle
import matplotlib.pyplot as plt
import numpy as np
import time

from geometric_multivariate_kernel_ridge_regression.src.manifold import Manifold
from geometric_multivariate_kernel_ridge_regression.src.kernelreg import M2KRR
from geometric_multivariate_kernel_ridge_regression.src.dynreg import DynRegMan
from geometric_multivariate_kernel_ridge_regression.src.utils import make_se

data = pickle.load(open('data/ksdata_travelling.pkl', 'rb'))
dt = data['dt']
nu = data['nu']
xx = data['x']
tt = data['t']
uu = data['udata']

Nx, Nt = len(xx), len(tt)
assert uu.shape == (Nt, Nx)

Ntrain = 100
Ntest  = 14000

t_sim = tt[:Ntest]
t_plt = t_sim - t_sim[0]

# Test data
data_test = [uu[:Ntest]]

# ----------------------
# Processing by FFT
# From CANDyMan
# ----------------------
Xhat = np.fft.fft(uu)
phi = np.angle(Xhat[:, 1])
wav = np.concatenate((np.arange(33), np.arange(-31, 0))) # wavenumbers
XhatShift = Xhat*np.exp(-1j*np.outer(phi, wav))
Xshift = np.real(np.fft.ifft(XhatShift))

dphi = phi[1:] - phi[:-1]
dphi += (dphi < -np.pi)*2.0*np.pi - (dphi > np.pi)*2.0*np.pi
Xphi = dphi.reshape(-1,1)

phi_recon = np.hstack([phi[0], phi[0]+np.cumsum(dphi)])
tmp = np.fft.fft(Xshift) * np.exp(1j*np.outer(phi_recon, wav))
data_recon = np.real(np.fft.ifft(tmp))

a0_spc = Xshift[0]
a0_phi = [phi[0]]

data_train = Xshift[:Ntrain]
phi_train = Xphi[:Ntrain-1]
# ----------------------

ifest = 0    # Estimate dimension and reference bandwidth
ifkrr = 1    # Generate GMKRR-FFT models 
iftim = 0    # Sanity check

if ifest:
    man = Manifold(Xshift, d=1)
    dim = man.estimate_intrinsic_dim(bracket=[-20, 10], tol=0.2, ifplt=False)

# ------------------
# GMKRR
# ------------------
if ifkrr:
    t1 = time.time()
    manopt = {
        'd' : 1,
        'g' : 4,
        'T' : 0
    }
    regloc = {
        'man' : Manifold,
        'manopt' : manopt,
        'ker' : make_se(61.47),
        'nug' : 1e-6,
        'ifvec' : True
    }
    dloc = DynRegMan(M2KRR, regloc, fd='1', dt=dt)
    dloc.fit([data_train])

    mphi = Manifold(data_train[:-1], **manopt)
    mphi.precompute()
    t2 = time.time()

    a_spc = dloc.solve(a0_spc, t_sim)
    a_phi = np.array([
        mphi.gmls(_a, phi_train) for _a in a_spc[:-1]
    ])

    phi_recon = np.hstack([phi[0], phi[0]+np.cumsum(a_phi)])
    tmp = np.fft.fft(a_spc) * np.exp(1j*np.outer(phi_recon, wav))
    data_recon = np.real(np.fft.ifft(tmp))
    t3 = time.time()

    print(t2-t1, t3-t2)
    pickle.dump([data_recon], open(f'./res/ks_trf_{Ntrain}.dat', 'wb'))

# ------------------
# Temporal response
# ------------------
if iftim:
    IDX = 0
    a_tru = data_test[IDX]
    a_krr = pickle.load(open(f'./res/ks_trf_{Ntrain}.dat', 'rb'))[IDX]

    X, T = np.meshgrid(xx, t_plt)

    f, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 6))
    cs = ax[0].contourf(X, T, a_tru)
    ax[1].contourf(X, T, a_krr, levels=cs.levels)
    ax[2].contourf(X, T, a_tru-a_krr, levels=cs.levels)
    plt.colorbar(cs, ax=ax)

    ax[0].set_title('Truth')
    ax[1].set_title('Prediction')
    ax[2].set_title(f'Error {np.linalg.norm(a_tru-a_krr):4.3e}')

    ax[0].set_xlabel('$x$')
    ax[0].set_ylabel('$t$')
    ax[1].set_xlabel('$x$')
    ax[2].set_xlabel('$x$')

    f.savefig('pics/ks_trf_tim_cmp.png', dpi=600, bbox_inches='tight')

plt.show()
