import torch
import torch.nn as nn
import numpy as np
from torch_geometric.utils import to_dense_adj


class T_Ham(nn.Module):
    """
    Learnable Hamiltonian module using node features.
    Computes a symmetric Laplacian-like Hamiltonian matrix from input graph.

    Args:
        in_dim (int): Dimension of node features.
        hidden_dim (int): Hidden dimension for edge MLP.
        activation (nn.Module): Activation function (default: ReLU).
    """
    def __init__(self, in_dim=1, hidden_dim=16, activation=nn.ReLU()):
        super(T_Ham, self).__init__()
        self.linear1 = nn.Linear(in_dim * 2, hidden_dim)  # Input: concatenation of two node features
        self.linear2 = nn.Linear(hidden_dim, 1)
        self.activation = activation

    def forward(self, edge_index, num_nodes, x=None):
        """
        Forward pass to compute Laplacian Hamiltonian.

        Args:
            edge_index (LongTensor): Tensor of shape [2, E], edge indices.
            num_nodes (int): Number of nodes in the graph.
            x (Tensor or None): Node features of shape [N, d]. If None, use all-ones.

        Returns:
            Tensor: Hamiltonian matrix H ∈ ℝ^{N×N}, symmetric Laplacian form.
        """
        device = edge_index.device

        if x is None:
            x = torch.ones((num_nodes, 1), device=device)  # Default to scalar 1 features

        row, col = edge_index  # Edge sources and targets
        edge_feat = torch.cat([x[row], x[col]], dim=1)  # Shape: [E, 2 * in_dim]

        h = self.activation(self.linear1(edge_feat))     # Shape: [E, hidden_dim]
        edge_weight = torch.sigmoid(self.linear2(h)).squeeze()  # Shape: [E], values in (0, 1)

        # Construct weighted adjacency matrix
        adj = to_dense_adj(edge_index=edge_index, edge_attr=edge_weight, max_num_nodes=num_nodes)[0]  # Shape: [N, N]

        # Ensure symmetry
        adj = (adj + adj.T) / 2

        # Compute Laplacian: H = D - A
        deg = torch.diag(adj.sum(dim=1))
        H = deg - adj

        return H


class CTQWEncoder(nn.Module):
    """
    Continuous-Time Quantum Walk (CTQW) encoder module.
    Evolves initial quantum states using a Hamiltonian over specified time steps.

    Args:
        time_steps (List[float]): A list of time points to simulate the quantum walk.
    """
    def __init__(self, time_steps):
        super().__init__()
        self.time_steps = time_steps

    def forward(self, hamiltonian):
        """
        Simulate quantum evolution for all basis states.

        Args:
            hamiltonian (Tensor): Real symmetric Hamiltonian H ∈ ℝ^{N×N}.

        Returns:
            Tensor: Probability tensor Q ∈ ℝ^{T×N×N},
                    where Q[t, i, j] is the probability of state |j⟩
                    at time t when starting from |i⟩.
        """
        num_nodes = hamiltonian.size(0)
        device = hamiltonian.device

        # Initial quantum states: each column of identity matrix is a basis state |i⟩
        psi_0_all = torch.eye(num_nodes, device=device, dtype=torch.complex64)  # Shape: [N, N]

        Q_t_list = []
        for t in self.time_steps:
            U_t = torch.matrix_exp(-1j * hamiltonian * t)  # Unitary evolution operator: e^{-iHt}, Shape: [N, N]
            psi_t = U_t @ psi_0_all  # Evolved states from each |i⟩, Shape: [N, N]
            probs = (psi_t * psi_t.conj()).real  # Element-wise |ψ|^2, Shape: [N, N]
            Q_t_list.append(probs.unsqueeze(0))  # Shape: [1, N, N]

        Q = torch.cat(Q_t_list, dim=0)  # Final tensor Q ∈ ℝ^{T×N×N}
        return Q
