"""
The cavity flow problem

GMKRR training and prediction
Comparison with SINDy-type methods
Plotting of results

Disclaimer:
The SINDy-related implementation is obtained from the online documentation of PySINDy
```
https://pysindy.readthedocs.io/en/latest/examples/14_cavity_flow/example.html
```
We retain their implementation as much as possible.
The components of the script that are mostly from PySINDy are marked with **PySINDy**
"""
import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat
from scipy import linalg

import pysindy as ps

from geometric_multivariate_kernel_ridge_regression.src.manifold import Manifold
from geometric_multivariate_kernel_ridge_regression.src.kernelreg import M2KRR
from geometric_multivariate_kernel_ridge_regression.src.dynreg import DynRegMan
from geometric_multivariate_kernel_ridge_regression.src.utils import make_se

def cmpErr(sol):
    err = []
    for _i in range(Jtest):
        _r = data_test[_i]
        _tmp = np.zeros_like(_r)
        _tmp[:,:sol[_i].shape[1]] = sol[_i]
        _e = np.linalg.norm(_tmp-_r, axis=1)
        err.append(_e)
    return np.array(err)

data = loadmat("data/cavityPOD.mat")

t_dns = data['t'].flatten()
dt_dns = t_dns[1] - t_dns[0]
a_dns = data['a']
singular_vals = data['svs'].flatten()
E_dns = np.sum(a_dns**2, axis=1)
a_dim = a_dns.shape[1]

Ntrain = 4000
Ltest = 2000
Jtest = 13

dt_rom = 1e-2
t_sim = np.arange(0, Ltest//100, dt_rom)
t_ene = np.arange(0, 300, dt_rom)

# Training data
data_train = a_dns[:Ntrain]

# Test data
data_test = a_dns[Ntrain:].reshape(-1, Ltest, a_dim)
data_test = data_test[:Jtest]
a0_test   = data_test[:,0,:].squeeze()

# Convergence sequence for N
SKP = 40
idx = [5, 10, 15, 20, 25, 30, 35, 40]

FS = 18

# Training
# iftst=0 is for long-term predictions
# iftst=1 is for statistics of prediction error
iftst = 0
ifnai, ifmsi, ifkrr = 1, 1, 1  # Respectively for SINDy, M-SINDy, and GMKRR

# Plots
ifcnv = 0   # Convergence plot of GMKRR v.s. # of training samples, Fig. 9
iftim = 0   # Comparison of temporal responses, Fig. 8
iferr = 0   # Prediction error over a test dataset of trajectories, Fig. 3(a)
ifene = 0   # Energy preservation over long-term, Fig. 3(b)

# ------------------
# Naive SINDy
# **PySINDy**
# ------------------
if ifnai:
    # The original implementation seems to be using all 30000 data points
    # To be fair, we use 4000 points, same as ours
    # Change to `Ntrain_ref = 30000` for PySINDy results

    Ntrain_ref = 4000
    r = 6  # POD truncation
    x_train = a_dns[:Ntrain_ref, :r]
    t_train = t_dns[:Ntrain_ref]

    optimizer = ps.STLSQ(threshold=1)
    library = ps.PolynomialLibrary(degree=3)
    model = ps.SINDy(
        optimizer=optimizer,
        feature_library=library
    )
    model.fit(x_train, t_train, quiet=True)
    model.print()

    if iftst:
        # Test data
        sol = []
        for _i, _a in enumerate(a0_test):
            print(_i)
            a_nai = model.simulate(_a[:r], t_sim)
            sol.append(a_nai)
        pickle.dump(sol, open(f'./res/cav_cmp_nai_{Ntrain_ref}.dat', 'wb'))
    else:
        # Long-term prediction
        a_nai = model.simulate(x_train[0], t_ene)
        pickle.dump(a_nai, open(f'./res/cav_ene_nai_{Ntrain_ref}.dat', 'wb'))

# ------------------
# Manifold SINDy, with DMD
# **PySINDy**
# We abstracted the methods for cmpl2real, real2cmpl, and reconstruct
# ------------------
if ifmsi:
    # The original implementation seems to be using all 30000 data points
    # To be fair, we use 4000 points, same as ours
    # Change to `Ntrain_ref = 30000` for PySINDy results

    Ntrain_ref = 4000
    # -------------------------
    # First find DMD modes
    # -------------------------
    data = np.array(a_dns[:Ntrain_ref])

    X = data[:-1, :].T  # First time step
    Y = data[1:, :].T  # Time-shifted matrix so that Y = AX
    A = Y @ linalg.pinv(X)

    # Eigendecomposition of low-rank matrix (note these are discrete-time eigenvalues)
    dt_evals, V = linalg.eig(A)
    evals = np.log(dt_evals) / dt_dns

    a_dmd = (linalg.inv(V) @ data.T ).T
    E_dmd = np.mean(abs(a_dmd) ** 2, axis=0)

    sort_idx = np.argsort(-E_dmd)
    V = V[:, sort_idx]
    evals = evals[sort_idx]

    # -------------------------
    # Then fit the manifold representation
    # -------------------------
    def cmpl2real(dat):
        res = np.zeros(dat.shape)
        res[:, ::2]  = np.real(dat[:, ::2])
        res[:, 1::2] = np.imag(dat[:, ::2])
        return res
    def real2cmpl(dat):
        res = np.zeros_like(dat, dtype=np.complex64)
        res[:, ::2]  = dat[:, ::2] + 1j * dat[:, 1::2]
        res[:, 1::2] = dat[:, ::2] - 1j * dat[:, 1::2]
        return res

    a_dmd = linalg.solve(V, data.T).T
    a0_dmd = linalg.solve(V, a0_test.T).T
    a_dmd_real = cmpl2real(a_dmd)
    a0_dmd_real = cmpl2real(a0_dmd)

    active_idx = [0, 1, 4, 5]  # Dynamically active modes

    # "stable" modes whose amplitudes are functions of a_dmd[active_idx]
    stable_idx = [i for i in range(16) if i not in active_idx]

    x_train = a_dmd_real[:, active_idx]  # Inputs to the candidate functions for SINDy
    rhs = a_dmd_real[:, stable_idx]      # Targets for the sparse regression problem

    manifold_library = ps.PolynomialLibrary(degree=5, include_bias=False)
    Theta = manifold_library.fit_transform(x_train)  # Construct the polynomial library

    # Very ill-conditioned Theta so need to be careful here
    manifold = ps.FROLS(kappa=1e-3 / np.linalg.cond(Theta), alpha=1e-3, max_iter=10)
    manifold.fit(Theta, rhs)  # Solve sparse regression problem
    print(np.linalg.norm(manifold.coef_, 0, axis=1))

    # -------------------------
    # Then fit the manifold SINDy
    # -------------------------
    # Construct training data with negative symmetry
    x_train = [a_dmd_real[:, active_idx], -a_dmd_real[:, active_idx]]
    t_train = [t_dns[:Ntrain_ref], t_dns[:Ntrain_ref]]

    # Fit the model
    optimizer = ps.STLSQ(threshold=1e-3)
    library = ps.PolynomialLibrary(degree=3, include_bias=False)
    model = ps.SINDy(
        optimizer=optimizer,
        feature_library=library,
        feature_names=["a_1", "a_2", "a_5", "a_6"]
    )
    model.fit(x_train, t_train, multiple_trajectories=True, quiet=True)
    model.print()

    # -------------------------
    # Lastly predict and map back to original ambient space
    # -------------------------
    def reconstruct(a_dmd_active):
        a_pred_real = np.zeros((len(a_dmd_active), a_dim))
        a_pred_real[:, active_idx] = cmpl2real(a_dmd_active)
        a_pred_real[:, stable_idx] = manifold.predict(
            manifold_library.transform(a_pred_real[:, active_idx]))
        a_pred = real2cmpl(a_pred_real)
        return a_pred
    if iftst:
        sol = []
        for _i, _a in enumerate(a0_dmd_real):
            print(_i)
            a_pred_real = model.simulate(_a[active_idx], t_sim)
            a_pred = real2cmpl(a_pred_real)
            a_msi = np.real(V @ reconstruct(a_pred).T).T
            sol.append(a_msi)
        pickle.dump(sol, open(f'./res/cav_cmp_msi_{Ntrain_ref}.dat', 'wb'))
    else:
        a_pred_real = model.simulate(a_dmd_real[0,active_idx], t_ene)
        a_pred = real2cmpl(a_pred_real)
        a_msi = np.real(V @ reconstruct(a_pred).T).T
        pickle.dump(a_msi, open(f'./res/cav_ene_msi_{Ntrain_ref}.dat', 'wb'))

# ------------------
# GMKRR, convergence on N
# ------------------
if ifkrr:
    order_fd = '1'
    manopt = {
        'd' : 2,
        'g' : 3,
        'T' : 0
    }
    regloc = {
        'man' : Manifold,
        'manopt' : manopt,
        'ker' : make_se(10.0),
        'nug' : 1e-6,
        'ifvec' : True
    }
    dloc = DynRegMan(M2KRR, regloc, fd=order_fd, dt=dt_dns)

    if iftst:
        for _i in idx:
            jdx = np.arange(Ntrain).reshape(Ntrain//SKP, SKP)[:,:_i]
            dloc.fit(data_train[jdx])
            sol = []
            for _a in a0_test:
                a_krr = dloc.solve(_a, t_sim)
                sol.append(a_krr)
            pickle.dump(sol, open(f'./res/cav_cnv_{_i}.dat', 'wb'))
    else:
        dloc.fit([data_train[:4000]])
        a_krr = dloc.solve(data_train[0], t_ene)
        pickle.dump(a_krr, open(f'./res/cav_ene_krr_4000.dat', 'wb'))

# ------------------
# Convergence on N
# ------------------
if ifcnv:
    errs = []
    for _i in idx:
        sol = pickle.load(open(f'./res/cav_cnv_{_i}.dat', 'rb'))
        tmp = cmpErr(sol)
        errs.append(np.max(np.sqrt(np.mean(tmp**2, axis=0))))
    nums = np.array(idx)*Ntrain/SKP

    coe = np.polyfit(np.log(nums), np.log(errs), deg=1)
    ref = np.exp(coe[1]) * nums**coe[0]

    f, ax = plt.subplots()
    ax.loglog(nums, errs, 'ko-', label='Test error', markerfacecolor='none')
    ax.loglog(nums, ref,  'r--', label=f'Reference, order={-coe[0]:3.2f}')
    ax.legend(fontsize=10)
    ax.set_xlabel('$N$')
    xtcks = [nums[0], nums[1], nums[-1]]
    plt.xticks(ticks=xtcks, labels=[str(int(_)) for _ in xtcks], fontsize=12)
    ax.set_ylabel('Max RMSE')
    ytcks = [errs[0], errs[1], errs[-1]]
    plt.yticks(ticks=ytcks, labels=[f"{_:2.1e}" for _ in ytcks], fontsize=12)
    plt.minorticks_off()
    plt.grid()

    f.savefig('pics/cav_cnv_N.png', dpi=600, bbox_inches='tight')

# ------------------
# Temporal response
# ------------------
if iftim:
    IDX = 0
    a_tru = data_test[IDX]
    a_na3 = pickle.load(open('./res/cav_cmp_nai_4000.dat', 'rb'))[IDX]
    a_ms3 = pickle.load(open('./res/cav_cmp_msi_4000.dat', 'rb'))[IDX]
    a_krr = pickle.load(open('./res/cav_cnv_40.dat', 'rb'))[IDX]

    f, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 6))
    for _i in range(2):
        for _j in range(2):
            _k = _i*2
            ax[_i,_j].plot(t_sim, a_tru[:,_k], '-', color='0.5', label='DNS', lw=3)
            ax[_i,_j].plot(t_sim, a_na3[:,_k], 'b:',  label='SINDy', lw=1)
            ax[_i,_j].plot(t_sim, a_ms3[:,_k], 'r-',  label='M-SINDy', lw=1)
            ax[_i,_j].plot(t_sim, a_krr[:,_k], 'k-',  label='GMKRR (Ours)', lw=1.5)
            ax[_i,_j].grid()
        ax[_i,0].set_ylabel(f"Mode {_k+1}")
        ax[_i,0].set_xlim([0, 1])
        ax[_i,1].set_xlim([19, 20])
    ax[1,0].set_ylim([-0.03, 0.03])
    ax[-1,0].set_xlabel('$t$')
    ax[-1,1].set_xlabel('$t$')
    ax[-1,0].legend(loc="upper left", ncol=2, fontsize=10, bbox_to_anchor=(0.0,1.12), fancybox=True, framealpha=0.5)

    f.savefig('pics/cav_tim_cmp.png', dpi=600, bbox_inches='tight')

# ------------------
# Temporal error of trajectories
# ------------------
if iferr:
    def procErr(f):
        sol = pickle.load(open(f, 'rb'))
        tmp = cmpErr(sol)
        avr = np.mean(tmp, axis=0)
        mx  = np.max(tmp, axis=0)
        mn  = np.min(tmp, axis=0)
        return avr, mx, mn

    avr_na3, max_na3, min_na3 = procErr('./res/cav_cmp_nai_4000.dat')
    avr_ms3, max_ms3, min_ms3 = procErr('./res/cav_cmp_msi_4000.dat')
    avr_krr, max_krr, min_krr = procErr('./res/cav_cnv_40.dat')

    f = plt.figure()
    plt.fill_between(t_sim, min_na3, max_na3, color='b', alpha=0.3, edgecolor='none')
    plt.fill_between(t_sim, min_ms3, max_ms3, color='r', alpha=0.3, edgecolor='none')
    plt.fill_between(t_sim, min_krr, max_krr, color='k', alpha=0.3, edgecolor='none')
    plt.plot(t_sim, avr_na3, 'b-', label='SINDy')
    plt.plot(t_sim, avr_ms3, 'r-', label='M-SINDy')
    plt.plot(t_sim, avr_krr, 'k-', label='GMKRR (Ours)')
    plt.gca().tick_params(axis='both', which='major', labelsize=FS)
    plt.legend(loc=2, fontsize=FS)
    plt.xlabel('$t$', fontsize=FS)
    plt.ylabel('RMSE', fontsize=FS)

    f.savefig('pics/cav_err_cmp.png', dpi=600, bbox_inches='tight')

# ------------------
# Energy - long-term response
# ------------------
if ifene:
    def loaddat(f):
        a = pickle.load(open(f, 'rb'))
        return a, np.sum(a**2, axis=1)
    a_ms3, E_ms3 = loaddat('./res/cav_ene_msi_4000.dat')
    a_krr, E_krr = loaddat('./res/cav_ene_krr_4000.dat')

    f, ax = plt.subplots(ncols=4, sharey=True, figsize=(10, 3))
    for _i in range(4):
        ax[_i].semilogy(t_dns, E_dns, '-', color='0.5', label='DNS', lw=3)
        ax[_i].semilogy(t_ene, E_ms3, 'g-',  label='M-SINDy', lw=1)
        ax[_i].semilogy(t_ene, E_krr, 'k-',  label='GMKRR (Ours)', lw=1.5)
        ax[_i].set_ylim([1e-3, 4e-3])
        ax[_i].grid()
        ax[_i].set_xlabel('$t$', fontsize=FS-4)
        ax[_i].tick_params(axis='both', which='both', labelsize=FS-5)
    ax[0].set_ylim([1e-3, 4e-3])
    ax[0].set_ylabel('$E$', fontsize=FS-4)
    ax[0].set_xlim([0, 4])
    ax[1].set_xlim([150, 154])
    ax[2].set_xlim([220, 224])
    ax[3].set_xlim([296, 300])
    ax[3].legend(loc="upper left", ncol=2, fontsize=FS-4, bbox_to_anchor=(-3.65,0.32), fancybox=True, framealpha=0.5)

    f.savefig('pics/cav_ene_cmp.png', dpi=600, bbox_inches='tight')

plt.show()