import numpy as np
from lib.matmul import diag_matmul_np, matmul_diag_np


def sinkhorn(C: np.ndarray, p: np.ndarray, q: np.ndarray, eta: float, T=10):
    K = np.exp(-C / eta)
    u = np.ones(p.shape)
    for _ in range(T):
        v = q / (K.T @ u)
        u = p / (K @ v)
    tmp = diag_matmul_np(u, K)
    return matmul_diag_np(tmp, v)


def normalize(K: np.ndarray, p: np.ndarray, q: np.ndarray, T=10):
    u = np.ones(p.shape)
    for _ in range(T):
        v = q / (K.T @ u)
        u = p / (K @ v)
    tmp = diag_matmul_np(u, K)
    return matmul_diag_np(tmp, v)


def peyre_expon(trans: np.ndarray, ps: np.ndarray, pt: np.ndarray, ws: np.array, wt: np.array):
    """

    :param ws:
    :param wt:
    :param trans:
    :param ps:
    :param pt:
    :return:
    """
    ns, nt = len(ps), len(pt)
    deg_terms = np.outer(ws ** 2 @ ps, np.ones(nt))
    deg_terms += np.outer(np.ones(ns), wt ** 2 @ pt)
    num = ws @ trans @ wt
    return deg_terms - 2 * num


def gromov(ws: np.ndarray, wt: np.ndarray, ps: np.ndarray, pt: np.ndarray, stepsize=100, max_iter=100000,
           add_prec=1e-30, stop_prec=5e-7, init=None, use_gd=False):
    if init is None:
        trans = np.outer(ps, pt)
    else:
        trans = init
    for iter_num in range(max_iter):
        # print("calculating grad")
        grad = peyre_expon(trans=trans, ps=ps, pt=pt, ws=ws, wt=wt)
        # print("grad calculated")
        if use_gd:
            tmp = trans - stepsize * grad
        else:
            tmp = trans * np.exp(-1 - stepsize * grad) + add_prec
        del grad
        trans_new = normalize(tmp, ps, pt)
        if np.max(np.abs(trans_new - trans)) < stop_prec:
            break
        else:
            trans = trans_new
    if iter_num == max_iter - 1:
        print("has not converged")
    return trans


def gromov_loss(trans: np.ndarray, ws: np.ndarray, wt: np.ndarray, ps: np.ndarray, pt: np.ndarray):
    grad = peyre_expon(trans=trans, ps=ps, pt=pt, ws=ws, wt=wt)
    return np.sum(grad * trans)
