import torch
import numpy as np
from lib.srGW import md_semirelaxed_gromov_wasserstein
from h2tools.cluster_tree import ClusterTree


def srgw_partition(C: torch.Tensor, mu: torch.Tensor, n_par: int, gamma_entropy: float, device='cpu',
                   dtype=torch.float32):
    """

    :param C:
    :param mu:
    :param n_par: the number of clusters the graph is partitioned into
    :param gamma_entropy:
    :return:
    """
    C_bary = torch.eye(n_par, device=device, dtype=dtype)
    # trans, loss = md_semirelaxed_gromov_wasserstein(C1=C, p=mu, C2=C_bary, gamma_entropy=gamma_entropy, device=device,
    #                                                 dtype=dtype)
    trans, loss = md_semirelaxed_gromov_wasserstein(C1=C, p=mu, C2=C_bary, gamma_entropy=gamma_entropy, device=device,
                                                    dtype=dtype, init_mode="random")
    # print("trans={}, loss={}".format(trans, loss))
    # TODO use multiple optimization techniques to find the best partitioning (according to loss and AMI)
    tmp = torch.argmax(trans, dim=1)
    membership = [int(index) for index in tmp]
    return np.array(membership)


class adj_data(object):
    """
    This is container for data for the problem.
    First of all, it requires methods check_far, compute_aux, divide and __len__.
    Functions check_far, compute_aux, and divide must have exactly the same parameters, as required.

    Far Field ==>> low-rank ==>> intra-cluster ==>> two trees have at least one same entry
    """

    def __init__(self, w: np.ndarray, gamma_entropy=0.1, device='cpu', dtype=torch.float32):
        """save particles and matrix"""
        self.A = torch.from_numpy(w).to(device).type(dtype)
        self.matrix = w
        self.count = w.shape[0]
        self.gamma_entropy, self.device, self.dtype = gamma_entropy, device, dtype

    def check_far(self, bb0, bb1):
        for i0 in bb0:
            for i1 in bb1:
                if i0 == i1:
                    return True
        return False

    def compute_aux(self, indices):
        return indices

    def divide(self, indices: np.ndarray):
        """divides cluster into subclusters."""
        C = self.A[indices][:, indices]
        mu = torch.sum(C, dim=1).to(self.device).type(self.dtype)
        mu = mu / torch.sum(mu)
        n_par = 2
        membership = srgw_partition(C=C, mu=mu, n_par=n_par, gamma_entropy=self.gamma_entropy)
        clu0 = np.where(membership == 0)[0]
        clu1 = np.where(membership == 1)[0]
        # return [clu0, clu1]
        k = len(clu0)
        return np.concatenate([clu0, clu1]), [0, k, k + len(clu1)]

    def __len__(self):
        return self.count

# class adj_cluster_tree(ClusterTree):
#     def __init__(self, data: adj_data, block_size: int):
#         super(adj_cluster_tree, self).__init__(data=data, block_size=block_size)
#
#     # def is_far(self, i, other_tree, j):
#     #     pass
#
#     def divide(self, key):
#         """
#
#         :param key:
#         :return:
#         """
