import numpy as np 
import torch
import ot 

#####################################################################################################################################
#####################################################################################################################################
### LINEAR METRIC ###
#####################################################################################################################################
#####################################################################################################################################

###
# Settings : (Ds,Rs,Ls) if primal, (Ds,Xs,prs,pls) if dual
###

def HS_metric(
        Ds,Rs,Ls,
        Dt,Rt,Lt,
        sampfreqs:int=1,
        sampfreqt:int=1):
    Ts = Rs.conj()[None, :, :] * Ls[:, None, :] * np.exp(Ds*sampfreqs)[:, None, None]
    Tt = Rt.conj()[None, :, :] * Lt[:, None, :] * np.exp(Dt*sampfreqt)[:, None, None]
    C = Ts - Tt
    return torch.linalg.norm(C, 'fro')

def operator_metric(
        Ds,Rs,Ls,
        Dt,Rt,Lt,
        sampfreqs:int=1,
        sampfreqt:int=1):
    Ts = Rs.conj()[None, :, :] * Ls[:, None, :] * np.exp(Ds*sampfreqs)[:, None, None]
    Tt = Rt.conj()[None, :, :] * Lt[:, None, :] * np.exp(Dt*sampfreqt)[:, None, None]
    C = Ts - Tt
    return torch.linalg.svdvals(C).max()

#####################################################################################################################################
#####################################################################################################################################
### OT METRIC ###
#####################################################################################################################################
#####################################################################################################################################

from kooporch import ot_score, ot_plan, ChordalCostFunction

def chordal_metric(
        Ds, Rs, Ls, 
        Dt, Rt, Lt, 
        real_scale: float = 1.0,
        imag_scale: float = 1.0,
        alpha: float = 0.5,
        p: int = 2):
    cost_fn = ChordalCostFunction(real_scale,imag_scale,alpha,p)
    C = cost_fn(Ds, Rs, Ls, Dt, Rt, Lt)
    P = ot_plan(C)
    return ot_score(C, P, p)

def eigenvalue_metric(
        Ds, Rs, Ls, 
        Dt, Rt, Lt, 
        sampfreqs:int=1,
        sampfreqt:int=1,
        p: int = 2):
    C = torch.abs(torch.exp(Ds*sampfreqs)[:,None] - torch.exp(Dt*sampfreqt)[None,:])**p
    P = ot_plan(C)
    return ot_score(C, P, p)  

def subspace_metric(
        Ds, Rs, Ls, 
        Dt, Rt, Lt, 
        p: int = 2):
    cst_fn = ChordalCostFunction(real_scale=1.0,imag_scale=1.0,alpha=0.0,p=p)
    C = cst_fn(Ds, Rs, Ls, Dt, Rt, Lt)
    Ws = torch.abs(torch.exp(Ds))/torch.sum(torch.abs(torch.exp(Ds)))
    Wt = torch.abs(torch.exp(Dt))/torch.sum(torch.abs(torch.exp(Dt)))
    P = ot_plan(C,Ws,Wt)
    return ot_score(C, P, p)  

#####################################################################################################################################
#####################################################################################################################################
### RELATED WORK METRIC ###
#####################################################################################################################################
#####################################################################################################################################
def martin_distance(
        Ds, Rs, Ls, 
        Dt, Rt, Lt,
        sampfreqs:int=1,
        sampfreqt:int=1):
    ds = torch.exp(Ds*sampfreqs)
    dt = torch.exp(Dt*sampfreqt)
    num =torch.prod(torch.abs(1-ds.conj()[:,None]*dt[None,:])**2).real
    denoms = torch.prod(1-ds.conj()[:,None]*ds[None,:]).real
    denomt = torch.prod(1-dt.conj()[:,None]*dt[None,:]).real
    return torch.sqrt(torch.log(num/(denoms*denomt)))


def solve_equation_stable(A, B, C,scale:float=1.0):
    # solve X - scale*AXB = C for X
    m, n = C.shape
    # Build system: (I - kron(B.T, A)) vec(X) = vec(C)
    K = torch.kron(B.T, scale*A)
    M = torch.eye(m * n) - K
    c = C.reshape(-1, order='F')

    # Solve without inverting
    x = torch.linalg.solve(M, c, assume_a='gen')
    X = x.reshape(m, n, order='F')
    return X

def det_binet_cauchy_kernel(
        Ds, Rs, Ls, 
        Dt, Rt, Lt,
        sampfreqs:int=1,
        sampfreqt:int=1,
        scale:float=1.0):
    ds = torch.exp(Ds*sampfreqs)
    dt = torch.exp(Dt*sampfreqt)
    Ts = Rs.conj()[None, :, :] * Ls[:, None, :] * ds[:, None, None]
    Tt = Rt.conj()[None, :, :] * Lt[:, None, :] * dt[:, None, None]
    Id = torch.eye(Ts.shape[1])
    P = solve_equation_stable(Ts.real, Tt.real, Id, scale=scale)
    return torch.det(P)**2

def det_binet_cauchy_metric(
        Ds, Rs, Ls, 
        Dt, Rt, Lt,
        sampfreqs:int=1,
        sampfreqt:int=1,
        scale:float=1.0):
    vals = det_binet_cauchy_kernel(Ds, Rs, Ls, Ds, Rs, Ls, sampfreqs, sampfreqt, scale=scale)
    valt = det_binet_cauchy_kernel(Dt, Rt, Lt, Dt, Rt, Lt, sampfreqs, sampfreqt, scale=scale)
    valst = det_binet_cauchy_kernel(Ds, Rs, Ls, Dt, Rt, Lt, sampfreqs, sampfreqt, scale=scale)
    return torch.sqrt(torch.log(vals) + torch.log(valt) - 2*torch.log(valst))


    


