import numpy as np
from copy import deepcopy as dc
from lib.sinkhorn import peyre_expon, normalize, gromov_loss, gromov
from lib.util import pearson
from ot.gromov import gromov_wasserstein as ot_gw

add_prec = 1e-30
print_interval = 10


def cg(ws: np.ndarray, wt: np.ndarray, ps: np.ndarray, pt: np.ndarray, n_stage=10000, bound=0.2, eta_trans=1,
       trans_star=None, stop_prec=5e-7, use_gd=False):
    def inner(trans):
        turn = trans.T @ ws @ trans - wt * (trans.T @ trans)
        return bound * (turn <= 0).astype(float) - bound * (turn > 0).astype(float)

    best_loss = np.inf
    trans = np.outer(ps, pt)
    for stage_i in range(n_stage):
        pert = inner(trans=trans)
        # print("pert {}, {}, {}".format(np.min(pert), np.median(pert), np.max(pert)))
        # print("pert={}".format(pert))
        r_wt = wt + pert
        trans, log = ot_gw(C1=ws, C2=r_wt, p=ps, q=pt, loss_fun='square_loss', G0=trans, log=True, armijo=False)
        current_loss = log['gw_dist']
        if np.abs(best_loss - current_loss) < stop_prec:
            break
        if best_loss > current_loss:
            best_loss = current_loss
    return trans, inner(trans)


def gd(ws: np.ndarray, wt: np.ndarray, ps: np.ndarray, pt: np.ndarray, n_stage=10000, bound=0.2, eta_trans=1,
       trans_star=None, stop_prec=5e-7, use_gd=False):
    def inner(trans):
        turn = trans.T @ ws @ trans - wt * (trans.T @ trans)
        return bound * (turn <= 0).astype(float) - bound * (turn > 0).astype(float)

    # print("eta_trans={}".format(eta_trans))
    trans = np.outer(ps, pt)
    for stage_i in range(n_stage):
        pert = inner(trans=trans)
        # print("pert {}, {}, {}".format(np.min(pert), np.median(pert), np.max(pert)))
        # print("pert={}".format(pert))
        r_wt = wt + pert
        grad = peyre_expon(trans=trans, ps=ps, pt=pt, ws=ws, wt=r_wt)
        # print("grad calculated")
        if use_gd:
            tmp = trans - eta_trans * grad
        else:
            tmp = trans * np.exp(-1 - eta_trans * grad) + add_prec
        del grad
        trans_new = normalize(tmp, ps, pt)
        if stage_i % print_interval == 0:
            # loss = gromov_loss(trans=trans, ws=ws, wt=r_wt, ps=ps, pt=pt)
            # print("loss={}".format(loss))
            if trans_star is not None:
                corr = pearson(trans, trans_star)
                print("corr={}".format(corr))
            yield trans_new, inner(trans_new)
        if np.max(np.abs(trans_new - trans)) < stop_prec:
            yield trans_new, inner(trans_new)
            break
        else:
            trans = trans_new
    if stage_i == n_stage - 1:
        print("has not converged")
    yield trans, inner(trans)


def rgw(ws: np.ndarray, wt: np.ndarray, ps: np.ndarray, pt: np.ndarray, n_stage=100, bound=0.2, eta_trans=1,
        trans_star=None, stop_prec=5e-7, max_iter=100000, use_gd=False):
    """
    minmax alternating optimization
    """

    def inner(trans):
        turn = trans.T @ ws @ trans - wt * (trans.T @ trans)
        # print(turn)
        return bound * (turn <= 0).astype(float) - bound * (turn > 0).astype(float)

    print("eta_trans={}".format(eta_trans))
    trans = gromov(ws=ws, wt=wt, ps=ps, pt=pt, stepsize=eta_trans, max_iter=max_iter, use_gd=use_gd)
    if trans_star is not None:
        corr = pearson(trans, trans_star)
        print("corr={}".format(corr))
    loss = gromov_loss(trans=trans, ws=ws, wt=wt, ps=ps, pt=pt)
    print("loss={}".format(loss))
    print("trans={}".format(trans))
    for stage_i in range(n_stage):
        pert = inner(trans=trans)
        # print("pert {}, {}, {}".format(np.min(pert), np.median(pert), np.max(pert)))
        print("pert={}".format(pert))
        r_wt = wt + pert
        trans = gromov(ws=ws, wt=r_wt, ps=ps, pt=pt, stepsize=eta_trans, init=trans, use_gd=use_gd)
        yield trans
        if trans_star is not None:
            corr = pearson(trans, trans_star)
            print("corr={}".format(corr))
        loss = gromov_loss(trans=trans, ws=ws, wt=r_wt, ps=ps, pt=pt)
        print("loss={}".format(loss))
        print("trans={}".format(trans))
    return trans
