import numpy as np
from lib.sinkhorn import peyre_expon, normalize

print_interval = 100


def gromov(ws: np.ndarray, wt: np.ndarray, ps: np.ndarray, pt: np.ndarray, stepsize=100, max_iter=100000,
           add_prec=1e-30, stop_prec=5e-7, init=None):
    if init is None:
        trans = np.outer(ps, pt)
    else:
        trans = init
    for iter_num in range(max_iter):
        # print("calculating grad")
        grad = peyre_expon(trans=trans, ps=ps, pt=pt, ws=ws, wt=wt)
        # print("grad calculated")
        tmp = trans * np.exp(-1 - stepsize * grad) + add_prec
        del grad
        trans_new = normalize(tmp, ps, pt)
        if np.max(np.abs(trans_new - trans)) < stop_prec:
            break
        else:
            trans = trans_new
        if iter_num % print_interval == 0:
            yield trans
    if iter_num == max_iter - 1:
        print("has not converged")
    return trans
