import numpy as np
import torch
from tqdm import tqdm
from sklearn.decomposition import PCA
from ot_utils import inv_ot
import ot_utils
import mds

from ot.gromov import entropic_gromov_wasserstein

from timeit import default_timer as timer


def joint_mds(D1, D2, w1=None, w2=None, a=None, b=None, alpha=1.0, n_components=2, max_iter=300, eps=0.0,
              tol=1e-3, min_eps=0.001, eps_annealing=True, alpha_annealing=False, gw_init=False, return_stress=False):
    m = D1.shape[0]
    n = D2.shape[0]

    if a is None:
        a = D1.new_ones((m,)) / m
    if b is None:
        b = D2.new_ones((n,)) / n

    weights = D1.new_zeros((m + n, m + n))

    if w1 is None:
        w1 = torch.outer(a, a)

    if w2 is None:
        w2 = torch.outer(b, b)

    weights[:m, :m] = w1
    weights[m:, m:] = w2

    D = D1.new_zeros((m + n, m + n))
    D[:m, :m] = D1
    D[m:, m:] = D2

    # Initialization
    if gw_init:
        P = ot_utils.gromov_wasserstein(D1, D2, p=a, q=b, eps=eps, max_iter=20)
        weights[:m, m:] = alpha * P
        weights[m:, :m] = alpha * P.T
        Z, s = mds.smacof(D, n_components=n_components, n_init=1, weights=weights, eps=1e-09)#, eps=0.01)
        clf = PCA(n_components=n_components)
        Z = clf.fit_transform(Z.cpu().numpy())
        Z = torch.from_numpy(Z).to(D1.device)
        Z1 = Z[:m]
        Z2 = Z[m:]
        Z_old = Z
    else:
        Z1,_ = mds.smacof(D1, n_components=n_components, n_init=1)#, weights=w1)
        Z2,_ = mds.smacof(D2, n_components=n_components, n_init=1)#, weights=w2)
        clf = PCA(n_components=n_components)
        Z1 = clf.fit_transform(Z1.cpu().numpy())
        Z1 = torch.from_numpy(Z1).to(D1.device)
        Z2 = clf.fit_transform(Z2.cpu().numpy())
        Z2 = torch.from_numpy(Z2).to(D2.device)
        Z_old = torch.vstack((Z1, Z2))


    time1 = 0
    time2 = 0
    pbar = tqdm(range(max_iter))

    for i in pbar:
        tic = timer()
        P, O = inv_ot(Z1, Z2, a=a, b=b, eps=eps, max_iter=10)
        time1 += timer() - tic

        tic = timer()
        weights[:m, m:] = alpha * P
        weights[m:, :m] = alpha * P.T
        Z = Z_old.clone()
        Z[:m] = Z1.mm(O)

        Z, s = mds.smacof(D, n_components=n_components, init=Z, n_init=1, weights=weights)

        time2 += timer() - tic

        err = torch.norm(Z - Z_old)

        pbar.set_postfix({"eps": eps,"diff": err.item(), "stress": s.item()})
        if err < tol:
            # print(i)
            break
        Z_old = Z
        Z1 = Z[:m]
        Z2 = Z[m:]
        if eps_annealing:
            eps = max(eps * 0.95, min_eps)
        if alpha_annealing:
            alpha = max(alpha * 0.9, 0.01)

    print("Inv OT time: {:.2f}s".format(time1))
    print("MDS time: {:.2f}s".format(time2))
    if return_stress:
        return Z1, Z2, P, s

    return Z1, Z2, P
