"""
Reaction-Diffusion, Neumann BC

Full-dimensional GMKRR training and prediction
Plotting of results

** Run `rd_pca.py` first to obtain PCA-T0/-T4 results.
"""
import pickle
import matplotlib.pyplot as plt
import numpy as np
import time

from geometric_multivariate_kernel_ridge_regression.src.manifold import Manifold
from geometric_multivariate_kernel_ridge_regression.src.kernelreg import M2KRR
from geometric_multivariate_kernel_ridge_regression.src.dynreg import DynRegMan
from geometric_multivariate_kernel_ridge_regression.src.utils import make_se

data = pickle.load(open('data/rddata.pkl', 'rb'))
dt = data['dt']
xx = data['x']
yy = data['y']
tt = data['t']
uu = data['udata']

Nx, Nt = xx.size*2, len(tt)
assert uu.shape == (Nt, Nx)

Nn = xx.shape[0]
Nu = xx.size

N0 = 5000
Ntrain = 200
Ntest  = 5000

t_sim = tt[N0:N0+Ntest]
t_plt = t_sim - t_sim[0]

# Training data
data_train = uu[N0:N0+Ntrain]

# Test data
data_test = [uu[N0:N0+Ntest]]
a0_test   = [data_test[0][0]]

lbls = ['Truth', 'CANDyMan', 'GMKRR-Full', 'GMKRR-PCA-T0', 'GMKRR-PCA-T4']

# Explore
ifdim = 0    # Estimate intrinsic dimension; also obtaining the reference bandwidth
ifman = 1    # Visualize manifold structure, Fig. 4(h)

# Training
ifkrr = 1    # Train the model

# Plots
ifrms = 1    # Computation of reported error metrics, Table 1
iftim = 1    # Comparison of solutions at a time station, Fig. 4(d), Fig. 12(a)-(b)
iferr = 1    # Time evolution of trajectory errors, Fig. 4(l)
iftrj = 1    # Comparison of temporal response at spatial locations, Fig. 12(c)

if ifdim:
    Nm = 1000
    dat = uu[:Nm]
    man = Manifold(dat, d=2)
    dim, (f, ax) = man.estimate_intrinsic_dim(bracket=[-20, 10], tol=0.2, ifplt=True)
    _, _, f = man.visualize_intrinsic_dim()

if ifman:
    from sklearn.manifold import Isomap

    Nm = 1000
    dat = uu[:Nm]
    K = int(np.sqrt(Nm))

    isom = Isomap(n_neighbors=K, n_components=3)
    X = isom.fit_transform(dat).T
    f = plt.figure()
    ax = plt.axes(projection='3d')
    # ax.plot(*X, 'b-')
    ax.scatter(*X, '.', c=X[2])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    f.savefig(f'pics/rd_man.png', dpi=600, bbox_inches='tight')

# ------------------
# GMKRR
# ------------------
if ifkrr:
    t1 = time.time()
    order_fd = '1'
    manopt = {
        'd' : 1,
        'g' : 4,
        'T' : 0
    }
    regloc = {
        'man' : Manifold,
        'manopt' : manopt,
        'ker' : make_se(31000.6),
        'nug' : 1e-6,
        'ifvec' : True
    }
    dloc = DynRegMan(M2KRR, regloc, fd=order_fd, dt=dt)
    dloc.fit([data_train])

    t2 = time.time()
    sol = []
    for _a in a0_test:
        a_krr = dloc.solve(_a, t_sim)
        sol.append(a_krr)
    t3 = time.time()

    print(t2-t1, t3-t2)

    pickle.dump(sol, open(f'./res/rd_{Ntrain}.dat', 'wb'))

# ------------------
# Error
# ------------------
if ifrms:
    IDX = 0
    a_tru = data_test[IDX]
    a_can = pickle.load(open(f'./res/rd.pkl', 'rb'))[0][:-1]
    a_ful = pickle.load(open(f'./res/rd_{Ntrain}.dat', 'rb'))[IDX]
    a_pt0 = pickle.load(open(f'./res/rdp_{Ntrain}_0.dat', 'rb'))[IDX]
    a_pt4 = pickle.load(open(f'./res/rdp_{Ntrain}_4.dat', 'rb'))[IDX]

    for _a in [a_can, a_ful, a_pt0, a_pt4]:
        tmp = a_tru-_a
        rmse = np.linalg.norm(tmp, axis=1) / np.sqrt(a_tru.shape[1])
        maxe = np.max(np.abs(tmp))
        print(np.mean(rmse), maxe)

# ------------------
# Temporal response
# ------------------
if iftim:
    IDX = 0
    tmp = data_test[IDX]
    a_tru = (tmp[:,:Nu], tmp[:,Nu:])
    tmp = pickle.load(open(f'./res/rd.pkl', 'rb'))[0]
    a_can = (tmp[:-1,:Nu], tmp[:-1,Nu:])
    tmp = pickle.load(open(f'./res/rd_{Ntrain}.dat', 'rb'))[IDX]
    a_ful = (tmp[:,:Nu], tmp[:,Nu:])
    tmp = pickle.load(open(f'./res/rdp_{Ntrain}_0.dat', 'rb'))[IDX]
    a_pt0 = (tmp[:,:Nu], tmp[:,Nu:])
    tmp = pickle.load(open(f'./res/rdp_{Ntrain}_4.dat', 'rb'))[IDX]
    a_pt4 = (tmp[:,:Nu], tmp[:,Nu:])

    tdx = -1
    for _i in [0, 1]:
        _dat = [
            a_tru[_i][tdx].reshape(Nn,Nn),
            a_can[_i][tdx].reshape(Nn,Nn),
            a_ful[_i][tdx].reshape(Nn,Nn),
            a_pt0[_i][tdx].reshape(Nn,Nn),
            a_pt4[_i][tdx].reshape(Nn,Nn)
        ]

        f, ax = plt.subplots(nrows=2, ncols=5, sharex=True, sharey=True, figsize=(12, 4))
        cs = ax[0,0].contourf(xx, yy, _dat[0])
        for _d in range(1, 5):
            ax[0,_d].contourf(xx, yy, _dat[_d], levels=cs.levels)
            ax[0,_d].set_title(lbls[_d])
            ax[1,_d].contourf(xx, yy, _dat[0]-_dat[_d], levels=cs.levels)
            _e = np.linalg.norm(_dat[0]-_dat[_d]) / Nn
            ax[1,_d].set_title(f"RMSE {_e:4.3e}")
        plt.colorbar(cs, ax=ax[0])
        plt.colorbar(cs, ax=ax[1])

        ax[0,0].set_title('Truth')
        for _d in range(1, 5):
            ax[1,_d].set_xlabel('$x$')
        ax[0,0].set_xlabel('$x$')
        ax[0,0].set_ylabel('$y$')
        ax[1,1].set_ylabel('$y$')
        ax[1,0].set_axis_off()

        f.savefig(f'pics/rd_tim_cmp_{_i}.png', dpi=600, bbox_inches='tight')

# ------------------
# Temporal error of trajectories
# ------------------
if iferr:
    FS = 18
    IDX = 0
    SKP = 10
    a_tru = data_test[IDX][::SKP]
    a_can = pickle.load(open(f'./res/rd.pkl', 'rb'))[0][:-1][::SKP]
    a_ful = pickle.load(open(f'./res/rd_{Ntrain}.dat', 'rb'))[IDX][::SKP]
    a_pt0 = pickle.load(open(f'./res/rdp_{Ntrain}_0.dat', 'rb'))[IDX][::SKP]
    a_pt4 = pickle.load(open(f'./res/rdp_{Ntrain}_4.dat', 'rb'))[IDX][::SKP]
    dat = [a_can, a_ful, a_pt0, a_pt4]
    ref = np.linalg.norm(a_tru, axis=1)

    stys = ['-', '-', '--', '-']
    f = plt.figure()
    for _i in range(4):
        err = np.linalg.norm(a_tru-dat[_i], axis=1) / np.sqrt(a_tru.shape[1])
        plt.semilogy(t_plt[::SKP], err, stys[_i], label=lbls[_i+1])
    plt.legend(fontsize=FS, bbox_to_anchor=(0.42,0.43))
    plt.xlabel('$t$', fontsize=FS)
    plt.ylabel('RMSE', fontsize=FS)
    plt.ylim([1e-5, 0.5])
    ax = plt.gca()
    ax.tick_params(axis='both', which='major', labelsize=FS)
    f.savefig(f'pics/rd_err.png', dpi=600, bbox_inches='tight')

# ------------------
# Temporal error of trajectories
# ------------------
if iftrj:
    IDX = 0
    a_tru = data_test[IDX]
    a_can = pickle.load(open(f'./res/rd.pkl', 'rb'))[0][:-1]
    a_ful = pickle.load(open(f'./res/rd_{Ntrain}.dat', 'rb'))[IDX]
    a_pt0 = pickle.load(open(f'./res/rdp_{Ntrain}_0.dat', 'rb'))[IDX]
    a_pt4 = pickle.load(open(f'./res/rdp_{Ntrain}_4.dat', 'rb'))[IDX]

    xdx = 100
    print(xx.ravel()[xdx], yy.ravel()[xdx])

    f, ax = plt.subplots(ncols=3, sharey=True, figsize=(12, 3))
    for _j in range(3):
        ax[_j].plot(t_plt, a_tru[:,xdx], '--', label=lbls[0], color='0.5', lw=3)
        ax[_j].plot(t_plt, a_can[:,xdx], '-', label=lbls[1], lw=1)
        ax[_j].plot(t_plt, a_ful[:,xdx], '-', label=lbls[2], lw=1)
        ax[_j].plot(t_plt, a_pt0[:,xdx], '-', label=lbls[3], lw=1)
        ax[_j].plot(t_plt, a_pt4[:,xdx], '-', label=lbls[4], lw=1)
        ax[_j].grid()
        ax[_j].set_xlabel('$t$')
    ax[0].set_ylabel('$u$')
    ax[0].set_xlim([30, 50])
    ax[0].set_xticks([30, 35, 40, 45, 50])
    ax[1].set_xlim([130, 150])
    ax[1].set_xticks([130, 135, 140, 145, 150])
    ax[2].set_xlim([230, 250])
    ax[2].set_xticks([230, 235, 240, 245, 250])

    ax[0].legend(loc="lower left", ncol=1, fontsize=10, fancybox=True, framealpha=0.5)

    f.savefig(f'pics/rd_trj.png', dpi=600, bbox_inches='tight')

plt.show()
