"""
======================================
Comparing imbalanced pointclouds in 3D
======================================

This example shows how UGW matches two 3D pointclouds with imbalanced weights.

"""

import os

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
import torch
from sklearn.cluster import KMeans
from src.wot import WaveletOT
from sklearn.decomposition import PCA
from src.utils import construct_affinity_matrix

X1_types=np.loadtxt("data/simulations/s1_label1.txt")
X2_types=np.loadtxt("data/simulations/s1_label2.txt")

path = os.getcwd() + "/output"
if not os.path.isdir(path):
    os.mkdir(path)
path = path + "/plots"
if not os.path.isdir(path):
    os.mkdir(path)

type_to_color = {1: "blue", 2: "magenta", 3: "indigo"}

def plot_density_matching(pi, a, x, b, y, idx, alpha, linewidth):
    cmap1 = get_cmap("Blues")
    cmap2 = get_cmap("Reds")
    plt.figure(figsize=(6.0, 6.0))
    ax = plt.axes(projection="3d")
    ax.set_xlim(-6, 2)
    ax.set_ylim(-3, 5)
    ax.set_zlim(-1, 5)
    ax.scatter(
        x[:, 0],
        x[:, 1],
        6.0,
        c=[type_to_color[X1_types[i]] for i in range(len(x))],
        s=10 * (a / a) ** 2,
        zorder=1,
    )
    ax.scatter(
        y[:, 0],
        y[:, 1],
        0.0,
        c=[type_to_color[X2_types[i]] for i in range(len(x))],
        s=10 * (b / a) ** 2,
        zorder=1,
    )

    # Plot argmax of coupling
    for i in idx:
        m = np.sum(pi[i, :])
        ids = (-pi[i, :]).argsort()[:1]
        for j in ids:
            w = pi[i, j] / m
            t = [x[i][0], y[j][0]]
            u = [x[i][1], y[j][1]]
            v = [6.0, 0.0]
            if X1_types[i] != X2_types[ids]:
                c = "r"
            else:
                c = "g"
            ax.plot(
                t, u, v, c=c, alpha=0.75 * alpha, linewidth=linewidth, zorder=0
            )
    # plt.xticks([])
    # plt.yticks([])
    plt.tight_layout()


if __name__ == "__main__":
    n1 = 1000
    dim = 2
    rho = 0.5
    eps = 0.01
    n_clust = 50
    ratio = 0.7
    compute_balanced = True

    clean_X1 = np.genfromtxt("data/simulations/s1_mapped1.txt")
    clean_X2 =  np.genfromtxt("data/simulations/s1_mapped2.txt")
    # Generate gaussian mixtures translated from each other
    # a, x, b, y = generate_data(n1, ratio)

    X1_avg_dist = torch.cdist(torch.from_numpy(clean_X1), torch.from_numpy(clean_X1)).mean().item()
    X2_avg_dist = torch.cdist(torch.from_numpy(clean_X1), torch.from_numpy(clean_X1)).mean().item() 

    X1_var = 0.1 * X1_avg_dist
    X2_var = 0.1 * X2_avg_dist

    np.random.seed(5)
    # X1_noise = np.random.normal(0.0, X1_var, clean_X1.shape)
    # X2_noise = np.random.normal(0.0, X2_var, clean_X2.shape)
    X1_noise = np.random.normal(0.0, 0.0, clean_X1.shape)
    X2_noise = np.random.normal(0.0, 0.0, clean_X2.shape)

    X1 = clean_X1 + X1_noise
    X2 = clean_X2 + X2_noise

    pca=PCA(n_components=2)
    X1_pc2=pca.fit_transform(X1)
    X1_pc2 = ((X1_pc2 - X1_pc2.mean(axis=0)) / X1_pc2.std(axis=0)) - 1
    pca=PCA(n_components=2)
    X2_pc2=pca.fit_transform(X2)
    X2_pc2 = (X2_pc2 - X2_pc2.mean(axis=0)) / X2_pc2.std(axis=0) - 1

    a = torch.ones(X1.shape[0]).cuda() / X1.shape[0]
    b = torch.ones(X2.shape[0]).cuda() / X2.shape[0]

    n_scales = 20

    wot = WaveletOT(X1, X2, n_scales=n_scales)
    wot.solve()

    # aligned_point_X1 = wot.project(to_X2=False)
    # aligned_point_X2 = wot.project()
    pi_b = wot.coupling

    clf = KMeans(n_clusters=n_clust)
    clf.fit(X1_pc2)
    idx = np.zeros(n_clust)
    for i in range(n_clust):
        d = clf.transform(X1_pc2)[:, i]
        idx[i] = np.argmin(d)
    idx = idx.astype(int)


    if compute_balanced:
        # pi_b = entropic_gromov_wasserstein(dx.cuda().float(), dy.cuda().float(), a, b, epsilon=0.001, loss_fun="square_loss")
        plot_density_matching(pi_b.cpu().numpy(), a.cpu().numpy(), X1_pc2, b.cpu().numpy(), X2_pc2, idx, alpha=1.0, linewidth=0.5)
        plt.legend()
        plt.show()
        plt.savefig("fig_matching_plan_unnoised_wot.png")
        plt.close()

    # dx, dy = torch.from_numpy(dx), torch.from_numpy(dy)

    # rho_list = [0.1]
    # peps_list = [2, 1, 0, -1, -2, -3]
    # for rho in rho_list:
    #     pi = None
    #     for p in peps_list:
    #         eps = 10 ** p
    #         print(f"Params = {rho, eps}")
    #         a, b = torch.from_numpy(a), torch.from_numpy(b)
    #         pi = log_ugw_sinkhorn(
    #             a,
    #             dx,
    #             b,
    #             dy,
    #             init=pi,
    #             eps=eps,
    #             rho=rho,
    #             rho2=rho,
    #             nits_plan=1000,
    #             tol_plan=1e-5,
    #             nits_sinkhorn=1000,
    #             tol_sinkhorn=1e-5,
    #         )
    #         print(f"Sum of transport plans = {pi.sum().item()}")

    #         # Plot matchings between measures
    #         a, b = a.data.numpy(), b.data.numpy()
    #         pi_ = pi.data.numpy()
    #         plot_density_matching(
    #             pi_, a, x, b, y, idx, alpha=1.0, linewidth=1.0
    #         )
    #         plt.legend()
    #         plt.savefig(
    #             path + f"/fig_matching_plan_ugw_"
    #             f"rho{rho}_eps{eps}_ratio{ratio}.png"
    #         )
    #         plt.show()