import torch
from itertools import product


def compute_seq_id_clusters(idmat: torch.Tensor):
    # use sequence identity similarity or equivalent to compute the clusters
    id_bank = torch.ones(idmat.shape[0], dtype=torch.long).cumsum(0)-1
    pairs = torch.where(idmat==1.)
    mask = pairs[0]!=pairs[1]
    clean_pairs = torch.stack(pairs, dim=0)[:, mask]
    for i,j in clean_pairs.T:
        id_bank[i] = id_bank[j]
    return id_bank


from sklearn.cluster import AffinityPropagation

@torch.no_grad()
def compute_sem_clusters(simmat: torch.Tensor):
    # compute clusters from similarity matrix using affinity propagation
    clustering = AffinityPropagation(
        affinity='precomputed',
    ).fit(simmat.numpy())
    return torch.from_numpy(clustering.labels_)
