
import numpy as np
from collections import defaultdict
from scipy.cluster.hierarchy import to_tree


def get_samples_for_purity(clusters: dict[int, list[int]],
                           labels: np.array,
                           partition: list[list[int]],
                           num_samples: int = 1, seed: int = 1) -> dict[int, list]:

    np.random.seed(seed)

    # candidates nodes
    candidates = []
    for v in clusters.values():
        if len(v) < 2:
            continue
        candidates.extend(v)
    candidates = np.array(candidates)

    # link each node to its leiden cluster
    node2partition = {}
    for i, partition_cluster in enumerate(partition):
        for node in partition_cluster:
            node2partition[node] = i

    # create the samples
    current_num_samples = 0
    final_samples = defaultdict(list)
    while current_num_samples < num_samples:

        num_missing_samples = num_samples - current_num_samples
        samples = np.random.randint(len(candidates), size=min(num_missing_samples, 1000))
        samples = candidates[samples]
        samples_label = labels[samples]
        samples_partition = [node2partition[i] for i in samples]

        for sample, label, partition_idx in zip(samples, samples_label, samples_partition):

            # sample a node with the same label
            other_sample = np.random.randint(len(clusters[label]) - 1, size=1)[0]
            other_sample = [n for n in clusters[label] if n != sample][int(other_sample)]
            other_sample_partition_idx = node2partition[other_sample]

            # if both nodes are in the same leiden cluster, we can compute the purity
            if partition_idx == other_sample_partition_idx:
                sample_idx, other_sample_idx = None, None
                for i, n in enumerate(partition[partition_idx]):
                    if n == sample:
                        sample_idx = i  # relative index of sample in the leiden cluster
                    elif n == other_sample:
                        other_sample_idx = i  # relative index of the other sample in the leiden cluster
                assert sample_idx is not None and other_sample_idx is not None
                final_samples[partition_idx].append((sample_idx, other_sample_idx, label))
                current_num_samples += 1

            # stop if we have enough samples
            if current_num_samples >= num_samples:
                break

    return final_samples


class ClusterTree:

    def __init__(self, Z):
        self.root = to_tree(Z, rd=False)
        self.root.parent = None
        nodes_to_explore = [self.root]
        self.leaves = {}
        while len(nodes_to_explore) > 0:
            node = nodes_to_explore.pop(0)
            if node.left is not None:
                node.left.parent = node
                nodes_to_explore.append(node.left)
            if node.right is not None:
                node.right.parent = node
                nodes_to_explore.append(node.right)
            if node.left is None and node.right is None:
                self.leaves[node.id] = node

    def get_path_to_root(self, node):
        path = []
        while node is not None:
            path.append(node)
            node = node.parent
        return path

    def find_common_ancestor(self, idx_a, idx_b):
        node_a = self.leaves[idx_a]
        path_a = self.get_path_to_root(node_a)
        node_b = self.leaves[idx_b]
        path_b = self.get_path_to_root(node_b)
        common_ancestor = None
        for ancestor in path_b:
            if ancestor.id in [node.id for node in path_a]:
                common_ancestor = ancestor
                break
        return common_ancestor

    def get_descendants(self, ancestor) -> list[int]:
        descendants = []
        nodes_to_explore = [ancestor]
        while len(nodes_to_explore) > 0:
            node = nodes_to_explore.pop(0)
            if node.left is not None:
                nodes_to_explore.append(node.left)
            if node.right is not None:
                nodes_to_explore.append(node.right)
            if node.left is None and node.right is None:
                descendants.append(node.id)
        return descendants


def compute_leiden_cluster_purity(linkage_z,
                                  leiden_cluster: list[int],
                                  samples: list[tuple[int, int, int]],
                                  labels: np.array) -> float:
    purity = 0.
    tree = ClusterTree(linkage_z)
    for sample_idx, other_idx, label in samples:
        common_ancestor = tree.find_common_ancestor(sample_idx, other_idx)
        descendants = tree.get_descendants(common_ancestor)  # relative idx in the leiden cluster
        descendants_nodes = np.array(leiden_cluster)[descendants]  # node id of the descendants
        descendants_nodes_labels = labels[descendants_nodes]  # labels of the descendants
        purity += np.mean(descendants_nodes_labels == label)
    return purity
