"""
    Main utils script.
    This is part of our official implementation of our paper: "Virtual Nodes Go Temporal".
"""

from typing import Tuple, Optional, Literal, Dict
import math
import torch
from torch import nn
from torch.nn import functional as F

# For the Louvain clustering
try:
    import networkx as nx
    _HAS_NX = True
except Exception:
    _HAS_NX = False

# For the GAT-Based implementation/aggregation in the VNs.
try:
    from torch_geometric.nn import GATConv
    _HAS_PYG = True
except Exception:
    _HAS_PYG = False


# -------------------------------
# Utilities
# -------------------------------

def _device_of(x):
    """
        Simply extract the device of a given tensor.
        Helpful to avoid some device-related issues.
    """
    return x.device if x.is_cuda or x.device.type != 'cpu' else torch.device('cpu')

def degrees_from_edge_index(edge_index, num_nodes, device):
    """
        Return de degree vectors (undirected).
    """
    device = device or edge_index.device
    deg = torch.zeros(num_nodes, device=device)
    src, dst = edge_index
    ones = torch.ones(src.numel(), device=device)
    deg.index_add_(0, src, ones)
    deg.index_add_(0, dst, ones)
    return deg

def adjacency_random_projection(edge_index: torch.Tensor,
                                num_nodes: int,
                                dproj: int = 64,
                                device: Optional[torch.device] = None) -> torch.Tensor:
    """
    Computes a projection of a matrix A, by suming randpm features of neighbors. 
    Mainly compute the following operation: X ≈ A * R.
    This is useful for the k-means clustering. 
    """
    device = device or edge_index.device
    R = torch.randn(num_nodes, dproj, device=device)
    X = torch.zeros_like(R)
    src, dst = edge_index
    X.index_add_(0, src, R[dst])
    X.index_add_(0, dst, R[src])
    deg = degrees_from_edge_index(edge_index, num_nodes, device).clamp_min_(1.0).unsqueeze(1)
    X = X / deg
    return X

def degree_normalized_mean(x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
    """
    This is a degree-normalization of the adjacency matrix: Z = D^{-1} A X.
    This is useful for the case where we use degree-based aggregations.
    """
    num_nodes = x.size(0)
    src, dst = edge_index
    out = torch.zeros_like(x)
    out.index_add_(0, dst, x[src])
    deg = degrees_from_edge_index(edge_index, num_nodes, device=x.device).clamp_min_(1.0).unsqueeze(1)
    return out / deg


class MeanDegMixer(nn.Module):
    """
    Degree-normalized mean aggregator to construct the VNs.
    """
    def __init__(self, in_dim: int, out_dim: int, project: bool = True, bias: bool = False):
        super().__init__()
        self.project = project or (in_dim != out_dim)
        self.lin = nn.Linear(in_dim, out_dim, bias=bias) if self.project else nn.Identity()

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        z = degree_normalized_mean(x, edge_index)
        return self.lin(z)


def kmeans_pp_init(X: torch.Tensor, k: int, seed: int = 0) -> torch.Tensor:
    """
    This function initialize the different center (based on the k-means++ approach)
    """
    g = torch.Generator(device=X.device).manual_seed(seed)
    n = X.size(0)
    centers = torch.empty(k, X.size(1), device=X.device)
    # First center
    idx = torch.randint(0, n, (1,), generator=g, device=X.device)
    centers[0] = X[idx]
    # Geneate the rest of centers
    closest_dist_sq = torch.cdist(X, centers[0:1], p=2).square().squeeze(1)
    for c in range(1, k):
        probs = (closest_dist_sq / closest_dist_sq.sum().clamp_min(1e-12)).clamp_min(1e-12)
        idx = torch.multinomial(probs, 1, generator=g)
        centers[c] = X[idx]
        d2 = torch.cdist(X, centers[c:c+1], p=2).square().squeeze(1)
        closest_dist_sq = torch.minimum(closest_dist_sq, d2)
    return centers

def kmeans_torch(X: torch.Tensor, k: int, iters: int = 20, seed: int = 0) -> torch.Tensor:
    """
    Implementation of the K-Means algorithm, which is based on the Lloyd's k-means.
    Simply takes a matrix X: [n, d] and return assignment labels: [n] in [0..k-1].
    ---
    The full algorithm is provided in Algorithm 2 of our paper, please refer to that one.
    """
    centers = kmeans_pp_init(X, k, seed=seed)
    labels = torch.zeros(X.size(0), dtype=torch.long, device=X.device)

    for _ in range(iters):
        # assign
        d2 = torch.cdist(X, centers, p=2).square()
        new_labels = torch.argmin(d2, dim=1)
        if torch.equal(new_labels, labels):
            break
        labels = new_labels
        # update
        for j in range(k):
            mask = (labels == j)
            if mask.any():
                centers[j] = X[mask].mean(dim=0)
    return labels

def scatter_mean_by_assignment(X: torch.Tensor, assignment: torch.Tensor, k: int) -> torch.Tensor:
    d = X.size(1)
    out = torch.zeros(k, d, device=X.device, dtype=X.dtype)
    cnt = torch.zeros(k, 1, device=X.device, dtype=X.dtype)
    out.index_add_(0, assignment, X)
    cnt.index_add_(0, assignment, torch.ones(X.size(0), 1, device=X.device, dtype=X.dtype))
    return out / cnt.clamp_min(1.0)

def build_vn_edge_index(assignment: torch.Tensor,
                        num_nodes: int,
                        k: int,
                        connect: Literal["clique", "star", "ring"] = "clique",
                        device: Optional[torch.device] = None) -> torch.Tensor:
    """
    This function is to build the VN augmentation, by adding them into the edge_index.
    Following our paper, we note that following:
        - A children is always connected to its corresponding VN
        - The VNs are all connected. Note that this controlled by the "connect" parameter (clique by default).
        - Use "star" or "ring" otherwise.
    """
    device = device or assignment.device
    n = num_nodes
    vn_offset = n
    nodes = torch.arange(n, device=device)
    vn_ids = torch.arange(vn_offset, vn_offset + k, device=device)
    node2vn = vn_offset + assignment

    e_nv = torch.stack([nodes, node2vn], dim=0)
    e_vn = torch.stack([node2vn, nodes], dim=0)

    # This is the paper's default, all VNs are connected
    if connect == "clique":
        I, J = torch.meshgrid(vn_ids, vn_ids, indexing="ij")
        mask = (I != J)
        e_vv = torch.stack([I[mask], J[mask]], dim=0)
    
    # In case the user want to try other possibles connections (note that in the paper, we didn't include this)
    elif connect == "star":
        center = vn_ids[:1]
        leaves = vn_ids[1:]
        e1 = torch.stack([center.repeat_interleave(leaves.numel()), leaves.repeat(center.numel())], dim=0)
        e2 = torch.stack([e1[1], e1[0]], dim=0)
        e_vv = torch.cat([e1, e2], dim=1)
    elif connect == "ring":
        a = vn_ids
        b = torch.roll(a, shifts=-1)
        e1 = torch.stack([a, b], dim=0)
        e2 = torch.stack([b, a], dim=0)
        e_vv = torch.cat([e1, e2], dim=1)
    else:
        raise ValueError(f"Unknown VN connectivity: {connect}")

    edge_index = torch.cat([e_nv, e_vn, e_vv], dim=1).long()
    return edge_index


def assignment_kmeans_adjacency(edge_index: torch.Tensor,
                                num_nodes: int,
                                k: int,
                                dproj: int = 64,
                                iters: int = 20,
                                seed: int = 0) -> torch.Tensor:
    """
    This function applies the clustering to the input edge_index using the differently previously coded functions.
    ---
    Note that this is the main algorithm (based on K-Means). 
    """
    device = edge_index.device
    X = adjacency_random_projection(edge_index, num_nodes, dproj=dproj, device=device)
    labels = kmeans_torch(X, k=k, iters=iters, seed=seed)
    return labels

def assignment_random(num_nodes: int, k: int, seed: int = 0, device=None) -> torch.Tensor:
    """
    This function does a random clustering
    --
    This is related to the results provided in the paper, in which we compare the K-Means to Random.
    """
    g = torch.Generator(device=device or torch.device('cpu')).manual_seed(seed)
    return torch.randint(low=0, high=k, size=(num_nodes,), generator=g, device=device)

def assignment_louvain_greedy(edge_index: torch.Tensor, num_nodes: int, k: int) -> torch.Tensor:
    """
    This is an implementation of the Louvain Greedy algorithm.
    ---
        Note that it uses NetworkX, so make sure to import it.
    """
    if not _HAS_NX:
        raise RuntimeError("networkx not available")

    src, dst = edge_index.detach().cpu().numpy()
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(zip(src, dst))
    comms = list(nx.algorithms.community.greedy_modularity_communities(G))

    # Map communities to labels
    labels = torch.full((num_nodes,), -1, dtype=torch.long)
    for cid, nodes in enumerate(comms):
        labels[list(nodes)] = cid

    # Adjust to exactly k groups if needed
    c = len(comms)
    if c == k:
        return labels.to(edge_index.device)
    elif c > k:
        # merge smallest communities into largest until k
        sizes = [(cid, (labels == cid).sum().item()) for cid in range(c)]
        sizes.sort(key=lambda x: x[1], reverse=True)
        keep = [cid for cid, _ in sizes[:k]]
        remap = {old: i for i, old in enumerate(keep)}
        # any other community shall be assigned to nearest (by Jaccard over neighbor sets is heavy).
        j = 0
        for cid, _ in sizes[k:]:
            tgt = keep[j % k]; j += 1
            labels[labels == cid] = tgt
        # Normalize labels
        norm = torch.empty_like(labels)
        for new, old in enumerate(keep):
            norm[labels == old] = new
        # For the leftovers, -1s evently
        if (norm == -1).any():
            iso = torch.where(norm == -1)[0]
            for idx, u in enumerate(iso):
                norm[u] = idx % k
        return norm.to(edge_index.device)
    else:  # Split largest groups by degree order
        sizes = [(cid, (labels == cid).sum().item()) for cid in range(c)]
        sizes.sort(key=lambda x: x[1], reverse=True)
        next_label = c
        for cid, _ in sizes:
            if next_label >= k:
                break
            nodes = torch.where(labels == cid)[0]
            if nodes.numel() <= 1:
                continue
            half = nodes.numel() // 2
            labels[nodes[:half]] = next_label
            next_label += 1
        # fill if still <k
        while next_label < k:
            u = torch.randint(0, num_nodes, (1,))
            labels[u] = next_label
            next_label += 1
        # normalize to 0..k-1
        uniq = torch.unique(labels)
        remap = {int(old): i for i, old in enumerate(uniq.tolist())}
        for old, new in remap.items():
            labels[labels == old] = new
        # Return assignments
        return labels.to(edge_index.device)

class VNMixer(nn.Module):
    """
    This is our VNMixer, which do the message-passing from in the different VNs. Note that here we use a GAT-Based aggregation.
    ---
    Note that you need torch_geometric import for the GATConv.
    """
    def __init__(self, in_dim: int, out_dim: int, heads: int = 2, layers: int = 1, dropout: float = 0.0):
        super().__init__()
        if not _HAS_PYG:
            raise RuntimeError("torch_geometric is required for GATConv.")
        self.layers = nn.ModuleList()
        hdim = out_dim // heads if layers == 1 else max(out_dim // heads, 32)
        self.layers.append(GATConv(in_dim, hdim, heads=heads, add_self_loops=False, dropout=dropout))
        for _ in range(layers - 1):
            self.layers.append(GATConv(hdim * heads, out_dim // heads, heads=heads, add_self_loops=False, dropout=dropout))

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        h = x
        for i, conv in enumerate(self.layers):
            h = conv(h, edge_index)
            if i < len(self.layers) - 1:
                h = F.elu(h)
        return h


AssignmentMethod = Literal["kmeans_adj", "random", "louvain"]
VNConnect = Literal["clique", "star", "ring"] # Note that the default that we use in the paper is "clique"

def build_vn_graph_and_mix(memory_bank: torch.Tensor,
                           edge_index_snapshot: torch.Tensor,
                           num_nodes: int,
                           k: int,
                           in_dim: int,
                           out_dim: int,
                           method: AssignmentMethod = "kmeans_adj",
                           connect: VNConnect = "clique",
                           mixer: Optional[nn.Module] = None,
                           method_kwargs: Optional[Dict] = None,
                           seed: int = 0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    This is the full function that is used in the main script (following Algorithm 1 in the paper). 
        - Compute the assignment Π
        - build the Virtual Nodes
        - do a mixer pass on the current nodes and the VNs.
    """
    device = memory_bank.device
    method_kwargs = method_kwargs or {}

    # 1) Assignment Π (Algorithm 1/2 in the paper)
    if method == "kmeans_adj":
        assignment = assignment_kmeans_adjacency(edge_index_snapshot, num_nodes, k,
                                                 dproj=method_kwargs.get("dproj", 64),
                                                 iters=method_kwargs.get("iters", 20),
                                                 seed=seed)
    elif method == "random":
        assignment = assignment_random(num_nodes, k, seed=seed, device=device)
    elif method == "louvain":
        if not _HAS_NX:
            raise RuntimeError("Louvain requires networkx.")
        assignment = assignment_louvain_greedy(edge_index_snapshot, num_nodes, k)
    else:
        raise ValueError(f"Unknown assignment method: {method}")

    # 2) Build the VNs.
    vn_edge_index = build_vn_edge_index(assignment, num_nodes, k, connect=connect, device=device)

    # 3) Init VNs using the mean-aggregation and then use the mixer.
    vn_init = scatter_mean_by_assignment(memory_bank, assignment, k)
    x_aug = torch.cat([memory_bank, vn_init], dim=0)
    if mixer is None:
        mixer = VNMixer(in_dim, out_dim, heads=2, layers=1, dropout=0.0).to(device)
    z_aug = mixer(x_aug, vn_edge_index)
    z_nodes = z_aug[:num_nodes]

    return z_nodes, assignment, vn_edge_index
