
import numpy as np
import igraph as ig
import leidenalg

from scipy.spatial.distance import pdist
from torch_geometric.data import Data


def create_leiden_partition(graph: Data, resolution_parameter: float = 1.,
                            max_comm_size: int = 65000,
                            filtered_nodes: set = None,
                            ) -> list[list[int]]:

    num_nodes = graph.x.shape[0]
    edge_index = graph.edge_index.cpu().numpy()
    row, col = edge_index
    edges = list(zip(row, col))
    g = ig.Graph(edges=edges, n=num_nodes, directed=False)

    partition = leidenalg.find_partition(g, leidenalg.RBConfigurationVertexPartition,
                                         resolution_parameter=resolution_parameter,
                                         max_comm_size=max_comm_size)

    if filtered_nodes is not None:
        filtered_partition = []
        for cluster in partition:
            filtered = [node for node in cluster if node in filtered_nodes]
            if len(filtered) > 0:
                filtered_partition.append(filtered)

        return sorted(filtered_partition, key=lambda x: len(x), reverse=True)

    return sorted(partition, key=lambda x: len(x), reverse=True)


def blockwise_pdist(X, block_size=10000, dtype=np.float32):

    X = X.astype(dtype)

    if X.shape[0] <= block_size:
        return pdist(X, metric='cosine')

    n = X.shape[0]
    result = np.empty(n * (n - 1) // 2, dtype=dtype)

    for i_start in range(0, n, block_size):

        i_end = min(i_start + block_size, n)
        Xi = X[i_start:i_end]

        # Intra-bloc
        dists = pdist(Xi, metric="cosine")
        block_len = i_end - i_start
        idx = 0
        for i in range(block_len):
            ii = i_start + i
            jj_start = i + 1
            num_j = block_len - jj_start
            js = np.arange(i_start + jj_start, i_start + block_len)
            indices = ii * n - ii * (ii + 1) // 2 + js - ii - 1
            result[indices] = dists[idx:idx + num_j]
            idx += num_j

        # Inter-bloc
        for j_start in range(i_end, n, block_size):
            j_end = min(j_start + block_size, n)
            Xj = X[j_start:j_end]
            D = 1 - np.clip(np.dot(Xi, Xj.T), -1.0, 1.0)
            ii = np.arange(i_start, i_end).reshape(-1, 1)
            jj = np.arange(j_start, j_end).reshape(1, -1)
            ii_b, jj_b = np.broadcast_arrays(ii, jj)
            flat_i = ii_b.ravel()
            flat_j = jj_b.ravel()
            i_min = np.minimum(flat_i, flat_j)
            j_max = np.maximum(flat_i, flat_j)
            indices = i_min * n - i_min * (i_min + 1) // 2 + j_max - i_min - 1
            result[indices] = D.ravel()

    return result
