"""
Dynamics on distorted S1, lateral comparisons
"""
import copy
import pickle
import matplotlib.pyplot as plt
import numpy as np

from geometric_multivariate_kernel_ridge_regression.src.manifold import Manifold
from geometric_multivariate_kernel_ridge_regression.src.kernelreg import M2KRR
from geometric_multivariate_kernel_ridge_regression.src.dynreg import DynRegMan, DynRegWrap
from geometric_multivariate_kernel_ridge_regression.src.utils import make_exp, make_m32, make_m52, make_se, RMSE

s5 = np.sqrt(5)
def make_dyn(K, D, ret_fun=False):
    # \dot{\theta} = 3/2 - \cos(\theta)
    def dyn(tt, K=K, D=D):
        vv = 2 * np.arctan(np.tan(s5*tt/4)/s5)
        rr = 1 + D*np.cos(K*vv)
        uu = np.array([
            rr*np.cos(vv),
            rr*np.sin(vv)]).T
        return vv, uu
    def vec(x, K=K, D=D):
        _x = np.atleast_2d(x)
        _t = np.arctan2(_x[:,1], _x[:,0])
        _r = 1 + D*np.cos(K*_t)
        _d = -K*D*np.sin(K*_t)
        _c, _s = np.cos(_t), np.sin(_t)
        _T = np.vstack([
            -_r*_s+_d*_c, _r*_c+_d*_s]).T
        _T /= np.linalg.norm(_T, axis=1).reshape(-1,1)
        return _T
    if ret_fun:
        def fun(t, x, K=K, D=D):
            _x = np.atleast_2d(x)
            _t = np.arctan2(_x[:,1], _x[:,0])
            _r = 1 + D*np.cos(K*_t)
            _d = -K*D*np.sin(K*_t)
            _c, _s = np.cos(_t), np.sin(_t)
            _T = np.vstack([
                -_r*_s+_d*_c, _r*_c+_d*_s])
            _dx = _T * (1.5-_c)
            return _dx.T
        return dyn, vec, fun
    return dyn, vec

def make_opt(g, T, make_ker, l):
    manopt = {
        'd' : 1,
        'g' : g,
        'T' : T,
        'iforit' : True
    }
    regloc = {
        'man' : Manifold,
        'manopt' : manopt,
        'ker' : make_ker(l),
        'nug' : 1e-6,
        'ifvec' : True
    }
    return regloc

T = 128
Nt = 3201
tt = np.linspace(0, T, Nt)
dt = tt[1]-tt[0]

Ntrain = (Nt-1)//16
Ltest = Nt

t_sim = np.arange(Ltest)*dt
t_plt = t_sim

FS = 16

ifval = 0    # Cross-validation procedure for different kernels
ifkrr = 0    # Prediction using different kernels, Fig. 6
iftan = 0    # Convergence study, and ...
ifcnv = 0    # ... the plots, Fig. 7

# ----------------------------------
if ifval:
    fdyn, fvec = make_dyn(3, 0.5)
    vv, uu = fdyn(tt)
    data_train = uu[:Ntrain]
    data_valid = uu[:Ntrain*2]
    t_val = np.arange(Ntrain*2)*dt

    # Since there are 3 hyperparameters, we use a random search instead
    best = np.inf
    for _ in range(100):
        _g = np.random.randint(3, 8)
        _T = np.random.randint(3, 8)
        _l = 4**(np.random.rand(1)*2 - 1)
        opt = make_opt(_g, _T, make_m32, _l)  # Change the kernel here
        dloc = DynRegMan(M2KRR, opt, fd='1', dt=dt)
        dloc.fit([data_train])
        a_krr = dloc.solve(data_valid[0], t_val)
        e_krr = RMSE(data_valid, a_krr)[1]
        if e_krr < best:
            best = e_krr
            para = [_g, _T, _l]
    print(best, para)

if ifkrr:
    fdyn, fvec = make_dyn(3, 0.5)
    vv, uu = fdyn(tt)
    data_train = uu[:Ntrain]
    data_test  = uu
    a0_test = uu[0]

    reg1 = make_opt(5, 4, make_se,  1.90)
    reg2 = make_opt(3, 4, make_m52, 0.28)
    reg3 = make_opt(3, 4, make_m32, 0.56)
    reg4 = make_opt(3, 5, make_exp, 0.65)

    regs = [reg1, reg2, reg3, reg4]
    ttls = ['SE', 'M52', 'M32', 'M12']
    errs = []
    f, ax = plt.subplots(nrows=2, sharex=True)
    for _j in range(4):
        dloc = DynRegMan(M2KRR, regs[_j], fd='1', dt=dt)
        dloc.fit([data_train])
        a_krr = dloc.solve(a0_test, t_sim)
        for _i in range(2):
            ax[_i].plot(t_plt[-4*Ntrain:], a_krr[-4*Ntrain:,_i], '-', label=f'{ttls[_j]}')
        e_krr = RMSE(data_test, a_krr)[1]
        errs.append(e_krr)
    for _i in range(2):
        ax[_i].plot(t_plt[-4*Ntrain:], data_test[-4*Ntrain:,_i], 'k--', label='Truth')
    print(errs)
    ax[0].set_ylabel('$x_1$', fontsize=FS-2)
    ax[0].tick_params(axis='both', which='major', labelsize=FS-2)
    ax[1].legend(fontsize=FS-4, ncol=3, loc="lower left", bbox_to_anchor=(0.3,0.9))
    ax[1].set_xlabel('t', fontsize=FS-2)
    ax[1].set_ylabel('$x_2$', fontsize=FS-2)
    ax[1].tick_params(axis='both', which='major', labelsize=FS-2)
    f.savefig(f'./pics/s1_dyn_matern.png', dpi=600, bbox_inches='tight')

# ----------------------------------
Ntrain = 40
Ts = [2, 3]
idcs = [8, 16, 32, 64, 128]
if iftan:
    fdyn, fvec, ffun = make_dyn(3, 0.5, ret_fun=True)

    DT = 1e-6
    TT = np.arange(101)*DT
    TR = np.random.rand(100)*4*np.pi/s5
    data_test = [fdyn(TT+_t)[1] for _t in TR]
    a0_test = [d[0] for d in data_test]
    t_sim = TT

    errs = []
    for _T in Ts:
        err = []
        for _i in idcs:
            vv = np.linspace(0, 2*np.pi, _i*Ntrain)
            rr = 1 + 0.5*np.cos(3*vv)
            data_train = np.array([
                rr*np.cos(vv),
                rr*np.sin(vv)]).T

            # Dynamics known analytically
            fman = Manifold(data_train, d=1, g=_T, T=_T, iforit=True)
            pred = DynRegWrap(ffun, fman, dt=DT)
            tmp = []
            for _j, _a in enumerate(a0_test):
                a_krr = pred.solve(_a, t_sim)
                tmp.append(np.linalg.norm(data_test[_j]-a_krr)**2)
            err.append(tmp)
        errs.append(err)
    errs = np.array(errs)
    np.save('./res/s1t.npy', errs)

if ifcnv:
    errs = np.load('./res/s1t.npy')

    cs = ['k', 'r']
    Ns = Ntrain*np.array(idcs)
    f  = plt.figure()
    for _k, _T in enumerate(Ts):
        val = np.sqrt(np.mean(errs[_k], axis=1))
        tmp = (np.log(Ns)/(Ns))**_T
        plt.loglog(Ns, val, 'o-', color=cs[_k], markerfacecolor='none', label=f'$\ell={_T}$')
        plt.loglog(Ns, tmp/tmp[1]*val[1], '--', color=cs[_k], label=f'(log(N)/N)^{_T}')
    handles, labels = plt.gca().get_legend_handles_labels()
    order = [0, 2, 1, 3]
    plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order], ncol=2, fontsize=FS-4)
    plt.xlabel('N', fontsize=FS-2)
    plt.ylabel('RMSE', fontsize=FS-2)
    plt.xticks(ticks=Ns, labels=[str(int(_)) for _ in Ns], fontsize=FS-2)
    plt.yticks(fontsize=FS-2)
    f.savefig('./pics/s1_cnv.png', dpi=600, bbox_inches='tight')

plt.show()
