import torch
import torch.nn as nn
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import to_dense_batch

from QWGT import TransformerEncoderLayerWithQuantumBias
from QWGR import GraphCTQWRecurrentModule


class CTQWformerLayer(nn.Module):
    """
    A single layer of the CTQWformer model.

    Combines a Transformer branch with quantum structural bias,
    and a recurrent temporal encoder branch using CTQW evolution data.

    Args:
        hidden_dim (int): Hidden feature dimension.
        heads (int): Number of attention heads.
        time_steps (List[float]): Time steps used in CTQW simulation.
        dropout (float): Dropout rate.
        fusion (str): 'cat' or 'add' for feature fusion strategy.
        use_attention_bias (bool): Whether to use quantum attention bias.
        use_sequence_model (bool): Whether to use CTQW temporal modeling.
    """
    def __init__(self, hidden_dim, heads, time_steps, dropout=0.3, fusion='cat',
                 use_attention_bias=True, use_sequence_model=True):
        super().__init__()
        self.fusion = fusion
        self.use_attention_bias = use_attention_bias
        self.use_sequence_model = use_sequence_model

        self.fnn_before = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        if use_attention_bias:
            self.transformer = TransformerEncoderLayerWithQuantumBias(hidden_dim, heads)

        if use_sequence_model:
            self.recurrent = GraphCTQWRecurrentModule(time_steps, hidden_dim)

        fusion_dim = 0
        if fusion == 'cat':
            if use_attention_bias:
                fusion_dim += hidden_dim
            if use_sequence_model:
                fusion_dim += hidden_dim
            if fusion_dim == 0:
                fusion_dim = hidden_dim
        else:
            fusion_dim = hidden_dim

        self.fnn_after = nn.Sequential(
            nn.Linear(fusion_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

    def forward(self, x, qw_probs, batch):
        """
        Args:
            x (Tensor): Node features [N_total, H].
            qw_probs (Tensor): Quantum evolution tensor [B, T, N_max, N_max].
            batch (Tensor): Batch indices for each node [N_total].

        Returns:
            Tensor: Updated graph embeddings [B, H].
        """
        B, T, N_max, _ = qw_probs.size()
        H = x.size(-1)

        x_dense, mask = to_dense_batch(x, batch)  # [B, N_max, H]
        x_reshaped = self.fnn_before(x_dense)

        if self.use_attention_bias:
            x_trans = self.transformer(x_reshaped, qw_probs=qw_probs)  # [B, N_max, H]
            mask_f = mask.unsqueeze(-1).to(x_trans.dtype)
            x_trans_pool = (x_trans * mask_f).sum(dim=1) / mask_f.sum(dim=1).clamp(min=1e-9)
        else:
            x_trans_pool = torch.zeros((B, H), device=x.device, dtype=x.dtype)

        if self.use_sequence_model:
            x_recur = self.recurrent(qw_probs)
        else:
            x_recur = torch.zeros((B, H), device=x.device, dtype=x.dtype)

        if self.fusion == 'cat':
            components = []
            if self.use_attention_bias:
                components.append(x_trans_pool)
            if self.use_sequence_model:
                components.append(x_recur)
            if not components:
                components = [x_trans_pool]
            x_combined = torch.cat(components, dim=-1)
        else:
            x_combined = x_trans_pool + x_recur

        return self.fnn_after(x_combined)  # [B, H]


class CTQWformer(nn.Module):
    """
    Full CTQWformer model for graph classification.

    Args:
        input_dim (int): Input node feature dimension.
        hidden_dim (int): Hidden layer dimension.
        num_classes (int): Number of output classes.
        time_steps (List[float]): Time steps for CTQW.
        num_layers (int): Number of stacked CTQWformer layers.
        heads (int): Attention heads.
        dropout (float): Dropout rate.
        fusion (str): 'cat' or 'add' fusion.
        use_attention_bias (bool): Use structural quantum bias or not.
        use_sequence_model (bool): Use BiGRU branch or not.
    """
    def __init__(self, input_dim, hidden_dim, num_classes, time_steps,
                 num_layers=4, heads=4, dropout=0.2, fusion='cat',
                 use_attention_bias=True, use_sequence_model=True):
        super().__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)

        self.layers = nn.ModuleList([
            CTQWformerLayer(hidden_dim, heads, time_steps, dropout, fusion,
                            use_attention_bias=use_attention_bias,
                            use_sequence_model=use_sequence_model)
            for _ in range(num_layers)
        ])

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, data, qw_probs):
        """
        Args:
            data (Data): Batched graph data (PyG Data object).
            qw_probs (Tensor): Quantum evolution tensor [B, T, N_max, N_max].

        Returns:
            Tensor: Prediction logits [B, num_classes].
        """
        x = self.embedding(data.x)  # [N_total, H]
        batch = data.batch
        B, T, N_max, _ = qw_probs.size()

        for layer in self.layers:
            x = layer(x, qw_probs, batch)  # [B, H]
            x = x.unsqueeze(1).expand(-1, N_max, -1).reshape(B * N_max, -1)  # [B*N_max, H]

        x_dense, mask = to_dense_batch(x, batch)  # [B, N_max, H]
        mask_f = mask.unsqueeze(-1).to(x_dense.dtype)
        graph_emb = (x_dense * mask_f).sum(dim=1) / mask_f.sum(dim=1).clamp(min=1e-9)  # [B, H]

        return self.classifier(graph_emb)  # [B, num_classes]
