'''
@package dynreg.py

Implementation of
(1) Training on trajectory data
(2) Time integrators of learned model, including the proposed Euler+NC

Learns dynamics
    \dot{x} = f(x)
from trajectory data
    {x_1, x_2, ..., x_n}
There can be multiple trajectories
'''
import copy
import numpy as np
import scipy.integrate as spi
from tqdm import tqdm

# Finite difference schemes
def fd1(dat, dt):
    dx = np.vstack([
        dat[1:]-dat[:-1],
        dat[-1]-dat[-2]
    ]) / dt
    return dx
def fd2(dat, dt):
    dx = np.vstack([
        -3*dat[0]+4*dat[1]-dat[2],
        dat[2:]-dat[:-2],
        dat[-3]-4*dat[-2]+3*dat[-1]
    ]) / (2*dt)
    return dx
fdMap = {
    '1' : fd1,
    '2' : fd2,
}

class DynReg:
    """
    Generic implementation that fits dynamics by a generic regressor.
    """
    def __init__(self, reg, regopt, fd='1', dt=1):
        self._REG = reg         # Regressor object
        self._regopt = copy.deepcopy(regopt)    # Options for Regressor
        self._FD = fdMap[fd]    # Method for finite difference
        self._dt = dt           # Step size - assuming constant

    def fit(self, trjs, dots=None):
        """
        Fit the model given trajectory data.
        When dots=None, xdot is computed by finite difference
        """
        self._proc_data(trjs, dots)
        self._regressor = self._REG(**self._regopt)
        self._regressor.fit(self._X, self._Y)

    def predict(self, x):
        """
        Compute xdot given x
        """
        return self._regressor.predict(x)

    def solve(self, x0, ts, **odeopt):
        """
        The default time integrator by SciPy IVP solver
        """
        def func(t, x):
            return self._regressor.predict(x)
        sol = spi.solve_ivp(func, [ts[0], ts[-1]], x0, t_eval=ts, **odeopt)
        return sol.y.T

    def _proc_data(self, trjs, dots):
        """
        Process training data.
        X are states, Y are xdot's
        """
        _trjs = np.atleast_3d(trjs)
        _X, _Y = [], []
        if dots is None:
            for _trj in _trjs:
                _X.append(_trj)
                _Y.append(self._FD(_trj, self._dt))
        else:
            _dots = np.atleast_3d(dots)
            for _trj, _dot in zip(_trjs, _dots):
                _X.append(_trj)
                _Y.append(_dot)
        self._X = np.vstack(_X)
        self._Y = np.vstack(_Y)
        self._Ndat, self._Ndim = self._X.shape

class DynRegMan(DynReg):
    """
    Implementation of various time integrators, including Euler+NC (FENC)
    """
    def fit(self, trjs, dots=None):
        """
        Assumes that the regressor has a manifold object
        """
        super().fit(trjs, dots)
        self._normal = self._regressor._manifold._estimate_normal

    def solve(self, x0, t_eval, alg='FENC', **odeopt):
        print(f"====Solving [{t_eval[0]}, {t_eval[-1]}]")
        ts = t_eval
        if alg == 'FENC':
            # First-order Euler with Normal Correction
            sol, x_old = [x0], x0
            for _i in tqdm(range(1,len(ts))):
                dx = self.predict(x_old)
                x1 = x_old + dx * (ts[_i]-ts[_i-1])
                dn = self._normal(x_old, x1)
                x_old = x1 + dn
                sol.append(x_old)
            xp = np.vstack(sol)
        elif alg == 'RK2':
            # Runge-Kutta-2
            sol, x_old = [x0], x0
            for _i in tqdm(range(1,len(ts))):
                _t = ts[_i]-ts[_i-1]
                dx1 = self.predict(x_old)
                dx2 = self.predict(x_old + dx1*(_t/2))
                x_old = x_old + _t * dx2
                sol.append(x_old)
            xp = np.vstack(sol)
        elif alg == 'RK4':
            # Runge-Kutta-4
            sol, x_old = [x0], x0
            for _i in tqdm(range(1,len(ts))):
                _t = ts[_i]-ts[_i-1]
                dx1 = self.predict(x_old)
                dx2 = self.predict(x_old + dx1*(_t/2))
                dx3 = self.predict(x_old + dx2*(_t/2))
                dx4 = self.predict(x_old + dx3*_t)
                x_old = x_old + _t/6 * (dx1 + 2*dx2 + 2*dx3 + dx4)
                sol.append(x_old)
            xp = np.vstack(sol)
        else:
            # Otherwise use SciPy IVP solver, which defaults to RK45
            def func(t, x):
                return self.predict(x)
            sol = spi.solve_ivp(func, [t_eval[0], t_eval[-1]], x0, t_eval=t_eval, **odeopt)
            xp = sol.y.T
        return xp

class DynRegWrap(DynRegMan):
    """
    Dynamics model using a pretrained or an analytically-known model
    """
    def __init__(self, dyn, man, dt=1):
        # `dyn` follows the signature of input to solve_ivp
        self._dyn = dyn
        self._normal = man._estimate_normal
        self._dt = dt

    def fit(self, trjs, dots=None):
        raise NotImplementedError("Class DynRegWrap does not support fitting.")

    def predict(self, x):
        # The input 0 is "time", which is ignored; this is for compatibility with solve_ivp
        return self._dyn(0, x).squeeze()
