import numpy as np 
import ot 
from sklearn.utils.extmath import randomized_svd

#####################################################################################################################################
#####################################################################################################################################
### LINEAR METRIC ###
#####################################################################################################################################
#####################################################################################################################################

###
# Settings : (Ds,Rs,Ls) if primal, (Ds,Xs,prs,pls) if dual
###

def hs_metric(
        Ds,Rs,Ls,
        Dt,Rt,Lt,
        sampfreqs:int=1,
        sampfreqt:int=1):
    Ts = Rs @ (np.exp(Ds*sampfreqs).reshape(-1, 1) * Ls.conj().T)
    Tt = Rt @ (np.exp(Dt*sampfreqt).reshape(-1, 1) * Lt.conj().T)
    C = Ts - Tt
    return np.linalg.norm(C, 'fro')

def operator_metric(
        Ds,Rs,Ls,
        Dt,Rt,Lt,
        sampfreqs:int=1,
        sampfreqt:int=1,
        exact:bool=False,
        n_iter:int=5,
        random_state:int=None
        ):
    Ts = Rs @ (np.exp(Ds*sampfreqs).reshape(-1, 1) * Ls.conj().T)
    Tt = Rt @ (np.exp(Dt*sampfreqt).reshape(-1, 1) * Lt.conj().T)
    C = Ts - Tt
    if exact:
        return np.linalg.norm(C, 2)
    else:
        _, S, _ = randomized_svd(C.real, n_components=1, n_iter=n_iter, random_state=random_state)
    return S[0]

#####################################################################################################################################
#####################################################################################################################################
### OT METRIC ###
#####################################################################################################################################
#####################################################################################################################################

def unitary_grassman_matrix(Ps:np.ndarray,Pt:np.ndarray)->np.ndarray:
    """Compute the unitary grassman matrix for source and target domains.

    Args:
        Ps (np.ndarray): Source domain data, shape: (l,n_ds), l: ambiant space dimensions size, n_ds: grassman source dimension.
        Pt (np.ndarray): Target domain data, shape: (l,n_dt), l: ambiant space dimensions size, n_dt: grassman target dimension.

    Returns:
        np.ndarray: Grassman matrix.
    """
    Psn = Ps/np.linalg.norm(Ps,axis=0,keepdims=True)
    Ptn = Pt/np.linalg.norm(Pt,axis=0,keepdims=True)
    C = np.einsum('dl,lr->dr',Psn.conj().T,Ptn)
    return C

def eigenvector_chordal_cost_matrix(Rs:np.ndarray,Ls:np.ndarray,Rt:np.ndarray,Lt:np.ndarray)->np.ndarray:
    """Compute pairwise grassman matrices for source and target domains.

    Args:
        Rs (np.ndarray): Source right eigenvectors, shape: (L,Rs), L: ambiant space dimensions size, Rs: grassman source dimension.
        Ls (np.ndarray): Source left eigenvectors, shape: (L,Rs), L: ambiant space dimensions size, Rs: grassman source dimension.
        Rt (np.ndarray): Target right eigenvectors, shape: (L,Rt), L: ambiant space dimensions size, Rt: grassman target dimension.
        Lt (np.ndarray): Target left eigenvectors, shape: (L,Rt), L: ambiant space dimensions size, Rt: grassman target dimension.

    Returns:
        np.ndarray: eigenvector chordal cost matrix.
    """
    Cr = unitary_grassman_matrix(Rs,Rt)
    Cl = unitary_grassman_matrix(Ls,Lt)
    C = np.sqrt(1-np.clip((Cr*Cl).real,min=0,max=1))
    return C

def eigenvalue_cost_matrix(Ds:np.ndarray,Dt:np.ndarray,real_scale:float=1.0,imag_scale:float=1.0)->np.ndarray:
    """Compute pairwise eigenvalue matrices for source and target domains.

    Args:
        Ds (np.ndarray): Source domain eigenvalues, shape: (Rs,), Rs: source dimension.
        Dt (np.ndarray): Target domain eigenvalues, shape: (Rt,), Rt: target dimension.

    Returns:
        np.ndarray: Eigenvalue cost matrix.
    """
    Dsn = Ds.real*real_scale + 1j* Ds.imag*imag_scale
    Dtn = Dt.real*real_scale + 1j* Dt.imag*imag_scale
    C = np.abs(Dsn[:,None] - Dtn[None,:])
    return C

def ChordalCostFunction(
    real_scale:float=1.0,
    imag_scale:float=1.0,
    alpha:float=0.5,
    p:int = 2):
    """Generate the chordal cost function.

    Args:
        real_scale (float): Real scale factor.
        imag_scale (float): Imaginary scale factor.
        alpha (float): Weighting factor for the eigenvalue cost.
        p (int): Power for the chordal distance.

    Returns:
        callable: Chordal cost function.
    """
    def cost_function(
            Ds:np.ndarray,
            Rs:np.ndarray,
            Ls:np.ndarray,
            Dt:np.ndarray,
            Rt:np.ndarray,
            Lt:np.ndarray) -> np.ndarray:
        """Compute the chordal cost matrix between source and target spectral decompositions.

        Args:
            Ds (np.ndarray): Source eigenvalues.
            Rs (np.ndarray): Source right eigenvectors.
            Ls (np.ndarray): Source left eigenvectors.
            Dt (np.ndarray): Target eigenvalues.
            Rt (np.ndarray): Target right eigenvectors.
            Lt (np.ndarray): Target left eigenvectors.

        Returns:
            np.ndarray: Chordal cost matrix.
        """
        CD = eigenvalue_cost_matrix(Ds, Dt, real_scale=real_scale, imag_scale=imag_scale)
        CC = eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt)
        C = alpha * CD + (1 - alpha) * CC
        return C**p

    return cost_function

def ot_plan(C:np.ndarray,Ws:np.ndarray=None,Wt:np.ndarray=None)->np.ndarray:
    """Compute the optimal transport plan between two distributions given a cost matrix and marginal distributions.

    Args:
        C (np.ndarray): Cost matrix, shape: (n,m).
        Ws (np.ndarray): Source distribution, shape: (n,).
        Wt (np.ndarray): Target distribution, shape: (m,).

    Returns:
        np.ndarray: Optimal transport plan, shape: (n,m).
    """
    if Ws is None:
        Ws = np.ones(C.shape[0]) / C.shape[0]
    if Wt is None:
        Wt = np.ones(C.shape[1]) / C.shape[1]
    return ot.emd(Ws,Wt,C)

def ot_score(C:np.ndarray,P:np.ndarray, p:int=2)->float:
    """Compute the OT score (distance) between two distributions given a cost matrix and a transport plan.

    Args:
        C (np.ndarray): Cost matrix, shape: (n,m).
        P (np.ndarray): Transport plan, shape: (n,m).
        p (int, optional): Power for the OT score. Defaults to 2.

    Returns:
        float: OT score (distance).
    """
    return np.sum(C * P) ** (1/p)

def chordal_metric(
        Ds, Rs, Ls, 
        Dt, Rt, Lt, 
        real_scale: float = 1.0,
        imag_scale: float = 1.0,
        alpha: float = 0.5,
        p: int = 2):
    cost_fn = ChordalCostFunction(real_scale,imag_scale,alpha,p)
    C = cost_fn(Ds, Rs, Ls, Dt, Rt, Lt)
    P = ot_plan(C)
    return ot_score(C, P, p)

def eigenvalue_metric(
        Ds, Rs, Ls, 
        Dt, Rt, Lt, 
        sampfreqs:int=1,
        sampfreqt:int=1,
        p: int = 2):
    C = np.abs(np.exp(Ds*sampfreqs)[:,None] - np.exp(Dt*sampfreqt)[None,:])**p
    P = ot_plan(C)
    return ot_score(C, P, p)  

def subspace_metric(
        Ds, Rs, Ls, 
        Dt, Rt, Lt, 
        sampfreqs:int=1,
        sampfreqt:int=1,
        p: int = 2):
    cst_fn = ChordalCostFunction(real_scale=1.0,imag_scale=1.0,alpha=0.0,p=p)
    C = cst_fn(Ds, Rs, Ls, Dt, Rt, Lt)
    Ws = np.abs(np.exp(Ds*sampfreqs))/np.sum(np.abs(np.exp(Ds*sampfreqs)))
    Wt = np.abs(np.exp(Dt*sampfreqt))/np.sum(np.abs(np.exp(Dt*sampfreqt)))
    P = ot_plan(C,Ws,Wt)
    return ot_score(C, P, p)  

#####################################################################################################################################
#####################################################################################################################################
### RELATED WORK METRIC ###
#####################################################################################################################################
#####################################################################################################################################
def martin_metric(
        Ds, Rs, Ls, 
        Dt, Rt, Lt,
        sampfreqs:int=1,
        sampfreqt:int=1):
    ds = np.exp(Ds*sampfreqs)
    dt = np.exp(Dt*sampfreqt)
    num =np.prod(np.abs(1-ds.conj()[:,None]*dt[None,:])**2).real
    denoms = np.abs(np.prod(1-ds.conj()[:,None]*ds[None,:]).real)
    denomt = np.abs(np.prod(1-dt.conj()[:,None]*dt[None,:]).real)
    try:
        res =np.sqrt(np.clip(np.log(num/(denoms*denomt)), 0, None))
    except ZeroDivisionError:
        res = 0
    return res

def solve_equation_stable(A, B, C,scale:float=1.0):
    # solve X - scale*AXB = C for X
    m, n = C.shape
    # Build system: (I - kron(B.T, A)) vec(X) = vec(C)
    K = np.kron(B.T, scale*A)
    M = np.eye(m * n) - K
    c = C.reshape(-1, order='F')

    # Solve without inverting
    x = np.linalg.solve(M, c, assume_a='gen')
    X = x.reshape(m, n, order='F')
    return X

def det_binet_cauchy_kernel(
        Ds, Rs, Ls, 
        Dt, Rt, Lt,
        sampfreqs:int=1,
        sampfreqt:int=1,
        scale:float=1.0):
    Ts = Rs @ (np.exp(Ds*sampfreqs).reshape(-1, 1) * Ls.conj().T)
    Tt = Rt @ (np.exp(Dt*sampfreqt).reshape(-1, 1) * Lt.conj().T)
    Id = np.eye(Ts.shape[1])
    P = solve_equation_stable(Ts.real, Tt.real, Id, scale=scale)
    return np.linalg.det(P)**2

def det_binet_cauchy_metric(
        Ds, Rs, Ls, 
        Dt, Rt, Lt,
        sampfreqs:int=1,
        sampfreqt:int=1,
        scale:float=1.0):
    vals = det_binet_cauchy_kernel(Ds, Rs, Ls, Ds, Rs, Ls, sampfreqs, sampfreqt, scale=scale)
    valt = det_binet_cauchy_kernel(Dt, Rt, Lt, Dt, Rt, Lt, sampfreqs, sampfreqt, scale=scale)
    valst = det_binet_cauchy_kernel(Ds, Rs, Ls, Dt, Rt, Lt, sampfreqs, sampfreqt, scale=scale)
    return np.sqrt(np.log(vals) + np.log(valt) - 2*np.log(valst))


#####################################################################################################################################
#####################################################################################################################################
### Kernel case ###
#####################################################################################################################################
#####################################################################################################################################
   
def RBF(sigma:float=1.0)->callable:
    """Generate the RBF kernel function with a given sigma.

    Args:
        sigma (float): The bandwidth parameter for the RBF kernel.
    Returns:
        callable: RBF kernel function that computes the kernel matrix between two sets of points.
    """
    def f(X:np.ndarray,Y:np.ndarray) -> np.ndarray:
        """Compute the RBF kernel between two sets of points.
        
        Args:
            X (torch.Tensor): First set of points, shape: (n, d), n: number of points, d: dimension.
            Y (torch.Tensor): Second set of points, shape: (m, d), m: number of points, d: dimension.

        Returns:
            torch.Tensor: RBF kernel matrix, shape: (n, m).
        """
        pdist = np.sum((X[:,None,:] - Y[None,:,:])**2,axis=-1)
        return np.exp(-pdist/sigma**2)
    return f 

def kernel_unitary_grassman_matrix(
        Xs:np.ndarray,
        ps:np.ndarray,
        Xt:np.ndarray, 
        pt:np.ndarray,
        K:callable) -> np.ndarray: 

    Ks = np.sqrt(np.sum(ps.conj() * (K(Xs,Xs) @ ps),axis=0,keepdims=True).real)
    Kt = np.sqrt(np.sum(pt.conj() * (K(Xt,Xt) @ pt),axis=0,keepdims=True).real)
    Kst = ps.conj().T @ K(Xs,Xt) @ pt
    return Kst/ (Ks.T*Kt)

def _kernel_unitary_grassman_matrix(
        ps:np.ndarray, 
        pt:np.ndarray,
        Ks:np.ndarray,
        kt:np.ndarray,
        Kst:np.ndarray) -> np.ndarray: 
    
    pKs = np.sqrt(np.sum(ps.conj() * (Ks @ ps),axis=0,keepdims=True).real)
    pKt = np.sqrt(np.sum(pt.conj() * (kt @ pt),axis=0,keepdims=True).real)
    pKst = ps.conj().T @ Kst @ pt
    return pKst/ (pKs.T*pKt)

def kernel_eigenvector_chordal_cost_matrix(
        Xs:np.ndarray,
        prs:np.ndarray,
        pls:np.ndarray, 
        Xt:np.ndarray,
        prt:np.ndarray, 
        plt:np.ndarray,
        K:callable):
    
    #compute the kernel matrices
    Ks = K(Xs,Xs)
    Kt = K(Xt,Xt)
    Kst = K(Xs,Xt)
    
    #compute the unitary grassman matrices
    Cr = _kernel_unitary_grassman_matrix(prs,prt,Ks,Kt,Kst)
    Cl = _kernel_unitary_grassman_matrix(pls,plt,Ks,Kt,Kst)
    C = np.sqrt(1-np.clip((Cr*Cl).real,min=0,max=1))
    return C

def eigenvalue_cost_matrix(Ds:np.ndarray,Dt:np.ndarray,real_scale:float=1.0,imag_scale:float=1.0)->np.ndarray:
    """Compute pairwise eigenvalue matrices for source and target domains.

    Args:
        Ds (np.ndarray): Source domain eigenvalues, shape: (Rs,), Rs: source dimension.
        Dt (np.ndarray): Target domain eigenvalues, shape: (Rt,), Rt: target dimension.

    Returns:
        np.ndarray: Eigenvalue cost matrix.
    """
    Dsn = Ds.real*real_scale + 1j* Ds.imag*imag_scale
    Dtn = Dt.real*real_scale + 1j* Dt.imag*imag_scale
    C = np.abs(Dsn[:,None] - Dtn[None,:])
    return C

def kernel_chordal_cost_matrix(
    Ds:np.ndarray,
    Xs:np.ndarray,
    prs:np.ndarray, 
    pls:np.ndarray,
    Dt:np.ndarray,
    Xt:np.ndarray,
    prt:np.ndarray,
    plt:np.ndarray,
    K:callable,
    real_scale:float=1.0,
    imag_scale:float=1.0,
    alpha:float=0.5,
    p:int = 2):

    CD = eigenvalue_cost_matrix(Ds,Dt,real_scale=real_scale,imag_scale=imag_scale)
    CC = kernel_eigenvector_chordal_cost_matrix(Xs,prs,pls,Xt,prt,plt,K)
    C = alpha * CD + (1 - alpha) * CC
    return C**p

def KernelChordalCostFunction(
    K:callable,
    real_scale:float=1.0,
    imag_scale:float=1.0,
    alpha:float=0.5,
    p:int = 2):

    def cost_function(
        Ds:np.ndarray,
        Xs:np.ndarray,
        prs:np.ndarray, 
        pls:np.ndarray,
        Dt:np.ndarray,
        Xt:np.ndarray,
        prt:np.ndarray,
        plt:np.ndarray) -> np.ndarray:

        CD = eigenvalue_cost_matrix(Ds,Dt,real_scale=real_scale,imag_scale=imag_scale)
        CC = kernel_eigenvector_chordal_cost_matrix(Xs,prs,pls,Xt,prt,plt,K)
        C = alpha * CD + (1 - alpha) * CC
        return C**p
    return cost_function


def kernel_chordal_metric(
        Ds, Xs, prs, pls, 
        Dt, Xt, prt, plt,
        K:callable,
        real_scale: float = 1.0,
        imag_scale: float = 1.0,
        alpha: float = 0.5,
        p: int = 2):
    cost_fn = KernelChordalCostFunction(K,real_scale,imag_scale,alpha,p)
    C = cost_fn(Ds, Xs, prs, pls, Dt, Xt, prt, plt)
    P = ot_plan(C)
    return ot_score(C, P, p)

def KernelChordalMetric(
    K:callable,
    real_scale:float=1.0,
    imag_scale:float=1.0,
    alpha:float=0.5,
    p:int = 2):

    def metric(
        Ds:np.ndarray,
        Xs:np.ndarray,
        prs:np.ndarray, 
        pls:np.ndarray,
        Dt:np.ndarray,
        Xt:np.ndarray,
        prt:np.ndarray,
        plt:np.ndarray) -> float:
        return kernel_chordal_metric(
            Ds, Xs, prs, pls, 
            Dt, Xt, prt, plt,
            K,
            real_scale,
            imag_scale,
            alpha,
            p)
    return metric