import torch 


def chain_of_cliques(n_cliques=1, dim=3, overlap=0, permute=False):

    assert dim > overlap

    n = (dim - overlap) * (n_cliques - 1) + dim
    perm = torch.randperm(n) if permute else torch.arange(n)
    cliques = [
        perm[i * (dim - overlap) : i * (dim - overlap) + dim] for i in range(n_cliques)
    ]

    return torch.stack(cliques, dim=0)


def separate_latents(x, index_matrix):

    n_cliques = index_matrix.shape[0]

    x = x[:, None, :].repeat(1, n_cliques, 1)
    m = index_matrix[None, ...].repeat(x.shape[0], 1, 1)
    
    return torch.gather(x, 2, m)