import torch
from torch_geometric.utils import coalesce, remove_self_loops
    
from ..graph import Graph


def edge_pruning(
    max_indegree: int,
    edge_index:   torch.LongTensor,
    pos:          torch.Tensor,
    edge_attr:    torch.Tensor = None,
) -> tuple[torch.LongTensor, torch.Tensor | None]:
    """Removes the longest edges incident to nodes with in-degree greater than max_indegree until the in-degree is equal to max_indegree.

    Args:
        max_indegree (int): Maximum in-degree allowed.
        edge_index (torch.LongTensor): Edge index tensor. Shape (2, num_edges).
        pos (torch.Tensor): Node position tensor. Shape (num_nodes, dim).
        edge_attr (torch.Tensor, optional): Edge attribute tensor. Shape (num_edges, num_edge_features). Defaults to None.

    Returns:
        tuple[torch.LongTensor, torch.Tensor | None]: The pruned edge index and edge attribute tensors (if provided).

    """
    assert edge_index.size(0) == 2, f'Expected edge_index to have shape (2, num_edges), got {edge_index.size()}'
    assert max_indegree > 0, f'Expected max_indegree to be greater than 0, got {max_indegree}'
    device = edge_index.device
    num_nodes = pos.size(0)
    indegree = torch.bincount(edge_index[1], minlength=num_nodes) # Compute the in-degree of each node
    mask = indegree > max_indegree                                # Mask of nodes with in-degree greater than max_in_degree. Shape (num_nodes,)
    # If there are no nodes with in-degree greater than max_in_degree, return the input edge_index and edge_attr
    if mask.sum() == 0:
        return edge_index, edge_attr
    masked_nodes = torch.arange(num_nodes, device=device)[mask]   # Nodes with in-degree greater than max_in_degree. Shape (num_masked_nodes,)
    senders = edge_index[0].split(indegree.tolist())              # Senders of each node. Shape (num_nodes,)
    masked_senders = [senders[i] for i in masked_nodes.tolist()]  # Senders of nodes with in-degree greater than max_in_degree. Shape (num_masked_nodes,)
    # Compute the edges to be removed
    edges_to_be_removed = torch.zeros(edge_index.size(1), dtype=torch.bool, device=device)
    # Iterate over the nodes with in-degree greater than max_in_degree
    for i, s in zip(masked_nodes, masked_senders):
        num_to_be_removed = indegree[i] - max_indegree # Number of edges to be removed
        # Get the neighbourhood-wise index of the longest edges
        lengths = torch.norm(pos[s] - pos[i], dim=1)                          # Shape (num_senders,)
        indices = torch.argsort(lengths, descending=True)[:num_to_be_removed] # Shape (num_to_be_removed,)
        # Compute the global index of the longest edges
        indices = indegree[:i].sum() + indices
        edges_to_be_removed[indices] = True
    # Remove the edges
    edge_index = edge_index[:, ~edges_to_be_removed]
    if edge_attr is not None:
        edge_attr = edge_attr[~edges_to_be_removed]
    return edge_index, edge_attr

    
def cells_to_edge_index(
    cells:        torch.LongTensor, # Shape (num_cells, max_num_nodes_per_cell)
    max_indegree: int = None,
    pos:          torch.Tensor = None,
) -> torch.LongTensor:
    """Converts a cell list to an edge index tensor.

    Args:
        cells (torch.LongTensor): Cell list tensor. Shape (num_cells, max_num_nodes_per_cell).
        max_indegree (int, optional): Maximum in-degree allowed. Defaults to None.
        pos (torch.Tensor, optional): Node position tensor. Shape (num_nodes, dim). Defaults to None.

    Returns:
        torch.LongTensor: Edge index tensor. Shape (2, num_edges).

    """
    max_num_nodes_per_cell = cells.shape[1]
    edge_index = torch.cat([
        cells[:, [i, i + 1]] for i in range(max_num_nodes_per_cell - 1)
    ], dim=0).T
    # Remove any edge with a negative or np.nan index
    mask = (edge_index >= 0).all(dim=0)
    edge_index = edge_index[:, mask]
    # Remove self-loops
    mask = edge_index[0] != edge_index[1]
    edge_index = edge_index[:, mask]
    # Make undirected graph
    edge_index = torch.cat([
        edge_index,
        edge_index[[1, 0]],
    ], dim=1)
    # Remove duplicated columns (edges) and sort them by increasing value of the second row
    edge_index = coalesce(edge_index, sort_by_row=False)
    # Edge pruning
    if max_indegree is not None:
        assert pos is not None, 'Expected pos to be provided when max_indegree is not None'
        edge_index = edge_pruning(max_indegree, edge_index, pos, edge_attr=None)[0]
    return edge_index  


# topology- and geometry-aware graph coarsening
def topology_geometry_graph_coarsening(
    pos:        torch.Tensor,
    edge_index: torch.Tensor,
    max_iter:   int = 5,
    alpha:      float = 1.0,
) -> tuple[torch.BoolTensor, torch.LongTensor]:
    """
    Guillard-style graph coarsening with a topology–geometry aware coarse-node ranking.

    The algorithm:
      1) Computes node indegree (topology).
      2) Estimates a local geometric density via mean distance to in-neighbors (geometry).
      3) Ranks nodes by a joint score: score = indegree * (local_density ** alpha).
         Lower score => higher priority to be selected as a coarse node.
      4) Greedily selects coarse nodes, suppressing their in-neighbors.
      5) Assigns each fine node to the nearest coarse in-neighbor if possible.
      6) Propagates parent assignments through neighbors until all nodes are assigned.
      7) Builds an HR->LR index map (each HR node maps to a coarse-node index).

    Args:
        pos (torch.Tensor): Node positions, shape (num_nodes, dim).
        edge_index (torch.Tensor): Edge list in COO format, shape (2, num_edges),
            interpreted as row -> col (source -> target).
        max_iter (int): Max propagation iterations for unresolved parent assignments.
        alpha (float): Exponent that controls the influence of geometric density.

    Returns:
        coarse_mask (torch.BoolTensor): Shape (num_nodes,). True if node is kept as coarse.
        idxHR_to_idxLR (torch.LongTensor): Shape (num_nodes,). HR node -> LR coarse index.
    """
    num_nodes = pos.size(0)
    device = pos.device
    row, col = edge_index  # row: sources, col: targets

    # ---------------------------
    # Step 1: Compute indegree.
    # ---------------------------
    # indegree[i] = number of incoming edges to node i.
    indegree = col.bincount(minlength=num_nodes)

    # Split 'row' into a list of tensors, where senders[i] contains all in-neighbors of i.
    # IMPORTANT: This relies on edge_index being ordered/grouped by 'col'. If not, senders will be incorrect.
    senders = row.split(indegree.tolist())

    # ---------------------------------------------------
    # Step 2: Estimate local geometric density per node.
    # ---------------------------------------------------
    # local_density[i] = mean Euclidean distance from i to its in-neighbors.
    # If a node has no in-neighbors, keep it as +inf (will increase its score).
    local_density = torch.full((num_nodes,), float('inf'), device=device)
    for i in range(num_nodes):
        s = senders[i]  # in-neighbors of i (sources that point to i)
        if len(s) > 0:
            dists = torch.norm(pos[i] - pos[s], dim=1)
            local_density[i] = dists.mean()

    # ---------------------------------------------
    # Step 3: Topology–geometry joint ranking score
    # ---------------------------------------------
    # Lower score => more likely to become a coarse node.
    score = indegree.float() * (local_density ** alpha)
    sorted_nodes = torch.argsort(score, descending=False)

    # Initialize: assume all nodes are coarse, then suppress neighbors of selected nodes.
    coarse_mask = torch.ones(num_nodes, dtype=torch.bool, device=device)
    selected = torch.zeros(num_nodes, dtype=torch.bool, device=device)  # records "chosen centers"

    # ---------------------------------------------
    # Step 4: Greedy coarse node selection (MIS-like)
    # ---------------------------------------------
    # Iterate nodes from best (lowest score) to worst; if still allowed, select it as coarse
    # and suppress its in-neighbors from being coarse.
    for i in sorted_nodes:
        if coarse_mask[i]:
            coarse_mask[senders[i]] = False
            selected[i] = True

    # --------------------------------------------------
    # Step 5: Assign each node a "parent" coarse node id.
    # --------------------------------------------------
    # parents[i] = the coarse representative of node i (node index in HR space).
    # Start with identity for all; fine nodes will be reassigned.
    parents = torch.arange(num_nodes, device=device)

    for i in range(num_nodes):
        if not coarse_mask[i]:  # only fine nodes need a parent
            s = senders[i]
            dists = torch.norm(pos[i] - pos[s], dim=1)

            # Only consider candidates that are coarse.
            valid = coarse_mask[s]
            dists[~valid] = float('inf')

            # Assign to the nearest coarse in-neighbor if it exists; otherwise mark unresolved (-1).
            if dists.min() < float('inf'):
                parents[i] = s[dists.argmin()]
            else:
                parents[i] = -1

    # ---------------------------------------------------------
    # Step 6: Propagate assignments for unresolved nodes (-1).
    # ---------------------------------------------------------
    # If a node cannot directly see a coarse in-neighbor, it inherits the parent of the nearest
    # neighbor that already has a valid parent. Repeat up to max_iter times.
    iter = 0
    while (parents == -1).any():
        iter += 1
        for i in range(num_nodes):
            if parents[i] == -1:
                s = senders[i]
                # Keep neighbors whose parent is already resolved.
                s = s[parents[s] != -1]
                if len(s) > 0:
                    dists = torch.norm(pos[i] - pos[s], dim=1)
                    # Inherit the parent's parent (coarse representative).
                    parents[i] = parents[s[dists.argmin()]]

        if iter >= max_iter:
            raise RuntimeError("Max iterations reached during parent assignment")

    # ---------------------------------------------------------
    # Step 7: Build HR -> LR mapping (coarse nodes get 0..C-1).
    # ---------------------------------------------------------
    # First assign consecutive LR indices to coarse nodes, then map every node via its parent.
    idxHR_to_idxLR = torch.full((num_nodes,), -1, dtype=torch.long, device=device)
    idxHR_to_idxLR[coarse_mask] = torch.arange(coarse_mask.sum(), device=device)

    # For each HR node i, its LR id is the LR id of its coarse parent.
    idxHR_to_idxLR = idxHR_to_idxLR[parents]

    return coarse_mask, idxHR_to_idxLR




def pool_edges(
    coarse_mask:       torch.BoolTensor,
    idxHR_to_idxLR:    torch.LongTensor,
    edge_index:        torch.LongTensor,
    max_indegree:      int = None,
    pos:               torch.Tensor = None,
) -> torch.LongTensor:
    """Pools the edges of a graph to create a lower-resolution graph. The pooling is performed by merging the edges of the high-resolution graph into a single edge in the lower-resolution graph, 
    i.e., the spatial connectivity is preserved.

    Args:
        coarse_mask (torch.BoolTensor): Mask of the lower-resolution nodes. Shape (num_nodes,).
        idxHR_to_idxLR (torch.LongTensor): Mapping from high-resolution to low-resolution indices. Shape (num_nodes,).
        edge_index (torch.LongTensor): Edge index tensor. Shape (2, num_edges).
        max_indegree (int, optional): Maximum in-degree allowed. Defaults to None.
        pos (torch.Tensor, optional): Node position tensor. Shape (num_nodes, dim). Defaults to None.

    Returns:
        torch.LongTensor: Edge index tensor of the lower-resolution graph. Shape (2, num_edges).
    """
    coarse_num_nodes = coarse_mask.sum().item() # Number of lower resolution nodes
    # Express `coarse_edge_index` in terms of the lower resolution indices
    coarse_edge_index = idxHR_to_idxLR[edge_index]
    # Remove the resulting self-loops
    coarse_edge_index = remove_self_loops(coarse_edge_index)[0]
    # Aggregate the resulting edges
    coarse_edge_index = coalesce(coarse_edge_index, num_nodes=coarse_num_nodes, sort_by_row=False)
    # Edge pruning
    if max_indegree is not None:
        assert pos is not None, 'Expected pos to be provided when max_indegree is not None'
        coarse_edge_index = edge_pruning(max_indegree, coarse_edge_index, pos)[0]
    # Checks
    assert coarse_edge_index.min() >= 0
    assert coarse_edge_index.max() == coarse_num_nodes - 1
    return coarse_edge_index



class MeshCoarsening_TGC:
    """Transform that performs multiple mesh coarsenings on a graph. The coarsening is performed by dropping nodes, assigning non-dropped nodes to dropped nodes, and merging the edges to preserve the spatial connectivity.

    Args:
        num_scales (int): Number of scales.
        max_indegree (int, optional): Maximum in-degree allowed. Defaults to None.
        rel_pos_scaling (list[float, None], optional): Scaling factor for the relative position in the lower-resolution graph. Defaults to None.
        scalar_rel_pos (bool, optional): Whether to use the scalar relative position (distance) or the vector relative position. Defaults to False (vector relative position).
    """

    def __init__(
        self,
        num_scales:      int,
        max_indegree:    int               = None,
        rel_pos_scaling: list[float, None] = None,
        scalar_rel_pos:  bool              = False,
    ) -> None:
        if rel_pos_scaling is None:
            rel_pos_scaling = [None] * (num_scales - 1)
        assert num_scales > 1, f'Expected num_scales to be greater than 1, got {num_scales}'
        assert max_indegree is None or max_indegree > 0, f'Expected max_indegree to be greater than 0, got {max_indegree}'
        assert len(rel_pos_scaling) == num_scales, f'Expected scale_edge_attr to have length {num_scales}, got {len(rel_pos_scaling)}'
        self.num_scales      = num_scales
        self.max_indegree    = max_indegree
        self.rel_pos_scaling = rel_pos_scaling
        self.scalar_rel_pos  = scalar_rel_pos

    def __call__(
        self,
        graph: Graph
    ) -> Graph:
        if graph.batch is None:
            graph.batch = torch.zeros(graph.pos.size(0), dtype=torch.long, device=graph.pos.device)
        graph.coarse_mask_2, graph.pos_2, graph.idx1_to_idx2, graph.edge_index_2, graph.edge_attr_2, graph.e_12, graph.batch_2 = topology_geometry_graph_coarsening(
            pos_1                 = graph.pos,
            edge_index_1          = graph.edge_index,
            batch_1               = graph.batch,
            max_indegree          = self.max_indegree,
            rel_pos_scaling_lr    = self.rel_pos_scaling[1],
            rel_pos_scaling_hr_lr = self.rel_pos_scaling[0],
            scalar_rel_pos        = self.scalar_rel_pos,
        )
        for i in range(2, self.num_scales):
            coarse_mask, pos, idx_to_parent, edge_index, edge_attr, e, batch = topology_geometry_graph_coarsening(
                pos_1                 = getattr(graph, f'pos_{i}'),
                edge_index_1          = getattr(graph, f'edge_index_{i}'),
                batch_1               = getattr(graph, f'batch_{i}'),
                max_indegree          = self.max_indegree,
                rel_pos_scaling_lr    = self.rel_pos_scaling[i],
                rel_pos_scaling_hr_lr = self.rel_pos_scaling[i-1],
                scalar_rel_pos        = self.scalar_rel_pos,
            )
            setattr(graph, f'coarse_mask_{i+1}', coarse_mask)
            setattr(graph, f'pos_{i+1}', pos)
            setattr(graph, f'idx{i}_to_idx{i+1}', idx_to_parent)
            setattr(graph, f'edge_index_{i+1}', edge_index)
            setattr(graph, f'edge_attr_{i+1}', edge_attr)
            setattr(graph, f'e_{i}{i+1}', e)
            if batch is not None: setattr(graph, f'batch_{i+1}', batch)
        return graph
    

def compute_cell_properties(
    graph: Graph,
) -> Graph:
    """Computes the properties of the cells in a graph: the centroid, area, normal, and the normal at each node.

    Args:
        graph (Graph): Input graph.
    """
    assert hasattr(graph, 'cell_list'), 'Expected graph to have a cell_list attribute'
    device = graph.pos.device
    num_cells = len(graph.cell_list)
    graph.num_nodes_per_cell = torch.tensor([(cell >= 0).sum() for cell in graph.cell_list], device=device)
    # Get the centroid of each cell
    graph.cell_centroid = torch.stack([graph.pos[cell].mean(dim=0) for cell in graph.cell_list], dim=0)
    # Get the area of each cell
    graph.cell_area = torch.zeros(num_cells, device=device)
    for idx, cell in enumerate(graph.cell_list):
        for i, j in zip(cell, torch.cat([cell[1:], cell[:1]])):
            centroid = graph.cell_centroid[idx]
            graph.cell_area[idx] += (0.5 * torch.cross(graph.pos[i] - centroid, graph.pos[j] - centroid).norm())
    # Find the normal vector for each cell
    p0 = torch.stack([graph.pos[cell[0]] for cell in graph.cell_list], dim=0)
    p1 = torch.stack([graph.pos[cell[1]] for cell in graph.cell_list], dim=0)
    p2 = torch.stack([graph.pos[cell[2]] for cell in graph.cell_list], dim=0)
    v1 = p1 - p0
    v2 = p2 - p0
    graph.cell_normal = torch.cross(v1, v2)
    graph.cell_normal = graph.cell_normal / torch.norm(graph.cell_normal, dim=1, keepdim=True)
    # Direct it outwards:
    # Find the mean vector from each point to the centroid
    v = torch.mean(graph.cell_centroid.unsqueeze(0) - graph.pos.unsqueeze(1), dim=0)
    # Find the dot product between the normal and the vector
    dot = torch.sum(v * graph.cell_normal, dim=1)
    # If the dot product is negative, invert the normal
    graph.cell_normal[dot < 0] = -graph.cell_normal[dot < 0]
    # Find the normal at each point by averaging the normals of the cells that share that point
    graph.normal = torch.zeros_like(graph.pos)
    for idx, cell in enumerate(graph.cell_list):
        for node in cell:
            graph.normal[node] += graph.cell_normal[idx]
    graph.normal = graph.normal / torch.norm(graph.normal, dim=1, keepdim=True)
    return graph


class ComputeNormals:
    """Transform that computes the normal vector at each node in a graph.

    Args:
        del_cell_list (bool, optional): Whether to delete the cell_list attribute after computing the normals. Defaults to True.
    """

    def __init__(
        self,
        del_cell_list = True
    ) -> None:
        self.del_cell_list = del_cell_list

    def __call__(
        self,
        graph: Graph,
    ) -> Graph:
        assert hasattr(graph, 'cell_list'), 'Expected graph to have a cell_list attribute'
        # Get the centroid of each cell
        graph.cell_centroid = torch.stack([graph.pos[cell].mean(dim=0) for cell in graph.cell_list], dim=0)
        # Find the normal vector for each cell
        p0 = torch.stack([graph.pos[cell[0]] for cell in graph.cell_list], dim=0)
        p1 = torch.stack([graph.pos[cell[1]] for cell in graph.cell_list], dim=0)
        p2 = torch.stack([graph.pos[cell[2]] for cell in graph.cell_list], dim=0)
        pn = torch.stack([graph.pos[cell[-1]] for cell in graph.cell_list], dim=0)
        v1 = p1 - p0
        v2 = p2 - p0
        vn = pn - p0
        graph.cell_normal_12 = torch.cross(v1, v2)
        graph.cell_normal_2n = torch.cross(v2, vn)
        graph.cell_normal_1n = torch.cross(v1, vn)
        # Pick the maximum normal
        cell_normal_norm_12 = torch.norm(graph.cell_normal_12, dim=1, keepdim=True)
        cell_normal_norm_2n = torch.norm(graph.cell_normal_2n, dim=1, keepdim=True)
        cell_normal_norm_1n = torch.norm(graph.cell_normal_1n, dim=1, keepdim=True)
        graph.cell_normal = torch.where(
            cell_normal_norm_12 > cell_normal_norm_2n,
            graph.cell_normal_12,
            graph.cell_normal_2n,
        )
        graph.cell_normal = torch.where(
            cell_normal_norm_1n > torch.norm(graph.cell_normal, dim=1, keepdim=True),
            graph.cell_normal_1n,
            graph.cell_normal,
        )
        graph.cell_normal = graph.cell_normal / torch.norm(graph.cell_normal, dim=1, keepdim=True)
        # Direct it outwards:
        # Find the mean vector from each point to the centroid
        v = torch.mean(graph.cell_centroid.unsqueeze(0) - graph.pos.unsqueeze(1), dim=0)
        # Find the dot product between the normal and the vector
        dot = torch.sum(v * graph.cell_normal, dim=1)
        # If the dot product is negative, invert the normal
        graph.cell_normal[dot < 0] = -graph.cell_normal[dot < 0]
        # Find the normal at each point by averaging the normals of the cells that share that point
        graph.normal = torch.zeros_like(graph.pos)
        for idx, cell in enumerate(graph.cell_list):
            graph.normal[cell] += graph.cell_normal[idx]
        graph.normal = graph.normal / torch.norm(graph.normal, dim=1, keepdim=True)
        if self.del_cell_list:
            delattr(graph, 'cell_list')
        return graph
