# gnn-meta-graph/src/meta_graph.py

import torch
import torch.nn.functional as F
from torch_geometric.data import Data

"""
meta_graph.py

Utilities for attention-based subgraph extraction and construction of meta-graphs.
Includes:
- Node importance scoring from attention weights
- Subgraph construction via thresholded importance
- Meta-graph creation using intra- and inter-modality similarity
"""


def compute_node_attention_scores(attention_scores, edge_index, num_nodes):
    """
    Compute node-level weights using extracted multi-head attention scores.

    Args:
        attention_scores (torch.Tensor): Multi-head edge attention scores of shape [num_edges, num_heads].
        edge_index (torch.Tensor): Edge indices of shape [2, num_edges].
        num_nodes (int): Number of nodes in the graph.

    Returns:
        torch.Tensor: Node attention weights of shape [num_nodes].
    """

    num_edges = min(edge_index.shape[1], attention_scores.shape[0])  # Ensure valid indexing
    edge_index = edge_index[:, :num_edges]  # Truncate excess edges
    attention_scores = attention_scores[:num_edges]  # Truncate excess scores

    node_scores = torch.zeros(num_nodes, device=attention_scores.device)

    # Aggregate attention scores across heads
    attention_scores = attention_scores.mean(dim=1)  # Shape: [num_edges]

    # Ensure edges don't reference nodes outside valid range
    valid_mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
    edge_index = edge_index[:, valid_mask]  # Remove invalid edges
    attention_scores = attention_scores[valid_mask]  # Adjust attention scores

    # Aggregate attention scores per node
    for i, (src, dst) in enumerate(edge_index.t()):
        node_scores[src] += attention_scores[i].item()
        node_scores[dst] += attention_scores[i].item()

    return node_scores


def normalize_attention_scores(node_scores):
    min_score = node_scores.min()
    max_score = node_scores.max()
    normalized_scores = (node_scores - min_score) / (max_score - min_score + 1e-5)
    return normalized_scores


def adjust_edge_index(edge_index, node_subset):
    """Remap edge_index so it only contains nodes from node_subset."""
    node_map = {old_idx.item(): new_idx for new_idx, old_idx in enumerate(node_subset)}
    valid_edges = [edge for edge in edge_index.t().tolist() if edge[0] in node_map and edge[1] in node_map]
    
    if valid_edges:
        new_edge_index = torch.tensor([[node_map[src], node_map[dst]] for src, dst in valid_edges], dtype=torch.long).t()
        return new_edge_index
    else:
        return torch.empty((2, 0), dtype=torch.long)



def construct_subgraph(graph, significant_nodes, edge_weight_threshold=0.5):
    node_indices = significant_nodes.nonzero(as_tuple=True)[0]  # Get selected node indices
    graph.x = graph.x.T
    
    if node_indices.numel() == 0:
        print("⚠️ No significant nodes found. Selecting a fallback node.")
        node_indices = torch.tensor([0])

    max_index = graph.x.size(0) - 1
    if node_indices.max().item() > max_index:
        raise ValueError(
            f"🚨 Error: Selected node index {node_indices.max().item()} exceeds "
            f"maximum valid index {max_index}."
        )

    # Subset **nodes** (rows), keeping **all features** (columns)
    new_x = graph.x[node_indices, :]  # Ensure features remain intact
    # Adjust edge_index to reference the selected nodes
    new_edge_index = adjust_edge_index(graph.edge_index, node_indices)
    subgraph = Data(x=new_x, edge_index=new_edge_index)
    subgraph.y = graph.y  # Keep the original label

    return subgraph


def construct_subgraphs(node_attention_scores, raw_graphs, encoded_graphs, threshold):
    """
    Construct subgraphs based on node attention scores.

    Args:
        node_attention_scores (torch.Tensor): Node attention weights.
        raw_graphs (list): Original graph data.
        encoded_graphs (list): Encoded graph data with transformed features.
        threshold (float): Threshold for selecting important nodes.

    Returns:
        list: List of constructed subgraphs.
    """
    subgraphs = []
    for raw_graph, encoded_graph in zip(raw_graphs, encoded_graphs):
        normalized_scores = normalize_attention_scores(node_attention_scores)
        significant_nodes = normalized_scores >= threshold
        subgraph = construct_subgraph(encoded_graph, significant_nodes)
        subgraphs.append(subgraph)
    
    return subgraphs


def compute_global_similarity_matrix(graphs):

    all_node_embeddings = torch.cat([graph.x for graph in graphs], dim=0)
    similarity_matrix = F.cosine_similarity(
        all_node_embeddings.unsqueeze(1), all_node_embeddings.unsqueeze(0), dim=2
    )
    return similarity_matrix


def threshold_similarity(similarity_matrix, threshold=0.5):
    edge_indices = (similarity_matrix > threshold).nonzero(as_tuple=False)
    return edge_indices.t().contiguous()

def construct_meta_graph(graphs, similarity_matrix, threshold=0.5):
    combined_node_features = torch.cat([graph.x for graph in graphs], dim=0)
    
    intra_graph_edges = []
    offset = 0

    for graph in graphs:
        intra_edges = graph.edge_index.clone()
        intra_edges = intra_edges + offset  # Apply offset correctly
        intra_graph_edges.append(intra_edges)
        offset += graph.x.shape[0]  # Increment offset properly

    intra_graph_edges = torch.cat(intra_graph_edges, dim=1) if intra_graph_edges else torch.empty((2, 0), dtype=torch.long)

    inter_graph_edges = threshold_similarity(similarity_matrix, threshold)

    # Ensure inter_graph_edges do not exceed node count
    valid_mask = (inter_graph_edges[0] < combined_node_features.shape[0]) & (inter_graph_edges[1] < combined_node_features.shape[0])
    inter_graph_edges = inter_graph_edges[:, valid_mask]  # Fix shape mismatch


    # Ensure indices are valid before concatenation
    if intra_graph_edges.numel() > 0:
        combined_edges = torch.cat([intra_graph_edges, inter_graph_edges], dim=1)
    else:
        combined_edges = inter_graph_edges
    meta_graph = Data(x=combined_node_features, edge_index=combined_edges)
    return meta_graph


def construct_meta_graphs(subgraphs_list):
    meta_graphs = []
    for i in range(len(subgraphs_list[0])):
        similarity_matrix = compute_global_similarity_matrix(
            [subgraphs[i] for subgraphs in subgraphs_list]
        )
        meta_graph = construct_meta_graph(
            [subgraphs[i] for subgraphs in subgraphs_list], similarity_matrix, threshold=0.5
        )
        meta_graph.y = subgraphs_list[0][i].y  # Assign label from one of the subgraphs
        meta_graphs.append(meta_graph)
    return meta_graphs

