"""
Dynamics on distorted S1

GMKRR training and prediction
Plotting of results - Run other baselines first
"""
import copy
import pickle
import matplotlib.pyplot as plt
import numpy as np

from geometric_multivariate_kernel_ridge_regression.src.manifold import Manifold, ManifoldAnalytical
from geometric_multivariate_kernel_ridge_regression.src.kernelreg import M2KRR
from geometric_multivariate_kernel_ridge_regression.src.dynreg import DynRegMan
from geometric_multivariate_kernel_ridge_regression.src.utils import make_se, RMSE

s5 = np.sqrt(5)
def make_dyn(K, D):
    # \dot{\theta} = 3/2 - \cos(\theta)
    def dyn(tt, K=K, D=D):
        vv = 2 * np.arctan(np.tan(s5*tt/4)/s5)
        rr = 1 + D*np.cos(K*vv)
        uu = np.array([
            rr*np.cos(vv),
            rr*np.sin(vv)]).T
        return vv, uu
    def vec(x, K=K, D=D):
        _x = np.atleast_2d(x)
        _t = np.arctan2(_x[:,1], _x[:,0])
        _r = 1 + D*np.cos(K*_t)
        _d = -K*D*np.sin(K*_t)
        _c, _s = np.cos(_t), np.sin(_t)
        _T = np.vstack([
            -_r*_s+_d*_c, _r*_c+_d*_s]).T
        _T /= np.linalg.norm(_T, axis=1).reshape(-1,1)
        return _T
    return dyn, vec

def make_opt(g, T, ifana=False):
    if ifana:
        # This assumes perfect tangent vectors
        manopt = {
            'd' : 1,
            'g' : g
        }
        regloc = {
            'man' : ManifoldAnalytical,
            'manopt' : manopt,
            'ker' : make_se(1.0),
            'nug' : 1e-6,
            'ifvec' : True
        }
    else:
        # This is the more practical case where tangent vectors are estimated
        manopt = {
            'd' : 1,
            'g' : g,
            'T' : T,
            'iforit' : True
        }
        regloc = {
            'man' : Manifold,
            'manopt' : manopt,
            'ker' : make_se(1.0),
            'nug' : 1e-6,
            'ifvec' : True
        }
    return regloc

T = 128
Nt = 3201
tt = np.linspace(0, T, Nt)
dt = tt[1]-tt[0]

Ntrain = (Nt-1)//16
Ltest = Nt

t_sim = np.arange(Ltest)*dt
t_plt = t_sim

FS = 16

# Explore
ifdat = 0    # Visualize the manifold and vector field
ifval = 0    # Cross-validation procedure

# Production
ifrun = 0    # Run through chosen parameters of D

# Visualization
# Run the NODE and LDNet results first, or comment out the correspoonding lines first
iftrj = 0    # Visualize the trajectories in the ambient space, Fig. 1(a)
ifplt = 0    # Plot and compare the prediction, Fig. 1(b) and Fig. 5

if ifdat:
    fdyn, fvec = make_dyn(3, 0.3)  # This is for D=0.3
    vv, uu = fdyn(tt)
    data_train = uu[:Ntrain]
    data_test  = uu

    f = plt.figure()
    plt.plot(*data_train.T, 'b.')
    for _i in range(0, Ntrain, 10):
        _x = data_train[_i]
        _y = _x + 0.3*fvec(_x).squeeze()
        plt.plot([_x[0], _y[0]], [_x[1], _y[1]], 'k-')
    plt.gca().set_aspect('equal', adjustable='box')

    f = plt.figure()
    plt.plot(t_plt, vv)

    f = plt.figure()
    plt.plot(t_plt, data_test[:,0], 'k--')
    plt.plot(t_plt, data_test[:,1], 'k--')

if ifval:
    # fdyn, fvec = make_dyn(3, 0)
    fdyn, fvec = make_dyn(3, 0.1)
    # fdyn, fvec = make_dyn(3, 0.3)
    # fdyn, fvec = make_dyn(3, 0.5)
    vv, uu = fdyn(tt)
    data_train = uu[:Ntrain]
    data_valid = uu[:Ntrain*2]
    t_val = np.arange(len(data_valid))*dt

    errs = np.zeros((2, 5, 5))
    for _g in range(5):
        print(_g)
        for _T in range(5):
            opt = make_opt(_g+3, _T+3)
            dloc = DynRegMan(M2KRR, opt, fd='1', dt=dt)
            dloc.fit([data_train])

            a_krr = dloc.solve(data_valid[0], t_val)
            a_rk2 = dloc.solve(data_valid[0], t_val, alg='RK2')
            errs[0,_g,_T] = RMSE(data_valid, a_krr)[1]
            errs[1,_g,_T] = RMSE(data_valid, a_rk2)[1]

    print(errs[0])
    print(errs[1])

    print(np.argmin(errs[0]))
    print(np.argmin(errs[1]))

cases = [
    [(3,0.0), (6,3)],
    [(3,0.1), (7,5)],
    [(3,0.3), (6,5)],
    [(3,0.5), (5,4)]
]
if ifrun:
    res = []
    for _c in cases:
        fdyn, fvec = make_dyn(*_c[0])
        vv, uu = fdyn(tt)
        data_train = uu[:Ntrain]
        data_test  = uu
        a0_test = uu[0]

        reg4 = make_opt(*_c[1])
        reg4['manopt']['iforit'] = True

        tmp = [data_test]
        t1 = time.time()
        dloc = DynRegMan(M2KRR, reg4, fd='1', dt=dt)
        dloc.fit([data_train])
        t2 = time.time()
        a_krr = dloc.solve(a0_test, t_sim)
        t3 = time.time()
        a_rk2 = dloc.solve(a0_test, t_sim, alg='RK2')
        t4 = time.time()
        print(t2-t1, t3-t2, t4-t3)
        tmp += [a_krr, a_rk2]
        res.append(tmp)
    np.save('./res/s1_gmk.npy', np.array(res))

if iftrj:
    res = pickle.load(open('./res/s1.pkl', 'rb'))[:,0]
    clr = ['b', 'g', 'r', 'k']

    f = plt.figure(figsize=(6,6))
    for _i in range(4):
        _u = res[_i, :150].T
        plt.plot(*_u, clr[_i]+'.', label=f'D={cases[_i][0][1]}')
    ax = plt.gca()
    ax.set_aspect('equal', adjustable='box')
    ax.tick_params(axis='both', which='major', labelsize=FS)
    plt.legend(loc=1, fontsize=FS-2)
    plt.xlabel('$x_1$', fontsize=FS)
    plt.ylabel('$x_2$', fontsize=FS)
    f.savefig(f'./pics/s1_trj.png', dpi=600, bbox_inches='tight')

if ifplt:
    tmp1 = np.load('./res/s1_gmk.npy')
    tmp2 = np.array([np.load(f'./res/s1_nde_{_i}.npy') for _i in [0,1,3,5]])
    tmp3 = np.array([np.load(f'./res/s1_ldn_{_i}.npy') for _i in [0,1,3,5]])

    f, ax = plt.subplots(nrows=4, sharex=True, figsize=(10,6))
    for _i in range(4):
        a_tru, a_krr, a_rk2 = tmp1[_i]
        a_nde = tmp2[_i]
        a_ldn = tmp3[_i]

        ax[_i].plot(t_plt, a_krr[:,0], 'b-', label='Euler+NC')
        ax[_i].plot(t_plt, a_rk2[:,0], 'r-', label='RK2')
        ax[_i].plot(t_plt, a_nde[:,0], 'g-', label='NODE')
        ax[_i].plot(t_plt, a_ldn[:,0], 'm-', label='LDNet')
        ax[_i].plot(t_plt, a_tru[:,0], 'k--', label='Truth')
        ax[_i].set_ylabel('$x_1$', fontsize=FS)
        ax[_i].set_title(f'D={cases[_i][0][1]}', fontsize=FS)
        ax[_i].tick_params(axis='both', which='major', labelsize=FS)
    ax[2].legend(loc=1, bbox_to_anchor=(1.1,1.5), fontsize=FS-4)
    ax[-1].set_xlabel('t', fontsize=FS)
    plt.subplots_adjust(hspace=0.4)
    f.savefig(f'./pics/s1_prd_fulltime.png', dpi=600, bbox_inches='tight')

    STP = 800
    f, ax = plt.subplots(nrows=2, sharex=True, figsize=(6,4))
    for _i, _k in enumerate([1, 3]):
        a_tru, a_krr, a_rk2 = tmp1[_k]
        a_nde = tmp2[_k]
        a_ldn = tmp3[_k]

        ax[_i].plot(t_plt[-STP:], a_krr[-STP:,0], 'b-', label='Euler+NC')
        ax[_i].plot(t_plt[-STP:], a_rk2[-STP:,0], 'r-', label='RK2')
        ax[_i].plot(t_plt[-STP:], a_nde[-STP:,0], 'g-', label='NODE')
        ax[_i].plot(t_plt[-STP:], a_ldn[-STP:,0], 'm-', label='LDNet')
        ax[_i].plot(t_plt[-STP:], a_tru[-STP:,0], 'k--', label='Truth')
        ax[_i].set_ylabel('$x_1$', fontsize=FS)
        ax[_i].set_title(f'D={cases[_k][0][1]}', fontsize=FS)
        ax[_i].tick_params(axis='both', which='major', labelsize=FS)
    ax[1].legend(loc=1, bbox_to_anchor=(1.0,1.7), fontsize=FS-4)
    ax[-1].set_xlabel('t', fontsize=FS)
    plt.subplots_adjust(hspace=0.4)
    f.savefig(f'./pics/s1_prd.png', dpi=600, bbox_inches='tight')

plt.show()
