"""
KS beating-travelling example

Full-dimensional GMKRR training and prediction
Plotting of results - Run other baselines first
"""
import pickle
import matplotlib.pyplot as plt
import numpy as np
import time

from geometric_multivariate_kernel_ridge_regression.src.manifold import Manifold
from geometric_multivariate_kernel_ridge_regression.src.kernelreg import M2KRR
from geometric_multivariate_kernel_ridge_regression.src.dynreg import DynRegMan
from geometric_multivariate_kernel_ridge_regression.src.utils import make_se

data = pickle.load(open('data/ksdata_travelling.pkl', 'rb'))
dt = data['dt']
nu = data['nu']
xx = data['x']
tt = data['t']
uu = data['udata']

Nx, Nt = len(xx), len(tt)
assert uu.shape == (Nt, Nx)

Ntrain = 6000
Ntest  = 14000

t_sim = tt[:Ntest:2]
t_plt = t_sim - t_sim[0]

# Training data
data_train = uu[:Ntrain:2]
dt_train = dt*2

# Test data
data_test = [uu[:Ntest:2]]
a0_test   = [data_test[0][0]]

lbls = ['Truth', 'GMKRR - Full', 'GMKRR - FFT', 'CANDyMan', 'NODE', 'LDNet']

# Explore
ifdim = 0    # Estimate intrinsic dimension; also obtaining the reference bandwidth
ifman = 1    # Visualize manifold structure, Fig. 4(f)-(g)

# Training
ifkrr = 1    # Train the model

# Plots
ifrms = 1    # Computation of reported error metrics, Table 1
iftim = 1    # Comparison of solutions at a time station, Fig. 4(b)-(c), Fig. 8
iferr = 1    # Time evolution of trajectory errors, Fig. 4(j)-(k)

if ifdim:
    man = Manifold(uu, d=2)
    dim, (f, ax) = man.estimate_intrinsic_dim(bracket=[-40, 5], tol=0.2, ifplt=True)
    _, _, f = man.visualize_intrinsic_dim()

if ifman:
    from sklearn.manifold import Isomap

    Nm = 1000
    K = int(np.sqrt(Nm))
    isom = Isomap(n_neighbors=K, n_components=3)

    dat = uu[:Nm]
    X = isom.fit_transform(dat).T
    f = plt.figure()
    ax = plt.axes(projection='3d')
    ax.plot(*X, 'b-')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    f.savefig(f'pics/ks_tra_man_{len(dat)}.png', dpi=600, bbox_inches='tight')

    dat = uu[::4]
    X = isom.fit_transform(dat).T
    f = plt.figure()
    ax = plt.axes(projection='3d')
    ax.scatter(*X, '.', c=X[2])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    f.savefig(f'pics/ks_tra_man_{len(dat)}.png', dpi=600, bbox_inches='tight')

# ------------------
# GMKRR
# ------------------
if ifkrr:
    t1 = time.time()
    order_fd = '1'
    manopt = {
        'd' : 3,
        'g' : 3
    }
    regloc = {
        'man' : Manifold,
        'manopt' : manopt,
        'ker' : make_se(98.0816),
        'nug' : 1e-6,
        'ifvec' : True
    }
    dloc = DynRegMan(M2KRR, regloc, fd=order_fd, dt=dt_train)
    dloc.fit([data_train])

    t2 = time.time()
    sol = []
    for _a in a0_test:
        a_krr = dloc.solve(_a, t_sim)
        # a_krr = dloc.solve(_a, t_sim, alg='RK2')
        sol.append(a_krr)
    t3 = time.time()
    print(t2-t1, t3-t2)

    pickle.dump(sol, open(f'./res/ks_tra_{Ntrain}.dat', 'wb'))

# ------------------
# Error
# ------------------
if ifrms:
    IDX = 0
    a_tru = data_test[IDX]
    a_can = pickle.load(open(f'./res/kst.pkl', 'rb'))[-1][:-1:2, :-1]
    a_nde = pickle.load(open(f'./res/ks_trf_nde.pkl', 'rb'))[0][:-1:2]
    a_ldn = pickle.load(open(f'./res/ks_trf_ldn.pkl', 'rb'))[0][:-1:2]
    a_ful = pickle.load(open(f'./res/ks_tra_{Ntrain}.dat', 'rb'))[IDX]
    a_red = pickle.load(open(f'./res/ks_trf_100.dat', 'rb'))[IDX][::2]

    for _a in [a_can, a_nde, a_ldn, a_ful, a_red]:
        tmp = a_tru-_a
        rmse = np.linalg.norm(tmp, axis=1) / np.sqrt(a_tru.shape[1])
        maxe = np.max(np.abs(tmp))
        print(np.mean(rmse), maxe)

# ------------------
# Temporal response
# ------------------
if iftim:
    # Nplt = 100
    Nplt = 7000
    IDX = 0
    a_tru = data_test[IDX][:Nplt]
    a_can = pickle.load(open(f'./res/kst.pkl', 'rb'))[-1][::2,:-1][:Nplt]
    a_nde = pickle.load(open(f'./res/ks_trf_nde.pkl', 'rb'))[0][:-1:2][:Nplt]
    a_ldn = pickle.load(open(f'./res/ks_trf_ldn.pkl', 'rb'))[0][:-1:2][:Nplt]
    a_ful = pickle.load(open(f'./res/ks_tra_{Ntrain}.dat', 'rb'))[IDX][:Nplt]
    a_red = pickle.load(open(f'./res/ks_trf_100.dat', 'rb'))[IDX][::2][:Nplt]
    _dat = [a_tru.T, a_ful.T, a_red.T, a_can.T, a_nde.T, a_ldn.T]

    T, X = np.meshgrid(t_plt[:Nplt], xx)

    f, AX = plt.subplots(nrows=4, ncols=3, sharex=True, sharey=True, figsize=(12, 7))
    ax = AX.flatten()
    cs = ax[0].contourf(T, X, _dat[0])
    for _d in range(1, 6):
        _k = 3 * (_d // 3)
        ax[_d+_k].contourf(T, X, _dat[_d], levels=cs.levels)
        ax[_d+_k].set_title(lbls[_d])
        ax[_d+_k+3].contourf(T, X, _dat[0]-_dat[_d], levels=cs.levels)
        _e = np.mean(np.sqrt(np.mean((_dat[0]-_dat[_d])**2, axis=0)))
        ax[_d+_k+3].set_title(f"RMSE {_e:4.3e}")
    for _i in range(4):
        plt.colorbar(cs, ax=ax[3*_i:3*(_i+1)])

    ax[0].set_title('Truth')
    for _d in range(9, 12):
        ax[_d].set_xlabel('$t$')
    for _d in [0, 3, 6, 9]:
        ax[_d].set_ylabel('$x$')
    ax[3].set_axis_off()

    f.savefig(f'pics/ks_tra_tim_cmp_{Nplt}.png', dpi=600, bbox_inches='tight')

# ------------------
# Temporal error of trajectories
# ------------------
if iferr:
    FS = 18
    IDX = 0
    END, SKP = 100, 1
    # END, SKP = 7000, 80
    a_tru = data_test[IDX][:END:SKP]
    a_can = pickle.load(open(f'./res/kst.pkl', 'rb'))[-1][:-1:2, :-1][:END:SKP]
    a_nde = pickle.load(open(f'./res/ks_trf_nde.pkl', 'rb'))[0][:-1:2][:END:SKP]
    a_ldn = pickle.load(open(f'./res/ks_trf_ldn.pkl', 'rb'))[0][:-1:2][:END:SKP]
    a_ful = pickle.load(open(f'./res/ks_tra_{Ntrain}.dat', 'rb'))[IDX][:END:SKP]
    a_red = pickle.load(open(f'./res/ks_trf_100.dat', 'rb'))[IDX][::2][:END:SKP]
    dat = [a_ful, a_red, a_can, a_nde, a_ldn]

    f = plt.figure()
    for _i in range(5):
        err = np.linalg.norm(a_tru-dat[_i], axis=1) / np.sqrt(a_tru.shape[1])
        sty = '--' if _i > 2 else '-'
        plt.semilogy(t_plt[:END:SKP], err, sty, label=lbls[_i+1])
    plt.legend(loc=4, fontsize=FS-4, ncol=2)
    plt.xlabel('$t$', fontsize=FS)
    plt.ylabel('RMSE', fontsize=FS)
    ax = plt.gca()
    ax.tick_params(axis='both', which='major', labelsize=FS)
    if END == 7000:
        ax.set_ylim([1e-2, 10])
    else:
        ax.set_ylim([1e-4, 1.0])

    f.savefig(f'pics/ks_tra_err_{END}.png', dpi=600, bbox_inches='tight')

plt.show()
