
import numpy as np
from collections import defaultdict

import torch
from torch_geometric.loader import NeighborLoader


def sample_positive_pairs(cluster_to_nodes: dict[int, list[int]],
                          num_anchors: int,
                          alpha: float = 0.5,
                          min_size_clusters: int = 2,
                          max_size_clusters: int = 10000,
                          seed: int = None) -> torch.Tensor:

    if num_anchors <= 0:
        return torch.empty((0, 2), dtype=torch.long)

    rng = np.random.default_rng(seed)

    clusters = {
        k: np.asarray(v, dtype=np.int64)
        for k, v in cluster_to_nodes.items()
        if min_size_clusters <= len(v) <= max_size_clusters
    }
    if not clusters:
        return torch.empty((0, 2), dtype=torch.long)

    cluster_keys = np.fromiter(clusters.keys(), dtype=np.int64)
    cluster_sizes = np.fromiter((len(clusters[k]) for k in cluster_keys), dtype=np.int64)
    num_nodes = cluster_sizes.sum()

    uniform_prob = (1.0 - alpha) / len(cluster_keys)
    size_prob = alpha * (cluster_sizes / num_nodes)
    cluster_probs = size_prob + uniform_prob
    cluster_probs /= cluster_probs.sum()

    chosen_cluster_idx = rng.choice(len(cluster_keys), size=num_anchors, p=cluster_probs)

    pairs = np.empty((num_anchors, 2), dtype=np.int64)

    uniq, inv = np.unique(chosen_cluster_idx, return_inverse=True)
    for u_pos, u in enumerate(uniq):
        mask = (inv == u_pos)
        m = mask.sum()
        nodes = clusters[cluster_keys[u]]
        n = nodes.shape[0]
        a = rng.integers(0, n, size=m, dtype=np.int64)
        b = rng.integers(0, n, size=m, dtype=np.int64)
        same = (a == b)
        while same.any():
            b[same] = rng.integers(0, n, size=int(same.sum()), dtype=np.int64)
            same = (a == b)
        pairs[mask, 0] = nodes[a]
        pairs[mask, 1] = nodes[b]

    return torch.from_numpy(pairs).to(torch.long)


def sample_negative_pairs_from_positives(
    pos_pairs: torch.Tensor,
    node_to_cluster: torch.Tensor,
    cluster_to_nodes: dict[int, list[int]],
    num_neg_per_anchor: int = 5,
    alpha: float = 0.5,
    min_size_clusters: int = 2,
    seed: int = None,
) -> torch.Tensor:
    if pos_pairs.numel() == 0 or num_neg_per_anchor <= 0:
        return torch.empty((0, 2), dtype=torch.long)

    rng = np.random.default_rng(seed)

    clusters = {
        k: np.asarray(v, dtype=np.int64)
        for k, v in cluster_to_nodes.items()
        if len(v) >= min_size_clusters
    }
    if not clusters:
        return torch.empty((0, 2), dtype=torch.long)

    cluster_keys = np.fromiter(clusters.keys(), dtype=np.int64)
    cluster_sizes = np.fromiter((len(clusters[k]) for k in cluster_keys), dtype=np.int64)
    if cluster_keys.size <= 1:
        return torch.empty((0, 2), dtype=torch.long)

    num_nodes_all = int(cluster_sizes.sum())
    uniform_prob = (1.0 - alpha) / cluster_keys.size
    size_prob = alpha * (cluster_sizes / num_nodes_all)
    base_probs = size_prob + uniform_prob
    base_probs = base_probs / base_probs.sum()

    cluster_idx_of = {ck: i for i, ck in enumerate(cluster_keys)}

    anchors = pos_pairs[:, 0].detach().cpu().to(torch.long).numpy()
    node_to_cluster_np = node_to_cluster.detach().cpu().to(torch.long).numpy()
    anchor_clusters = node_to_cluster_np[anchors]

    M = anchors.shape[0]
    K = num_neg_per_anchor

    out = np.empty((M * K, 2), dtype=np.int64)
    starts = np.arange(M, dtype=np.int64) * K

    by_cluster = {}
    for i, c in enumerate(anchor_clusters):
        by_cluster.setdefault(int(c), []).append(i)

    for c, idxs in by_cluster.items():
        idxs = np.asarray(idxs, dtype=np.int64)
        m = idxs.size
        total_draws = m * K

        probs = base_probs.copy()
        j = cluster_idx_of.get(c, None)
        if j is not None and probs[j] > 0:
            mass = 1.0 - probs[j]
            probs[j] = 0.0
            if mass > 0:
                probs /= mass
            else:
                continue

        if probs.sum() == 0 or not np.isfinite(probs).all():
            continue

        chosen_cluster_idx = rng.choice(len(cluster_keys), size=total_draws, p=probs)

        neg_nodes = np.empty(total_draws, dtype=np.int64)
        uniq, inv = np.unique(chosen_cluster_idx, return_inverse=True)
        for u_pos, u in enumerate(uniq):
            mask = (inv == u_pos)
            cnt = int(mask.sum())
            nodes = clusters[cluster_keys[u]]
            pick = rng.integers(0, len(nodes), size=cnt, dtype=np.int64)
            neg_nodes[mask] = nodes[pick]

        positions = np.repeat(starts[idxs], K) + np.tile(np.arange(K, dtype=np.int64), m)
        out[positions, 0] = np.repeat(anchors[idxs], K)
        out[positions, 1] = neg_nodes

    return torch.from_numpy(out).to(torch.long)


def get_validation_pairs(graph, num_validation_anchors: int, num_neg_per_anchor,
                         alpha_sampling: float,
                         num_neighbors_loader,
                         min_size_clusters: int = 2,
                         max_size_clusters: int = 10000,
                         seed: int = None,
                         **kwargs
                         ):

    clusters = defaultdict(list)
    for i, cluster in enumerate(graph.y.tolist()):
        clusters[cluster].append(i)

    # sample the pairs
    pos_pairs = sample_positive_pairs(cluster_to_nodes=clusters, num_anchors=num_validation_anchors,
                                      alpha=alpha_sampling, seed=seed, min_size_clusters=min_size_clusters,
                                      max_size_clusters=max_size_clusters)
    neg_pairs = sample_negative_pairs_from_positives(cluster_to_nodes=clusters,
                                                     node_to_cluster=graph.y, pos_pairs=pos_pairs,
                                                     num_neg_per_anchor=num_neg_per_anchor, alpha=alpha_sampling,
                                                     min_size_clusters=min_size_clusters, seed=seed)

    # graph nodes that will be used
    nodes = torch.cat([pos_pairs, neg_pairs], dim=0)
    nodes = nodes.view(-1)
    nodes = torch.unique(nodes)

    # create the batch
    if num_neighbors_loader is None:
        num_neighbors_loader = [10, 5]
    loader = NeighborLoader(graph, num_neighbors=num_neighbors_loader, batch_size=len(nodes), input_nodes=nodes)
    batch = next(iter(loader))

    # get the local indexes of the pairs in the batch
    pos_pairs_local = torch.stack([torch.tensor([(batch.n_id == a).nonzero(as_tuple=True)[0][0].item(),
                                                 (batch.n_id == b).nonzero(as_tuple=True)[0][0].item()
                                                 ]) for a, b in pos_pairs])
    negative_pairs_local = torch.stack([torch.tensor([(batch.n_id == a).nonzero(as_tuple=True)[0][0].item(),
                                                      (batch.n_id == b).nonzero(as_tuple=True)[0][0].item()
                                                      ]) for a, b in neg_pairs])

    return batch, pos_pairs_local, negative_pairs_local
