"""
Various unit and regression tests for the package functionalities
"""

import matplotlib.pyplot as plt
import numpy as np
import scipy.integrate as spi

from geometric_multivariate_kernel_ridge_regression.src.manifold import Manifold, ManifoldAnalytical
from geometric_multivariate_kernel_ridge_regression.src.kernelreg import MKRR, M2KRR
from geometric_multivariate_kernel_ridge_regression.src.dynreg import DynReg, DynRegMan, fd2
from geometric_multivariate_kernel_ridge_regression.src.utils import tangent_1circle, tangent_2torus

a, b = 1.0, 2.0
def torus_sample(Nsmp):
    tmp = np.random.rand(2,Nsmp)*2*np.pi
    x = (a*np.cos(tmp[0])+b) * np.cos(tmp[1])
    y = (a*np.cos(tmp[0])+b) * np.sin(tmp[1])
    z = a*np.sin(tmp[0])
    tar = np.vstack([x, y, z]).T
    return tar

ifdim = 1   # Estimation of intrinsic dimensions
ifman = 1   # Estimation of normal vectors
iftan = 1   # Estimation of tangent vectors
ifort = 1   # Orientation of tangent vectors
ifas1 = 1   # Analytical model for unit circle
ifat2 = 1   # Analytical model for 2-torus
iffit = 1   # Generic GMLS fit
ifkrr = 1   # Generic MKRR fit
ifdyn = 1   # Simple test of GMKRR and time integrator

if ifdim:
    dat = torus_sample(5000)

    man = Manifold(dat, d=2)
    dim, (f, ax) = man.estimate_intrinsic_dim(bracket=[-20, 5], tol=0.2, ifplt=True)
    _, _, f = man.visualize_intrinsic_dim()

if ifman:
    dat = torus_sample(5000)
    man = Manifold(dat, d=2)
    man.precompute()
    f, ax = man.plot3d(10, scl=0.5)
    ax.plot([-2,2], [-2,2], [-2,2], 'w.')

    base = man._data[0]
    T = man._T[0]
    s = np.linspace(0, 1, 11)
    s = np.vstack([s,s]).T
    dx = s.dot(T)
    x1 = man._data[0] + dx
    dn = man._estimate_normal(base, x1)
    xs = x1+dn
    ax.plot(xs[:,0], xs[:,1], xs[:,2], 'r-')

    t = np.random.rand(200)*2*np.pi
    dat = np.vstack([np.cos(t), np.sin(t)]).T
    man = Manifold(dat, d=1)
    man.precompute()
    f, ax = man.plot2d(8, scl=0.5)
    ax.set_aspect('equal', adjustable='box')

    base = man._data[0]
    T = man._T[0]
    s = np.linspace(0, 1, 11).reshape(-1,1)
    dx = s.dot(T)
    x1 = man._data[0] + dx
    dn = man._estimate_normal(base, x1)
    xs = x1+dn
    ax.plot(xs[:,0], xs[:,1], 'r-')

if iftan:
    t = np.random.rand(200)*2*np.pi
    dat = np.vstack([np.cos(t), np.sin(t)]).T
    nrm = dat

    cmpNrm = lambda tru, man: np.abs(np.sum(tru*np.array(man._T).squeeze(), axis=1))

    m1 = Manifold(dat, d=1, T=0)
    m1.precompute()
    f, ax = m1.plot2d(8, scl=0.5)
    ax.set_aspect('equal', adjustable='box')

    m2 = Manifold(dat, d=1, T=3)
    m2.precompute()
    f, ax = m2.plot2d(8, scl=0.5)
    ax.set_aspect('equal', adjustable='box')

    e1 = cmpNrm(nrm, m1)
    e2 = cmpNrm(nrm, m2)
    f = plt.figure()
    plt.semilogy(e1, 'b-', label='Local PCA')
    plt.semilogy(e2, 'r-', label='GMLS estimate')
    plt.legend()
    plt.xlabel("Sample index")
    plt.ylabel("$n\cdot t$")

if ifort:
    t = np.random.rand(200)*2*np.pi
    dat = np.vstack([np.cos(t), np.sin(t)]).T

    m1 = Manifold(dat, d=1, T=3, iforit=True)
    m1.precompute()
    f, ax = m1.plot2d(200, scl=0.5)
    ax.set_aspect('equal', adjustable='box')

if ifas1:
    t = np.random.rand(200)*2*np.pi
    dat = np.vstack([np.cos(t), np.sin(t)]).T

    man = ManifoldAnalytical(dat, d=1, g=4, fT=tangent_1circle)
    man.precompute()
    f, ax = man.plot2d(10, scl=0.5)
    ax.set_aspect('equal', adjustable='box')

    base = man._data[0]
    T = man._T[0]
    s = np.linspace(0, 1, 11).reshape(-1,1)
    x1 = base + s.dot(T)
    dn = man._estimate_normal(base, x1)
    xs = x1 + dn
    ax.plot(xs[:,0], xs[:,1], 'r-')
    ax.plot(x1[:,0], x1[:,1], 'b-')

if ifat2:
    dat = torus_sample(2000)

    fT = lambda x: tangent_2torus(x, R=2)

    man = ManifoldAnalytical(dat, d=1, fT=fT)
    man.precompute()
    f, ax = man.plot3d(30, scl=0.5)
    ax.plot([-2,2], [-2,2], [-2,2], 'w.')

if iffit:
    dat = torus_sample(500)
    tar = torus_sample(500)
    Ytrn = np.linalg.norm(np.sin(dat), axis=1)
    Ytst = np.linalg.norm(np.sin(tar), axis=1)
    man = Manifold(dat, d=2, g=4)
    man.precompute()
    res = [man.gmls(_t, Ytrn) for _t in tar]

    f = plt.figure()
    plt.plot(Ytst, 'bo', markerfacecolor='none')
    plt.plot(res,  'rs', markerfacecolor='none')

if ifkrr:
    def func(inp):
        x  = np.atleast_2d(inp)
        y1 = x[:,0]**2 + np.sin(x[:,1])
        y2 = np.exp(x[:,0]*x[:,1])
        return np.vstack([y1, y2]).T.squeeze()
    def fker(x, y):
        k = np.exp(-np.sum((x-y)**2)/0.5)
        K = np.array([
            [k, 0],
            [0, k]])
        return K
    Xtrn = np.random.rand(30,2)
    Ytrn = func(Xtrn)
    Ns = 21
    s = np.linspace(0,1,Ns)
    S, T = np.meshgrid(s, s)
    Xtst = np.vstack([S.reshape(-1), T.reshape(-1)]).T
    Ytst = func(Xtst)

    mkrr = MKRR(fker)
    mkrr.fit(Xtrn, Ytrn)
    Yprd = mkrr.predict(Xtst)

    err = np.linalg.norm(Ytst-Yprd, axis=1) / np.linalg.norm(Ytst, axis=1)
    f = plt.figure()
    plt.plot(err, 'bo', markerfacecolor='none')

    f, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(12,8))
    for _i in range(2):
        yt, yp = Ytst[:,_i].reshape(Ns, Ns), Yprd[:,_i].reshape(Ns, Ns)
        cs1 = ax[_i,0].contourf(S, T, yt)
        ax[_i,1].contourf(S, T, yp, levels=cs1.levels)
        cs2 = ax[_i,2].contourf(S, T, np.abs(yt-yp))
        plt.colorbar(cs1, ax=ax[_i,0])
        plt.colorbar(cs1, ax=ax[_i,1])
        plt.colorbar(cs2, ax=ax[_i,2])
    ax[1,0].plot(Xtrn[:,0], Xtrn[:,1], 'wo', markerfacecolor='none')

    for _i in range(2):
        ax[_i,0].set_ylabel('x2')
        ax[_i,0].set_title(f'y{_i+1} truth')
        ax[_i,1].set_title(f'y{_i+1} pred')
        ax[_i,2].set_title(f'y{_i+1} error')
    for _i in range(3):
        ax[1,_i].set_xlabel('x1')

if ifdyn:
    def dyn(t):
        s = t + 0.5*np.sin(t)
        return np.vstack([np.cos(s), np.sin(s)]).T

    def fker_glo(x, y):
        k = np.exp(-np.sum((x-y)**2)/0.1)
        K = np.array([
            [k, 0],
            [0, k]])
        return K
    def fker_loc(x, y):
        return np.exp(-np.sum((x-y)**2)/0.1)

    Nt = 201
    ttrn = np.linspace(0, 10, Nt)
    dt = ttrn[1]-ttrn[0]
    Xtrn = dyn(ttrn)
    Ytrn = fd2(Xtrn, dt)
    ttst = np.linspace(10, 20, Nt)
    Xtst = dyn(ttst)
    Ytst = fd2(Xtst, dt)

    order_fd = '1'

    # f, ax = plt.subplots(ncols=3, nrows=1, figsize=(12,4))
    # ax[0].plot(Xtrn[:,0], Xtrn[:,1], 'b.')
    # ax[0].plot(Xtst[:,0], Xtst[:,1], 'r.')
    # ax[1].plot(ttrn, Xtrn[:,0], 'b-')
    # ax[1].plot(ttrn, Xtrn[:,1], 'r-')
    # ax[1].plot(ttst, Xtst[:,0], 'b--')
    # ax[1].plot(ttst, Xtst[:,1], 'r--')
    # ax[2].plot(ttrn, Ytrn[:,0], 'b-')
    # ax[2].plot(ttrn, Ytrn[:,1], 'r-')
    # ax[2].plot(ttst, Ytst[:,0], 'b--')
    # ax[2].plot(ttst, Ytst[:,1], 'r--')

    # Global version
    regglo = {
        'ker' : fker_glo,
        'nug' : 1e-6
    }
    dglo = DynReg(MKRR, regglo, fd=order_fd, dt=dt)
    dglo.fit([Xtrn])

    Yglo = dglo.predict(Xtst)
    Xglo = dglo.solve(Xtst[0], ttst)

    # Local version
    manopt = {
        'd' : 1,
        'g' : 2
    }
    regloc = {
        'man' : Manifold,
        'manopt' : manopt,
        'ker' : fker_loc,
        'nug' : 1e-6
    }
    dloc = DynRegMan(M2KRR, regloc, fd=order_fd, dt=dt)
    dloc.fit([Xtrn])

    Yloc = dloc.predict(Xtst)
    Xloc = dloc.solve(Xtst[0], ttst)

    # Compare rate
    f = plt.figure()
    plt.plot(ttst, Yloc[:,0], 'b-')
    plt.plot(ttst, Yglo[:,0], 'b:')
    plt.plot(ttst, Ytst[:,0], 'k--')
    plt.plot(ttst, Yloc[:,1], 'r-')
    plt.plot(ttst, Yglo[:,1], 'r:')
    plt.plot(ttst, Ytst[:,1], 'k--')

    # Compare prediction
    f, ax = plt.subplots(ncols=2, figsize=(10,5))
    ax[0].plot(Xloc[:,0], Xloc[:,1], 'b-')
    ax[0].plot(Xglo[:,0], Xglo[:,1], 'b:')
    ax[0].plot(Xtst[:,0], Xtst[:,1], 'k--')
    ax[1].plot(ttst, Xloc[:,0], 'b-')
    ax[1].plot(ttst, Xglo[:,0], 'b:')
    ax[1].plot(ttst, Xtst[:,0], 'k--')
    ax[1].plot(ttst, Xloc[:,1], 'r-')
    ax[1].plot(ttst, Xglo[:,1], 'r:')
    ax[1].plot(ttst, Xtst[:,1], 'k--')

plt.show()
