import torch
import torch.nn as nn


class GraphCTQWRecurrentModule(nn.Module):
    """
    Recurrent module for processing temporal evolution in CTQW.
    Encodes node-wise quantum walk dynamics into graph-level embeddings.

    Args:
        time_steps (List[float]): Time steps used in CTQW simulation.
        hidden_dim (int): Hidden dimension for temporal encoding and GRU.
    """
    def __init__(self, time_steps, hidden_dim):
        super().__init__()
        self.time_steps = len(time_steps)

        self.embedding = nn.Linear(self.time_steps, hidden_dim)
        self.gru = nn.GRU(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=True
        )
        self.readout = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, Q):
        """
        Forward pass for graph embedding.

        Args:
            Q (Tensor): CTQW probability evolution tensor of shape [B, T, N, N],
                        where B is batch size, T is number of time steps,
                        and N is number of nodes.

        Returns:
            Tensor: Graph-level embedding of shape [B, H].
        """
        # Step 1: Extract diagonal (i.e., self-return probabilities) over time
        node_time_series = torch.diagonal(Q, dim1=-2, dim2=-1)  # [B, T, N]
        node_time_series = node_time_series.permute(0, 2, 1)    # [B, N, T]

        # Step 2: Project each node's time series into hidden space
        x = self.embedding(node_time_series)  # [B, N, H]

        # Step 3: Process with Bi-GRU to encode temporal dynamics
        output, _ = self.gru(x)               # [B, N, 2H]

        # Step 4: Aggregate over nodes via mean pooling
        graph_emb = output.mean(dim=1)        # [B, 2H]

        # Step 5: Final projection
        return self.readout(graph_emb)        # [B, H]
