from __future__ import annotations

import re
from typing import Generator, overload

import torch
import torch.nn.functional as F


def pad_adjacency_matrix(
    adjacency_matrix: list[torch.Tensor],
    /,
    pad_value: float = 0.0,
) -> torch.Tensor:
    """
    Pad the adjacency matrix to a fixed size.

    Args:
        adjacency_matrix (torch.Tensor): The adjacency matrix to pad.
        pad_value (float): The value to use for padding. Default is 0.0.
    """
    max_nodes = max(adj.shape[0] for adj in adjacency_matrix)

    padded = list[torch.Tensor]()
    for adj in adjacency_matrix:
        pad = (0, max_nodes - adj.shape[0], 0, max_nodes - adj.shape[1])
        padded.append(F.pad(adj, pad, value=pad_value))
    return torch.stack(padded, dim=0)


_Edge = tuple[str, str, str]  # (src, edge_type, dst)


def get_all_edges(graph: str, /) -> Generator[_Edge, None, None]:
    for m in re.finditer(r"(\(\w+\s*\d*,\s+\w+,\s+\w+\s*\d*\))", graph):
        edge = m.group(1)
        edge = edge.replace("(", "").replace(")", "")
        edge = edge.split(",")
        if len(edge) != 3:
            raise ValueError(
                f"Invalid edge format: {edge}. Expected format: (src, edge_type, dst)"
            )
        src, edge_type, dst = edge
        yield src.strip(), edge_type.strip(), dst.strip()


def map_observation_to_graph(
    graph: str | list[str], /
) -> tuple[list[str], torch.Tensor, list[str]]:
    if isinstance(graph, str):
        graph = [graph]

    adj = dict[str, set[tuple[str, str]]]()
    rels = set[str]()
    for g in graph:
        for node_from, rel, node_to in get_all_edges(g):
            if node_from not in adj:
                adj[node_from] = set()
            if node_to not in adj:
                adj[node_to] = set()
            adj[node_from].add((node_to, rel))
            rels.add(rel)

    nodes = list(adj.keys())
    rels = ["none"] + list(rels)

    adjacency_matrix = torch.zeros((len(nodes), len(nodes)), dtype=torch.int64)

    for i, node_from in enumerate(nodes):
        for node_to, rel in adj[node_from]:
            j = nodes.index(node_to)
            adjacency_matrix[i, j] = rels.index(rel)

    return nodes, adjacency_matrix, rels


def node_pruning(
    hidden_states: torch.Tensor,  # (batch_size, num_nodes, embed_dim)
    adjacency_matrix: torch.Tensor,  # (batch_size, num_nodes, num_nodes)
    *,
    dropout: float = 0.2,
):
    removal_indices = (
        torch.rand(
            *hidden_states.shape[:-1],
            device=hidden_states.device,
        )
        < dropout
    )
    hidden_states = hidden_states.masked_fill(
        removal_indices.unsqueeze(-1).expand_as(hidden_states), 0.0
    )
    adjacency_matrix = adjacency_matrix.masked_fill(
        removal_indices.unsqueeze(-1).expand_as(adjacency_matrix), 0.0
    )
    adjacency_matrix = adjacency_matrix.masked_fill(
        removal_indices.unsqueeze(-2).expand_as(adjacency_matrix), 0.0
    )
    return hidden_states, adjacency_matrix


def edge_pruning(
    adjacency_matrix: torch.Tensor,  # (batch_size, num_nodes, num_nodes)
    *,
    dropout: float = 0.2,
):
    removal_indices = (
        torch.rand(*adjacency_matrix.shape, device=adjacency_matrix.device)
        < dropout
    )
    adjacency_matrix = adjacency_matrix.masked_fill(removal_indices, 0.0)
    return adjacency_matrix
