"""
Reaction-Diffusion, Neumann BC

Reduced-dimensional MKRR
Acceleration by dimension reduction by PCA

Detailed result comparison is done in rd.py
"""
import pickle
import matplotlib.pyplot as plt
import numpy as np
import time

from geometric_multivariate_kernel_ridge_regression.src.manifold import Manifold
from geometric_multivariate_kernel_ridge_regression.src.kernelreg import M2KRR
from geometric_multivariate_kernel_ridge_regression.src.dynreg import DynRegMan
from geometric_multivariate_kernel_ridge_regression.src.utils import make_se

data = pickle.load(open('data/rddata.pkl', 'rb'))
dt = data['dt']
xx = data['x']
yy = data['y']
tt = data['t']
uu = data['udata']

Nx, Nt = xx.size*2, len(tt)
assert uu.shape == (Nt, Nx)

Nn = xx.shape[0]
Nu = xx.size

N0 = 5000
Ntrain = 200
Ntest  = 5000

t_sim = tt[N0:N0+Ntest]
t_plt = t_sim

# Training data
data_train = uu[N0:N0+Ntrain]

# Test data
data_test = [uu[N0:N0+Ntest]]
a0_test   = [data_test[0][0]]

# ---------------------------
# Preprocess by SVD
_, _, _Vh = np.linalg.svd(data_train, full_matrices=False)
Proj = _Vh[:6].T
svd_train = data_train.dot(Proj)

recon_train = svd_train.dot(Proj.T)
err_train = np.linalg.norm(recon_train-data_train, axis=1) / np.linalg.norm(data_train, axis=1)
print(np.max(err_train))

# recon_test  = data_test[0].dot(Proj).dot(Proj.T)
# err_test  = np.linalg.norm(recon_test-data_test[0], axis=1) / np.linalg.norm(data_test[0], axis=1)
# print(np.max(err_test))

svd0_test = [data_test[0][0].dot(Proj)]

# ---------------------------

ifest = 0    # Estimate dimension and reference bandwidth
ifkrr = 1    # Generate GMKRR-PCA-T0/-T4 models
ift4  = 1    # Toggles T4 model
iftim = 0    # Sanity check

if ifest:
    man = Manifold(svd_train, d=2)
    dim = man.estimate_intrinsic_dim(bracket=[-20, 10], tol=0.2, ifplt=False)

# ------------------
# GMKRR
# ------------------
if ifkrr:
    t1 = time.time()
    order_fd = '1'
    if ift4:
        manopt = {             # This is for T4 model
            'd' : 1,
            'g' : 4,
            'T' : 4,
            'iforit' : True
        }
    else:
        manopt = {             # This is for T0 model
            'd' : 1,
            'g' : 4,
            'T' : 0
        }
    regloc = {
        'man' : Manifold,
        'manopt' : manopt,
        'ker' : make_se(31319.5),
        'nug' : 1e-6,
        'ifvec' : True
    }
    dloc = DynRegMan(M2KRR, regloc, fd=order_fd, dt=dt)
    dloc.fit([svd_train])

    t2 = time.time()
    sol = []
    for _a in svd0_test:
        a_krr = dloc.solve(_a, t_sim)
        # a_krr = dloc.solve(_a, t_sim, alg='RK2')
        sol.append(a_krr.dot(Proj.T))
    t3 = time.time()

    print(t2-t1, t3-t2)

    pickle.dump(sol, open(f'./res/rdp_{Ntrain}_{manopt["T"]}.dat', 'wb'))

# ------------------
# Temporal response
# ------------------
if iftim:
    IDX = 0
    a_tru = data_test[IDX]
    a_krr = pickle.load(open(f'./res/rdp_{Ntrain}_{manopt["T"]}.dat', 'rb'))[IDX]

    tdx = -1
    f, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(10, 6))
    _tru = a_tru[tdx,:Nu].reshape(Nn,Nn)
    _krr = a_krr[tdx,:Nu].reshape(Nn,Nn)
    cs = ax[0,0].contourf(xx, yy, _tru)
    ax[0,1].contourf(xx, yy, _krr, levels=cs.levels)
    ax[0,2].contourf(xx, yy, _tru-_krr, levels=cs.levels)
    ax[0,2].set_title(f'Error {np.linalg.norm(_tru-_krr):4.3e}')
    plt.colorbar(cs, ax=ax[0])
    _tru = a_tru[tdx,Nu:].reshape(Nn,Nn)
    _krr = a_krr[tdx,Nu:].reshape(Nn,Nn)
    cs = ax[1,0].contourf(xx, yy, _tru)
    ax[1,1].contourf(xx, yy, _krr, levels=cs.levels)
    ax[1,2].contourf(xx, yy, _tru-_krr, levels=cs.levels)
    ax[1,2].set_title(f'Error {np.linalg.norm(_tru-_krr):4.3e}')
    plt.colorbar(cs, ax=ax[1])

    ax[0,0].set_title('U - Truth')
    ax[0,1].set_title('Prediction')
    ax[1,0].set_title('V - Truth')
    ax[1,1].set_title('Prediction')

    ax[1,0].set_xlabel('$x$')
    ax[1,1].set_xlabel('$x$')
    ax[1,2].set_xlabel('$x$')
    ax[0,0].set_ylabel('$y$')
    ax[1,0].set_ylabel('$y$')

plt.show()
