from typing import Literal, Union, Callable
from numpy.typing import ArrayLike
from logging import Logger
from math import sqrt
import numpy as np

from src.utils import EigResult,ModeResult,add_diagonal_,rank_reveal,weighted_norm,fuzzy_parse_complex
from scipy.linalg import eig, eigh
from scipy.sparse.linalg import eigsh

def primal_fit_to(
    Z : ArrayLike,  # Feature matrix for equaly spaced trajectory data of the shape [num_training_points, features]
    dt: float,  # Time step
    step: int =1, # Multiple of time-step  
    tikhonov_reg: float = 0.,  # Tikhonov (ridge) regularization parameter, can be 0
    rank: Union[int,None] = None,  # Rank of the estimator
    symmetry: Literal[
        "symmetric", "antisymmetric", None
    ] = None,
    # Whether the generator is self-adjoint or not
    svd_solver: Literal[
        "arnoldi", "full"
    ] = "arnoldi",  # SVD solver to use. 'arnoldi' is faster but might be numerically unstable.
) -> EigResult:
    ''' link with the paper 
    Carreful Z paper not Z code 
    code = paper = formule
    C = C = S^*S
    H = T = S^*Z
    V = C^-1 = V V^T 
    W (V^T T V) = ... = 
    '''
    ###
    # Number of data points
    npts = Z.shape[0] - step
    eps = 1000.0 * np.finfo(Z.dtype).eps
    penalty = max(eps, tikhonov_reg)
    
    C = Z[:-step].T @ Z[:-step] / npts #S
    if symmetry=='symmetric':
        H =  Z[:-step].T @ Z[step:] / npts
        H +=  H.T
        H /= 2 
    elif symmetry=='antisymmetric':
#        step = 2
#        npts -= (step-1)
        H =  Z[:-step].T @ Z[step:] / npts
        H -=  H.T
        H /= 2
    else:
        H =  Z[:-step].T @ Z[step:] / npts

    add_diagonal_(C, penalty)

    if rank is not None:
        # Find U via Generalized eigenvalue problem equivalent to the SVD. If K is ill-conditioned might be slow.
        # Prefer svd_solver == 'randomized' in such a case.
        if svd_solver == "arnoldi":
            # Adding a small buffer to the Arnoldi-computed eigenvalues.
            num_arnoldi_eigs = min(rank + 5, npts)
            values, vectors = eigsh(H@H.T, k=num_arnoldi_eigs, M=C)
        elif svd_solver == "full":  # 'full'
            values, vectors = eigh(H@H.T, C, overwrite_a=True, overwrite_b=True)
        else:
            raise ValueError(f"Unknown svd_solver: {svd_solver}")
        if rank == Z.shape[1]: 
            #logger.warning(
            #    f"Warning: Full rank estimator is chosen, hence rank reduction bias is not estimated."
            #)
            numerically_nonzero_values_idxs = rank_reveal(values, rank, ignore_warnings=False)
            values = np.sqrt(values[numerically_nonzero_values_idxs])
            bias_sigma = 1
            vectors = vectors[:, numerically_nonzero_values_idxs]
        else:
            numerically_nonzero_values_idxs = rank_reveal(values, rank+1, ignore_warnings=False)
            bias_sigma = sqrt(np.abs(values[numerically_nonzero_values_idxs][-1]))
            values = values[numerically_nonzero_values_idxs][:-1]
            vectors = vectors[:, numerically_nonzero_values_idxs][:,:-1]

        # Compare the filtered eigenvalues with the regularization strength, and warn if there are any eigenvalues that are smaller than the regularization strength.
        if not np.all(np.abs(values) >= tikhonov_reg):
            Logger.warning(
                msg = f"Warning: {(np.abs(values) < tikhonov_reg).sum()} out of the {len(values)} squared singular values are smaller than the regularization strength {tikhonov_reg:.2e}. Consider redudcing the regularization strength to avoid overfitting."
            )

        # Eigenvector normalization
        vecs_norm = weighted_norm(vectors, C)

        stable_values_idxs = rank_reveal(
            vecs_norm, rank, rcond=1000.0 * np.finfo(values.dtype).eps
        )
        V = vectors[:, stable_values_idxs] / vecs_norm[stable_values_idxs]
        values = values[stable_values_idxs]
    else:
        bias_sigma = 1
        values, vectors = eigh(C, overwrite_a=True, overwrite_b=True)
        numerically_nonzero_values_idxs = rank_reveal(values, Z.shape[1], ignore_warnings=False)
        values = np.sqrt(values[numerically_nonzero_values_idxs])
        vectors = vectors[:, numerically_nonzero_values_idxs]
        V = vectors/values

    W = np.linalg.multi_dot([V.T, H, V])

    if symmetry == 'symmetric':
        values, vr = eigh(W, overwrite_a=True, overwrite_b=True)
        vl = vr
    elif symmetry=='antisymmetric':
        values, vr = eigh(W/1j, overwrite_a=True, overwrite_b=True)
        vl = vr
    else:
        values, vl, vr = eig(W, left=True, right=True)  
        values = fuzzy_parse_complex(values)

    r_perm = np.argsort(values)
    vr = vr[:, r_perm]
    # l_perm = np.argsort(values.conj())
    vl = vl[:, r_perm]
    values = values[r_perm]
    
    # transforming the eigenvalues 
    if symmetry=='symmetric': #hyperbolic cosine of the generator L with time dt is learned, i.e. cosh(L dt), which is the same as exp(L dt) for L^*=L
        values = np.log(values)/(step*dt)
    elif symmetry=='antisymmetric': # when L^*=-L hyperbolic sine of the generator L with time dt is learned, i.e. sinh(L step dt)/step
        values = 1j*np.arcsinh(values)/(step*dt)
    else: #TO for time step dt is learned, i.e. exp(L dt), without symmetry constrains
        values = np.log(values)/(step*dt)
    
    ## Normalization in RKHS norm
    rcond = 1000.0 * np.finfo(Z.dtype).eps
    vr = V @ vr
    r_norm = np.linalg.norm(vr, axis=0)
    r_norm = np.where(np.abs(r_norm) < rcond, np.inf, r_norm)
    vr = vr / r_norm
    #bias = bias_sigma / weighted_norm(vr,C) 
    
    # Biorthogonalization
    ZZ = np.zeros_like(Z) 
    if symmetry=='symmetric':
        ZZ[step:] += Z[:-step]
        ZZ[:-step] += Z[step:]
        ZZ /=2
        vl = np.linalg.multi_dot([ZZ, V, vl]) /sqrt(npts)
    elif symmetry=='antisymmetric': # when L^*=-L hyperbolic sine of the generator L with time dt is learned, i.e. 2 sinh(L dt)
        ZZ[step:] += Z[:-step]
        ZZ[:-step] -= Z[step:]
        ZZ = ZZ / (2*1j)
        vl = np.linalg.multi_dot([ZZ, V, vl]) /sqrt(npts)
    else: #TO for time step dt is learned, i.e. exp(L dt), without symmetry constrains
        ZZ[step:] += Z[:-step]
        vl = np.linalg.multi_dot([ZZ, V, vl]) /sqrt(npts)

    #l_norm = np.where(np.abs(values) < rcond, np.inf, values.conj() / r_norm)
    #vl = vl / l_norm

    l_norm = np.sum((Z.T@vl).conj() * vr / sqrt(npts), axis=0).conj()
    l_norm = np.where(np.abs(l_norm) < rcond, np.inf, l_norm)
    vl = sqrt(1+step/npts) * vl / l_norm

        # Spectral bias
    bias = bias_sigma / weighted_norm(vr,C) 

    result: EigResult = {"values": values, "left": vl , "right": vr, "bias": bias}
    return result


def dual_fit_to(
    K : ArrayLike,  # Kernel matrix for equaly spaced trajectory data
    dt: float,  # Time step
    step: int = 1, # time step for learning, time scale of the learning algorithm is step dt   
    tikhonov_reg: float = 0.,  # Tikhonov (ridge) regularization parameter, can be 0
    rank: Union[int,None] = None,  # Rank of the estimator
    symmetry: Literal[
        "symmetric", "antisymmetric", None
    ] = None,
    # Whether the generator is self-adjoint or not
    svd_solver: Literal[
        "arnoldi", "full"
    ] = "arnoldi",  # SVD solver to use. 'arnoldi' is faster but might be numerically unstable.
) -> EigResult:
    npts = K.shape[0] - step   # Number of data points
    eps = 1000.0 * np.finfo(K.dtype).eps
    penalty = max(eps, tikhonov_reg)*npts

    if symmetry=='symmetric':
        M =  np.diag(np.ones(npts),step)
        M +=  M.T
        M /= 2
    elif symmetry=='antisymmetric':
        M =  np.diag(np.ones(npts),step)
        M -=  M.T
        M /= 2*step
    else:
        M =  np.diag(np.ones(npts),step)

    if rank is not None:
        # Find U via Generalized eigenvalue problem equivalent to the SVD. If K is ill-conditioned might be slow.
        A = np.linalg.multi_dot([M, K / sqrt(npts), M.T, K / sqrt(npts)])
     
        add_diagonal_(K, penalty)
        if svd_solver == "arnoldi":
            # Adding a small buffer to the Arnoldi-computed eigenvalues.
            num_arnoldi_eigs = min(rank + 5, npts)
            values, vectors = eigsh(A, k=num_arnoldi_eigs, M=K)
        elif svd_solver == "full":  # 'full'
            values, vectors = eig(A, K, overwrite_a=True, overwrite_b=True)
        else:
            raise ValueError(f"Unknown svd_solver: {svd_solver}")
        # Remove the penalty from K (inplace)
        add_diagonal_(K, -penalty)

        numerically_nonzero_values_idxs = rank_reveal(values, rank+1, ignore_warnings=False)
        bias_sigma_sq = np.abs(values[numerically_nonzero_values_idxs][-1])
        values = values[numerically_nonzero_values_idxs][:-1]
        vectors = vectors[:, numerically_nonzero_values_idxs][:,:-1]
        # Compare the filtered eigenvalues with the regularization strength, and warn if there are any eigenvalues that are smaller than the regularization strength.
        if not np.all(np.abs(values) >= tikhonov_reg):
            Logger.warning(
                f"Warning: {(np.abs(values) < tikhonov_reg).sum()} out of the {len(values)} squared singular values are smaller than the regularization strength {tikhonov_reg:.2e}. Consider redudcing the regularization strength to avoid overfitting."
            )

        # Eigenvector normalization
        K_vecs = np.dot(K / sqrt(npts), vectors)
        vecs_norm = np.sqrt(
            np.sum(
                K_vecs**2 + tikhonov_reg * K_vecs * vectors * sqrt(npts),
                axis=0,
            )
        )

        stable_values_idxs = rank_reveal(
            vecs_norm, rank, rcond=1000.0 * np.finfo(values.dtype).eps
        )
        U = vectors[:, stable_values_idxs] / vecs_norm[stable_values_idxs]
        values = values[stable_values_idxs]

        V = K @ U
    else:
        Logger.error(
                f"Error: Full rank kernel method not supported."
            )
    
    # Eigenvalue decomposition
    W = np.linalg.multi_dot([V.T, M, V]) / npts

    if symmetry == 'symmetric':
        values, vr_ = eigh(W, overwrite_a=True, overwrite_b=True)
        vl_ = vr_
    elif symmetry == 'antisymmetric':
        values, vr_ = eigh(W/1j, overwrite_a=True, overwrite_b=True)
        values = 1j*values
        vl_ = vr_
    else:
        values, vl_, vr_ = eig(W, left=True, right=True)  
        values = fuzzy_parse_complex(values)
    r_perm = np.argsort(values)
    vr_ = vr_[:, r_perm]
    # l_perm = np.argsort(values.conj())
    vl_ = vl_[:, r_perm]
    values = values[r_perm]

    # transforming the eigenvalues 
    if symmetry=='symmetric': #hyperbolic cosine of the generator L with time dt is learned, i.e. cosh(L dt), which is the same as exp(L dt) for L^*=L
        values = np.log(values)/(step*dt)
    elif symmetry=='antisymmetric': # when L^*=-L hyperbolic sine of the generator L with time dt is learned, i.e. sinh(L step dt)/step
        values = np.arcsinh(step*values)/(step*dt)
    else: #TO for time step dt is learned, i.e. exp(L dt), without symmetry constrains
        values = np.log(values)/(step*dt)
    
    rcond = 1000.0 * np.finfo(U.dtype).eps
    ## Normalization in RKHS norm
    vr = U @ vr_ 
    r_norm =weighted_norm(vr, K / npts)
    r_norm = np.where(np.abs(r_norm) < rcond, np.inf, r_norm)
    vr = vr / r_norm
    vr_ = vr_ / r_norm 
    
    
    # Biorthogonalization
    if symmetry=='symmetric':
        vl = np.linalg.multi_dot([M, V, vl_]) /sqrt(npts)
        vl_ /= sqrt(npts)
        l_norm = np.sum(vl_.conj() * (W@vr_), axis=0).conj() 
    elif symmetry=='antisymmetric': # when L^*=-L hyperbolic sine of the generator L with time dt is learned, i.e. 2 sinh(L dt)
        vl = np.linalg.multi_dot([-M.T /1j, V, vl_]) /sqrt(npts)
        vl_ /= sqrt(npts)
        l_norm = np.sum(vl_.conj() * (W@vr_), axis=0)  / 1j
    else: #TO for time step dt is learned, i.e. exp(L dt), without symmetry constrains
        vl = np.linalg.multi_dot([M.T, V, vl_]) /sqrt(npts)
        vl_ /= sqrt(npts)
        l_norm = np.sum(vl_.conj() * (W@vr_), axis=0).conj()
    #l_norm = np.where(np.abs(values) < rcond, np.inf, values.conj() / r_norm)
    l_norm = np.where(np.abs(l_norm) < rcond, np.inf, l_norm)
    vl = vl / l_norm

    # Spectral bias
    Kvr = K@vr/sqrt(npts)
    bias =  np.sqrt( npts * bias_sigma_sq / np.sum( Kvr*Kvr.conj(), axis=0).real)

    # Correcting the normalization of eigenfunctions
    vr *=sqrt(1+step/npts)
    vl *=sqrt(1+step/npts)

    result: EigResult = {"values": values, "left": vl , "right": vr, "bias": bias}
    return result



def evaluate_eigenfunction(
    eig_result: EigResult,
    which: Literal["left", "right"],
    data: ArrayLike,  # Feature matrix of the shape [num_evaluation_points, features] or kernel matrix of the shape [num_evaluation_points, num_training_points] 
    dual: bool = False,  # Whether the algorithm is dual or primal
):
    vr_or_vl = eig_result[which]
    if dual | (which == "left"):
        rsqrt_dim = (data.shape[1]) ** (-0.5)
        return np.linalg.multi_dot([rsqrt_dim * data, vr_or_vl])
    else:       
        return data @ vr_or_vl


def modes(
    eig_result: EigResult,
    initial_conditions: ArrayLike,  # Feature matrix of the shape [num_initial_conditions, features] or kernel matrix of the shape [num_initial_conditions, num_training_points]
    obs_train: ArrayLike,  # Observable to be predicted evaluated on the trajectory data, shape [num_training_points, obs_features]
    dual: bool = False,  # Whether the algorithm is dual or primal
) -> ModeResult:
    evals = eig_result["values"]
    levecs = eig_result["left"]
    npts = obs_train.shape[0]
    if initial_conditions.ndim==1:
        initial_conditions = np.expand_dims(initial_conditions,axis = 0)
    conditioning = evaluate_eigenfunction(eig_result, "right", initial_conditions, dual=dual).T  # [rank, num_initial_conditions]
    str = 'abcdefgh' # Maximum number of feature dimensions is 8
    einsum_str = ''.join([str[k] for k in range(obs_train.ndim - 1)]) # string for features
    modes_ = np.einsum('nr,n'+einsum_str+'->r'+einsum_str,  levecs.conj(), obs_train) /sqrt(npts) # [rank, features]
    modes_ = np.expand_dims(modes_, axis = 1)
    dims_to_add = modes_.ndim - conditioning.ndim
    if dims_to_add>0:
        conditioning = np.expand_dims(conditioning, axis=tuple(range(-dims_to_add, 0)))
    modes = conditioning*modes_ # [rank, num_init_cond, obs_features]
    result: ModeResult = {"decay_rates": -evals.real, "frequencies": evals.imag/(2*np.pi), "modes": modes }
    return result

def kmd_filter(kmd: ModeResult,
           indices: Union[ArrayLike,None] = None,
           decay_rates: Union[Callable[ArrayLike,ArrayLike],None] = None, 
           frequencies: Union[Callable[ArrayLike,ArrayLike],None] = None, 
           eps: float = 1e-6,
           real:bool = True, # If the states and observables are real valued, return only positive frequencies 
)-> ModeResult:
    relevant_modes = np.sqrt(np.sum(np.abs(kmd['modes']*kmd['modes'].conj()), axis = tuple(range(1, kmd['modes'].ndim))))>eps
    filter = relevant_modes
    if indices is not None:
        indices_ = np.isin(np.arange(kmd["modes"].shape[0]), indices)
        filter = filter*indices_
    if decay_rates is not None:
        filter = filter*decay_rates(kmd["decay_rates"])
    if frequencies is not None:
        filter = filter*frequencies(kmd["frequencies"])
    if real:
        freq = kmd["frequencies"][filter]
        rates = kmd["decay_rates"][filter]
        modes = kmd["modes"][filter]
        idx_p = freq > 0
        #corr_ = np.expand_dims(np.exp(-2*freq[idx_p]*2*np.pi*1j),axis=tuple(range(-(modes.ndim-freq.ndim), 0)))
        #modes[idx_p] = modes[idx_p] + corr_ * modes[idx_p].conj()
        modes[idx_p] *= 2
        filter_ = np.logical_not(freq < 0)
        result: ModeResult = {"decay_rates": rates[filter_], "frequencies": freq[filter_], "modes": modes[filter_] }
    else:
        result: ModeResult = {"decay_rates": kmd["decay_rates"][filter], "frequencies": kmd["frequencies"][filter], "modes": kmd["modes"][filter]}
    return result



def predict(
    t: Union[float,ArrayLike],  # time in the same units as dt
    mode_result: ModeResult,
) -> ArrayLike: # shape [num_init_cond, features] or if num_t>1 [num_init_cond, num_time, features]
    if type(t) == float:
        t = np.array([t])
    evals = -mode_result["decay_rates"]+2*np.pi*1j*mode_result["frequencies"]
    to_evals = np.exp(evals[:,None]*t[None,:]) # [rank,time_steps]
    str = 'abcdefgh' # Maximum number of feature dimensions is 8
    einsum_str = ''.join([str[k] for k in range(mode_result["modes"].ndim - 2)]) # string for features
    predictions = np.einsum('rs,rm'+einsum_str+'->ms'+einsum_str, to_evals, mode_result["modes"])
    if predictions.shape[0]==1 or predictions.shape[1]==1: # If only one time point or one initial condition is requested, remove unnecessary dims
        predictions = np.squeeze(predictions)
    return predictions.real