'''
@package kernelreg.py

Implementation of
(1) Multivariate kernel ridge regression (MKRR)
(2) Manifold-based MKRR (M2KRR)

Learns a mapping from X to Y, regardless of meaning of X and Y.
'''

import copy
import time
import numpy as np
import scipy.linalg as spl
from tqdm import tqdm

class MKRR:
    """
    A generic implementation of multivariate regression, that does not account for any structure.
    This class mainly provides the interfaces.  In practice, only M2KRR is used.
    """
    def __init__(self, ker=None, nug=1e-10):
        self._fker = ker
        self._nugg = nug

        # Possible data members
        self._datX = []
        self._datY = []

    def fit(self, X, Y):
        print("----Fitting...")
        self._proc_data(X, Y)

        print("------Form K...")
        t1 = time.time()
        K = self._form_K(self._datX, None)
        A = K + self._nugg * np.eye(self._Ndat*self._Nout)
        t2 = time.time()
        print(f"    t = {t2-t1:4.3e}")

        print("------Linear solve...")
        t3 = time.time()
        self._c = spl.solve(A, self._datY.reshape(-1), assume_a='sym')
        t4 = time.time()
        print(f"    t = {t4-t3:4.3e}")
        print("----Done")

    def predict(self, inp):
        X = np.atleast_2d(inp)
        _Npnt, _Ninp = X.shape
        assert _Ninp == self._Ninp
        k = self._form_K(X, self._datX)
        Y = k.dot(self._c).reshape(-1, self._Nout)
        return Y

    def _proc_data(self, X, Y):
        self._datX = np.array(X)
        self._datY = np.array(Y)
        self._Ndat, self._Ninp = self._datX.shape
        tmp, self._Nout = self._datY.shape
        assert self._Ndat == tmp

    def _form_K(self, X1, X2):
        _N = len(X1)
        _B = self._Nout
        if X2 is None:
            # Self-interaction
            K = np.zeros((_N*_B, _N*_B))
            for _i in tqdm(range(_N)):
                K[_i*_B:(_i+1)*_B, _i*_B:(_i+1)*_B] = self._fker(X1[_i], X1[_i])
                for _j in range(_i+1, _N):
                    K[_i*_B:(_i+1)*_B, _j*_B:(_j+1)*_B] = self._fker(X1[_i], X1[_j])
                    K[_j*_B:(_j+1)*_B, _i*_B:(_i+1)*_B] = K[_i*_B:(_i+1)*_B, _j*_B:(_j+1)*_B].T
        else:
            # Cross-interaction
            _M = len(X2)
            K = np.zeros((_N*_B, _M*_B))
            for _i in range(_N):
                for _j in range(_M):
                    K[_i*_B:(_i+1)*_B, _j*_B:(_j+1)*_B] = self._fker(X1[_i], X2[_j])
        return K

class M2KRR(MKRR):
    """
    Assuming inputs are on a manifold, and outputs are in the tangent space.
    """
    def __init__(self, man, manopt, ker=None, nug=1e-10, ifvec=False):
        self._MAN = man
        self._manopt = copy.deepcopy(manopt)
        super().__init__(ker=ker, nug=nug)

        # If scalar kernel evaluation is vectorized
        self._ifvec = ifvec
        if self._ifvec:
            # Kernel is evaluated in batch
            # Should be preferred whenever possible
            self._frho = self._ker_vec
        else:
            # Kernel is evaluated pair-wise by a double-loop
            # Can be very slow
            self._frho = self._ker_scl

    def fit(self, X, Y):
        print("----Fitting...")
        self._proc_data(X, Y)

        print("------Form K...")
        t1 = time.time()
        K = self._form_K(self._datX, None)
        t2 = time.time()
        print(f"    t = {t2-t1:4.3e}")

        print("------Linear solve...")
        t1 = time.time()
        self._c = spl.solve(K, self._datY.reshape(-1), assume_a='sym').reshape(self._Ndat, 1, self._Nout)
        self._Tc = np.sum(self._T*self._c, axis=2).reshape(self._Ndat, self._Ninp, 1)
        t2 = time.time()
        print(f"    t = {t2-t1:4.3e}")
        print("----Done")

    def _proc_data(self, X, Y):
        super()._proc_data(X, Y)

        # Manifold
        self._manifold = self._MAN(self._datX, **self._manopt)
        self._manifold.precompute()
        self._T = np.transpose(np.array(self._manifold._T), axes=(0, 2, 1))
        self._tangent = self._manifold._estimate_tangent
        self._Nout = self._manifold._Nman
        assert self._T.shape == (self._Ndat, self._Ninp, self._Nout)

        # Replace datY
        _Y = []
        for _i in range(self._Ndat):
            _Y.append(self._manifold._T[_i].dot(self._datY[_i]))
        self._datY = np.vstack(_Y).reshape(-1, self._Nout)

    def predict(self, X):
        _X = np.atleast_2d(X)
        assert _X.shape[1] == self._Ninp
        _s, _Ts = self._form_K(_X, self._datX)
        _D = _s.dot(_Ts).squeeze()
        return _D

    def _form_K(self, X1, X2):
        _N = len(X1)
        _B = self._Nout
        if X2 is None:
            # Self-interaction
            # Only forming the upper triangular part, that is needed in linalg.solve
            _tmp = np.dot(np.transpose(self._T, axes=(0, 2, 1)), self._T)
            _I = np.eye(_B)
            for _i in range(_N):
                _tmp[_i,:,_i,:] = _I
            _rho = self._frho(X1, None)
            _rho += np.eye(_N) * self._nugg
            K = np.reshape(_rho.reshape(_N, 1, _N, 1) * _tmp, (_N*_B, _N*_B))
            return K
        else:
            # Cross-interaction
            # New input: X1, N
            # Data: X2, M
            # To accelerate computation, only the K-vector product is returned
            _M = len(X2)
            Ts = np.array([self._tangent(_x) for _x in X1])
            _rho = self._frho(X1, X2).reshape(_N, 1, _M, 1)

            # N x B x Ndat x 1 -> N x 1 x B
            ks = np.sum(_rho * Ts.dot(self._Tc), axis=2).reshape(_N, 1, _B)

            return ks, Ts

    def _ker_scl(self, X1, X2):
        _N = len(X1)
        if X2 is None:
            _rho = np.zeros((_N, _N))
            for _i in range(_N):
                for _j in range(_i, _N):
                    _rho[_i, _j] = self._fker(X1[_i], X1[_j])
        else:
            _M = len(X2)
            _rho = np.zeros((_N, _M))
            for _i in range(_N):
                for _j in range(_M):
                    _rho[_i, _j] = self._fker(X1[_i], X2[_j])
        return _rho

    def _ker_vec(self, X1, X2):
        if X2 is None:
            return self._fker(np.atleast_2d(X1), np.atleast_2d(X1))
        return self._fker(np.atleast_2d(X1), np.atleast_2d(X2))
