"""
KS Beating example

GMKRR training
Plotting of results
"""
import pickle
import matplotlib.pyplot as plt
import numpy as np
import time

from geometric_multivariate_kernel_ridge_regression.src.manifold import Manifold
from geometric_multivariate_kernel_ridge_regression.src.kernelreg import M2KRR
from geometric_multivariate_kernel_ridge_regression.src.dynreg import DynRegMan
from geometric_multivariate_kernel_ridge_regression.src.utils import make_se

data = pickle.load(open('data/ksdata_beating.pkl', 'rb'))
dt = data['dt']
nu = data['nu']
xx = data['x']
tt = data['t']
uu = data['udata']

Nx, Nt = len(xx), len(tt)
assert uu.shape == (Nt, Nx)

Ntrain = 100
Ltest = 500

t_sim = np.arange(Ltest)*dt
t_plt = t_sim

# Training data
data_train = uu[:Ntrain]

# Test data
data_test = [uu[:]]
a0_test   = [data_test[0][0]]

lbls = ['Truth', 'CANDyMan', 'GMKRR (Ours)']

# Explore
ifdim = 0    # Estimate intrinsic dimension; also obtaining the reference bandwidth
ifman = 1    # Visualize manifold structure, Fig. 4(e)

# Training
ifkrr = 1    # Train the model

# Plots
ifrms = 1    # Computation of reported error metrics, Table 1
iftim = 1    # Comparison of solutions at a time station, Fig. 4(a), Fig. 10(a)
iferr = 1    # Time evolution of trajectory errors, Fig. 4(i)
iftrj = 1    # Comparison of temporal response at spatial locations, Fig. 10(b)

if ifdim:
    man = Manifold(uu, d=2)
    dim, (f, ax) = man.estimate_intrinsic_dim(bracket=[-40, 5], tol=0.2, ifplt=True)
    _, _, f = man.visualize_intrinsic_dim()

if ifman:
    from sklearn.manifold import Isomap

    Nm = 1000
    dat = uu[:Nm]
    K = int(np.sqrt(Nm))

    isom = Isomap(n_neighbors=K, n_components=3)
    X = isom.fit_transform(dat).T
    f = plt.figure()
    ax = plt.axes(projection='3d')
    ax.plot(*X, 'b-')
    ax.scatter(*X, '.', c=X[2])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    f.savefig(f'pics/ks_bea_man.png', dpi=600, bbox_inches='tight')

# ------------------
# GMKRR
# ------------------
if ifkrr:
    t1 = time.time()
    order_fd = '1'
    manopt = {
        'd' : 1,
        'g' : 4,
        'T' : 0
    }
    regloc = {
        'man' : Manifold,
        'manopt' : manopt,
        'ker' : make_se(440.163),
        'nug' : 1e-6,
        'ifvec' : True
    }
    dloc = DynRegMan(M2KRR, regloc, fd=order_fd, dt=dt)

    dloc.fit([data_train])
    t2 = time.time()

    sol = []
    for _a in a0_test:
        a_krr = dloc.solve(_a, t_sim)
        # a_krr = dloc.solve(_a, t_sim, alg='RK2')
        sol.append(a_krr)
    t3 = time.time()

    print(t2-t1, t3-t2)

    pickle.dump(sol, open(f'./res/ks_bea_{Ntrain}.dat', 'wb'))

# ------------------
# Error
# ------------------
if ifrms:
    IDX = 0
    a_tru = data_test[IDX]
    a_can = pickle.load(open(f'./res/ksb.pkl', 'rb'))[-1][:-1,:-1]
    a_krr = pickle.load(open(f'./res/ks_bea_{Ntrain}.dat', 'rb'))[IDX]

    for _a in [a_can, a_krr]:
        tmp = a_tru-_a
        rmse = np.linalg.norm(tmp, axis=1) / np.sqrt(a_tru.shape[1])
        maxe = np.max(np.abs(tmp))
        print(np.mean(rmse), maxe)

# ------------------
# Temporal response
# ------------------
if iftim:
    IDX = 0
    a_tru = data_test[IDX]
    a_can = pickle.load(open(f'./res/ksb.pkl', 'rb'))[-1][:-1,:-1]
    a_krr = pickle.load(open(f'./res/ks_bea_{Ntrain}.dat', 'rb'))[IDX]
    _dat = [a_tru.T, a_can.T, a_krr.T]

    T, X = np.meshgrid(t_plt, xx)

    f, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(10, 4))
    cs = ax[0,0].contourf(T, X, _dat[0])
    for _d in range(1, 3):
        ax[0,_d].contourf(T, X, _dat[_d], levels=cs.levels)
        ax[0,_d].set_title(lbls[_d])
        ax[1,_d].contourf(T, X, _dat[0]-_dat[_d], levels=cs.levels)
        _e = np.mean(np.sqrt(np.mean((_dat[0]-_dat[_d])**2, axis=0)))
        ax[1,_d].set_title(f"RMSE {_e:4.3e}")
    plt.colorbar(cs, ax=ax[0])
    plt.colorbar(cs, ax=ax[1])

    ax[0,0].set_title('Truth')
    for _d in range(1, 3):
        ax[1,_d].set_xlabel('$t$')
    ax[0,0].set_xlabel('$t$')
    ax[0,0].set_ylabel('$x$')
    ax[1,1].set_ylabel('$x$')
    ax[1,0].set_axis_off()

    f.savefig('pics/ks_bea_tim_cmp.png', dpi=600, bbox_inches='tight')

# ------------------
# Temporal error of trajectories
# ------------------
if iferr:
    FS = 18
    IDX = 0
    a_tru = data_test[IDX]
    a_can = pickle.load(open(f'./res/ksb.pkl', 'rb'))[-1][:-1,:-1]
    a_krr = pickle.load(open(f'./res/ks_bea_{Ntrain}.dat', 'rb'))[IDX]
    dat = [a_can, a_krr]

    f = plt.figure()
    for _i in range(2):
        err = np.linalg.norm(a_tru-dat[_i], axis=1) / np.sqrt(a_tru.shape[1])
        plt.semilogy(t_plt, err, '-', label=lbls[_i+1])
    plt.legend(loc=4, fontsize=FS)
    plt.xlabel('$t$', fontsize=FS)
    plt.ylabel('RMSE', fontsize=FS)
    ax = plt.gca()
    ax.tick_params(axis='both', which='major', labelsize=FS)
    f.savefig(f'pics/ks_bea_err.png', dpi=600, bbox_inches='tight')

# ------------------
# Temporal error of trajectories
# ------------------
if iftrj:
    IDX = 0
    a_tru = data_test[IDX]
    a_can = pickle.load(open(f'./res/ksb.pkl', 'rb'))[-1][:-1,:-1]
    a_krr = pickle.load(open(f'./res/ks_bea_{Ntrain}.dat', 'rb'))[IDX]

    xdx = [25, 50]

    f, ax = plt.subplots(ncols=2, figsize=(10, 3))
    for _i in range(2):
        ax[_i].plot(t_sim, a_tru[:,xdx[_i]], '--', label=lbls[0], color='0.5', lw=3)
        ax[_i].plot(t_sim, a_can[:,xdx[_i]], '-', label=lbls[1], lw=1.5)
        ax[_i].plot(t_sim, a_krr[:,xdx[_i]], '-', label=lbls[2], lw=1.5)

        ax[_i].set_title(f'x/L={xx[xdx[_i]]/(2*np.pi):3.2f}', fontsize=10)
        ax[_i].grid()
        ax[_i].set_xlim([4, 5])
        ax[_i].set_xlabel('$t$')
        ax[_i].set_xticks([4, 4.5, 5])
    ax[0].set_ylabel('$u$')
    ax[0].legend(loc="lower left", ncol=1, fontsize=10, fancybox=True, framealpha=0.5)

    f.savefig(f'pics/ks_bea_trj.png', dpi=600, bbox_inches='tight')

plt.show()
