import torch
import itertools

# Gaussian Hellinger-Kantorovich (GHK)

# matrix sqrt and batched variant 
def _sqrtm(x):
    # decomp = torch.linalg.eigh(x)
    # return decomp.eigenvectors @ torch.diag(decomp.eigenvalues**0.5) @ decomp.eigenvectors.T
	# some of the matrices we need have positive eigenvalues but nonsymmetric
    decomp = torch.linalg.eig(x)
    return decomp.eigenvectors @ torch.diag(decomp.eigenvalues**0.5) @ torch.linalg.pinv(decomp.eigenvectors)
_bsqrtm = torch.vmap(_sqrtm)

def _sqrtm_psd(x):
    decomp = torch.linalg.eigh(x)
    return decomp.eigenvectors @ torch.diag(torch.clamp_min(decomp.eigenvalues, 0)**0.5) @ torch.linalg.pinv(decomp.eigenvectors)
_bsqrtm_psd = torch.vmap(_sqrtm_psd)

def check_psd(x):
	return (torch.linalg.eigvalsh(x).min() > 0) and (x - x.T).abs().max() < 1e-7

def entropicGHK(m_a, a, A, m_b, b, B, sigma = 1.0, gamma = 1.0):
	tau = gamma / (2*sigma**2 + gamma)
	lamda = sigma**2 + gamma/2
	d = len(a)
	Id = torch.eye(d)
	X = A + B + lamda * Id
	Atilde = gamma/2*(Id - lamda*torch.linalg.pinv(A + lamda*Id))
	Btilde = gamma/2*(Id - lamda*torch.linalg.pinv(B + lamda*Id))
	C = _sqrtm((1/tau)*torch.matmul(Atilde, Btilde) + (sigma**4 / 4)*Id) - (sigma**2/2)*Id
	prefactor = (sigma)**((d*sigma**2) / (gamma + sigma**2))
	_u = (m_a * m_b * torch.linalg.det(C) * ((torch.linalg.det(Atilde @ Btilde)**tau) / torch.linalg.det(A @ B))**0.5)**(1/(1+tau)) 
	_v = torch.exp(-((a-b) @ torch.linalg.pinv(X) @ (a-b)[:, None]) / (2*(1+tau))) / (torch.linalg.det(C - (2/gamma)*Atilde @ Btilde))**0.5
	m_pi = prefactor * _u * _v
	dist = gamma*(m_a + m_b) + 2*sigma**2 * (m_a * m_b) - 2*(sigma**2 + gamma)*m_pi
	return torch.real(dist)

def GHK(m_a, a, A, m_b, b, B, gamma = 1.0):
	d = len(a)
	Id = torch.eye(d)
	X = A + B + (gamma / 2) * Id
	Ahat = (2/gamma)*A + Id
	Bhat = (2/gamma)*B + Id
	J = _sqrtm(torch.matmul(Ahat, Bhat)) @ (Id - (2/gamma)*_sqrtm(A @ torch.linalg.pinv(Ahat) @ torch.linalg.pinv(Bhat) @ B))
	dist = gamma * (m_a + m_b - 2*(m_a * m_b * torch.exp(-((a-b) @ torch.linalg.pinv(X) @ (a-b)[:, None]) / 2) / torch.linalg.det(J))**0.5)
	return torch.real(dist)

def _eye_like(tensor):
    return torch.eye(*tensor.size(), out=torch.empty_like(tensor))

def diff_GHK(m_a, a, A, dg_a, da, dA, gamma = 1.0):
	d = len(a)
	U, V = torch.linalg.eigh(A)
	Ws = torch.einsum('ij,ik->ijk', V.T, V.T) # W_i = v_i v_i'
	q = 2/gamma
	Uq = 1 + q*U
	Id = _eye_like(A)
	dg = dg_a 
	X = 2*A + (1/q)*Id
	_a = (da @ torch.linalg.pinv(X) @ da[:, None])
	_b = 0.5*(dg)**2
	_c = q*sum([((U[i]*Uq[j]) / (Uq[j]*U[i]+Uq[i]*U[j])**2) * torch.trace(Ws[i] @ dA @ Ws[j] @ dA) for (i, j) in itertools.product(range(d), range(d))])
	# print(_a.item(), _b.item(), _c.item())
	res = m_a/q*(_a + _b + _c)
	return res

def get_dS(state, params):
    (A, e), (F, c, b), D = params
    S, x, g = state
    return A @ S + S @ A.T + 2*D + S @ (F + F.T) @ S
def get_dx(state, params):
    (A, e), (F, c, b), D = params
    S, x, g = state
    return x @ (A + S @ (F + F.T)).T + c @ S + e
def get_dg(state, params):
    (A, e), (F, c, b), D = params
    S, x, g = state
    _Gamma = F + F.T
    return 0.5*x @ (_Gamma @ x[:, None]) + torch.dot(c, x) + b + 0.5*torch.trace(_Gamma @ S)

import torchdiffeq

def F_ode(t, y, params = None):
    state = y
    dS, dx, dg = get_dS(state, params), get_dx(state, params), get_dg(state, params)
    return (dS, dx, dg)

def fit_b(A, e, F, c, D, S0, x0, bmin = -100, bmax = 100, T = 10, iters = 50, g_target = 0.0, **kwargs):
    # Fit constant growth term by bisection search
    ts = torch.linspace(0, 1.0, T)
    b = (bmax + bmin)/2
    for _  in range(iters):
        b_ = torch.scalar_tensor(b)
        params = ((A, e), (F, c, b), D)
        sol = torchdiffeq.odeint(lambda t, y: F_ode(t, y, params), (S0, x0, torch.scalar_tensor(0.0)), ts, **kwargs)
        if sol[2][-1] < g_target:
            bmin = b
        else:
            bmax = b
        b = (bmax + bmin)/2
        err = abs(sol[2][-1] - g_target)
        if err < 0.05:
            break
    if err > 0.1:
        raise Exception(f"Bisection search failed to converge: err={err}")
    return b
