import math

import torch
import torch.nn as nn
from sampling import ConditionalSampler, SamplingMethod, select_sampling_method


import time

d = False
def debug(msg):
    if d:
        print(msg)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

sampler = ConditionalSampler(device)

import torch

EPS = 1e-8


class RBMContrastiveDivergenceLayer(nn.Module):
    """
    Improved RBM layer optimized for ogbn-arxiv dataset
    """

    def __init__(self, in_features, out_features, k_steps=1, dropout=0.5, alpha=0.1, residual=0.5,
                 forward_sampling='gumbel_softmax', backward_sampling='sigmoid',
                 use_batch_norm=True, temperature=0.5):

        super(RBMContrastiveDivergenceLayer, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.k_steps = k_steps  # Reduced from 5 to 1
        self.dropout = dropout
        self.alpha = alpha
        self.residual = residual
        self.temperature = temperature

        self.forward_sampling = select_sampling_method(forward_sampling)
        self.backward_sampling = select_sampling_method(backward_sampling)

        # RBM parameters with better initialization
        self.W = nn.Parameter(torch.randn(in_features, out_features) * 0.01)
        self.visible_bias = nn.Parameter(torch.zeros(in_features))
        self.hidden_bias = nn.Parameter(torch.zeros(out_features))

        # Neighborhood aggregation weight (learnable)
        self.neighbor_weight = nn.Parameter(torch.tensor([1.0]))

        # Batch normalization for stability
        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.bn = nn.BatchNorm1d(out_features)

        # Dropout
        if dropout > 0.001:
            self.dropout_layer = nn.Dropout(dropout)

        # Layer normalization on input (helps with deep networks)
        self.input_norm = nn.LayerNorm(in_features)

    def aggregate_neighbors(self, x, edge_index):
        """
        Symmetric normalization aggregation (similar to GCN)
        """
        row, col = edge_index
        num_nodes = x.size(0)

        # Compute degree
        degree = torch.zeros(num_nodes, device=x.device, dtype=torch.float)
        degree.scatter_add_(0, row, torch.ones(row.size(0), device=x.device))
        degree = degree.clamp(min=1)

        # Symmetric normalization: D^(-0.5) A D^(-0.5)
        deg_inv_sqrt = degree.pow(-0.5)

        # Normalize edges
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Message passing
        out = torch.zeros_like(x)
        out.index_add_(0, row, x[col] * norm.unsqueeze(-1))

        # Add self-loops with proper weighting
        out = out + self.alpha * x

        return out

    def rbm_forward_pass(self, visible):
        """Forward pass through RBM with temperature control"""
        hidden_activation = self.hidden_bias + torch.matmul(visible, self.W)
        hidden_prob = sampler.sample(
            hidden_activation,
            self.forward_sampling,
            temperature=self.temperature,
            force_clamp=True
        )
        return hidden_prob, hidden_activation

    def rbm_backward_pass(self, hidden):
        """Backward pass through RBM"""
        visible_activation = self.visible_bias + torch.matmul(hidden, self.W.t())
        visible_prob = sampler.sample(
            visible_activation,
            self.backward_sampling,
            temperature=self.temperature,
            force_clamp=True
        )
        return visible_prob, visible_activation

    def contrastive_divergence(self, visible):
        """Perform k-step contrastive divergence (k=1 is standard CD-1)"""
        # Positive phase
        pos_hidden_prob, _ = self.rbm_forward_pass(visible)

        # Negative phase (k-step Gibbs sampling)
        neg_visible = visible.clone().detach()  # Detach to prevent gradient flow

        for _ in range(self.k_steps):
            neg_hidden_prob, _ = self.rbm_forward_pass(neg_visible)
            neg_hidden_sample = sampler.sample(
                neg_hidden_prob,
                self.forward_sampling,
                temperature=self.temperature,
                force_clamp=True,
                hard=True
            )
            neg_visible_prob, _ = self.rbm_backward_pass(neg_hidden_sample)
            neg_visible = sampler.sample(
                neg_visible_prob,
                self.backward_sampling,
                temperature=self.temperature,
                force_clamp=True,
                hard=True
            )

        neg_hidden_prob, _ = self.rbm_forward_pass(neg_visible)

        return pos_hidden_prob, neg_hidden_prob, neg_visible

    def forward(self, x, edge_index):
        """Forward pass with neighbor aggregation and RBM transformation"""
        # Input normalization
        x = self.input_norm(x)

        # Apply dropout to input
        if hasattr(self, 'dropout_layer') and self.training:
            x = self.dropout_layer(x)

        # Aggregate neighbor information
        aggregated_x = self.aggregate_neighbors(x, edge_index)

        # Combine original and aggregated features
        combined_input = self.residual * x + (1 - self.residual) * self.neighbor_weight * aggregated_x

        # Apply RBM transformation
        if self.training:
            # During training, use contrastive divergence
            pos_hidden, neg_hidden, neg_visible = self.contrastive_divergence(combined_input)
            output = pos_hidden
        else:
            # During inference, just use forward pass
            output, _ = self.rbm_forward_pass(combined_input)

        # Apply batch normalization
        if self.use_batch_norm:
            output = self.bn(output)

        return output

    def compute_cd_loss(self, visible):
        """
        Compute contrastive divergence loss for RBM training
        Can be added to main loss during training --- not in the scope of the article
        """
        pos_hidden_prob, neg_hidden_prob, neg_visible = self.contrastive_divergence(visible)

        # Reconstruction loss
        recon_loss = F.mse_loss(neg_visible, visible)

        # Free energy difference
        pos_free_energy = -torch.sum(torch.log(1 + torch.exp(
            self.hidden_bias + torch.matmul(visible, self.W)
        )), dim=1) - torch.matmul(visible, self.visible_bias)

        neg_free_energy = -torch.sum(torch.log(1 + torch.exp(
            self.hidden_bias + torch.matmul(neg_visible, self.W)
        )), dim=1) - torch.matmul(neg_visible, self.visible_bias)

        cd_loss = torch.mean(pos_free_energy - neg_free_energy)

        return cd_loss + 0.1 * recon_loss

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W, gain=0.5)
        nn.init.zeros_(self.visible_bias)
        nn.init.zeros_(self.hidden_bias)
        nn.init.constant_(self.neighbor_weight, 1.0)

        if self.use_batch_norm:
            self.bn.reset_parameters()

        self.input_norm.reset_parameters()



class RBMConvNet(torch.nn.Module):
    """
    Network using RBM Contrastive Divergence instead of GCN convolutions
    """

    def __init__(self, dataset, hidden_dim=64, dropout=0.6, k_steps=1, alpha=0.5, residual=0.35,
                 forward_sampling='gumbel_softmax', backward_sampling='sigmoid', num_layers=2):
        super(RBMConvNet, self).__init__()

        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        # First layer
        self.layers.append(RBMContrastiveDivergenceLayer(
            dataset.num_features, hidden_dim,  k_steps=1,
                    dropout=dropout,
                    alpha=0.1,
                    residual=0.5,
                    temperature=0.5,
                    use_batch_norm=True
        ))

        # Additional hidden layers if specified
        for _ in range(num_layers - 2):
            self.layers.append(RBMContrastiveDivergenceLayer(
                hidden_dim, hidden_dim, k_steps=k_steps, dropout=dropout / 2,
                alpha=alpha, residual=residual, forward_sampling=forward_sampling,
                backward_sampling=backward_sampling
            ))

        # Output layer
        self.layers.append(RBMContrastiveDivergenceLayer(
            hidden_dim, dataset.num_classes,  k_steps=1,
                dropout=0.0,  # No dropout on output
                alpha=0.1,
                residual=0.5,
                temperature=0.5,
                use_batch_norm=False  # No BN
        ))

        # Batch normalization layers
        self.batch_norms = nn.ModuleList([
            nn.BatchNorm1d(hidden_dim) for _ in range(num_layers - 1)
        ])



    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, data, return_loss=False, loss_type="reconstruction"):
        x, edge_index = data.x, data.edge_index

        # Forward through layers
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x, edge_index)

            # Apply batch normalization if not in training mode or skip connections
            if hasattr(self, 'batch_norms') and i < len(self.batch_norms):
                if not self.training or x.size(0) > 1:  # Skip batch norm for single samples
                    pass
                    # x = self.batch_norms[i](x)

            # Enhanced activation with residual-like connections
            x = F.elu(x)  # + 0.1 * x  # ELU with small residual

        # Output layer
        x = self.layers[-1](x, edge_index)
        x = F.log_softmax(x, dim=1)

        return x


import torch
import torch.nn as nn
import torch.nn.functional as F


class EnhancedRBMLayer(nn.Module):
    """
    Enhanced RBM layer with improved sampling and aggregation
    No attention - pure RBM philosophy
    """

    def __init__(self, in_features, out_features, k_steps=5, dropout=0.5,
                 alpha=0.1, residual=0.5, temperature=0.5, use_batch_norm=True):
        super(EnhancedRBMLayer, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.k_steps = k_steps
        self.dropout = dropout
        self.alpha = alpha
        self.residual = residual
        self.temperature = temperature

        # RBM parameters with better initialization
        self.W = nn.Parameter(torch.empty(in_features, out_features))
        self.visible_bias = nn.Parameter(torch.zeros(in_features))
        self.hidden_bias = nn.Parameter(torch.zeros(out_features))

        # Learnable aggregation weights (separate for self and neighbors)
        self.neighbor_weight = nn.Parameter(torch.ones(1))
        self.self_weight = nn.Parameter(torch.ones(1))

        # Normalization layers
        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.bn = nn.BatchNorm1d(out_features)
        self.layer_norm = nn.LayerNorm(in_features)

        # Dropout
        if dropout > 0:
            self.dropout_layer = nn.Dropout(dropout)

        # Skip connection projection if dimensions don't match
        self.skip_proj = None
        if in_features != out_features:
            self.skip_proj = nn.Linear(in_features, out_features, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        # Xavier initialization for better gradient flow
        nn.init.xavier_uniform_(self.W, gain=0.5)
        nn.init.zeros_(self.visible_bias)
        nn.init.zeros_(self.hidden_bias)
        nn.init.ones_(self.neighbor_weight)
        nn.init.ones_(self.self_weight)

        if self.skip_proj is not None:
            nn.init.xavier_uniform_(self.skip_proj.weight, gain=0.5)

        if self.use_batch_norm:
            self.bn.reset_parameters()
        self.layer_norm.reset_parameters()

    def aggregate_neighbors(self, x, edge_index):
        """
        Improved symmetric normalization (GCN-style)
        """
        row, col = edge_index
        num_nodes = x.size(0)

        # Compute degree
        degree = torch.zeros(num_nodes, device=x.device, dtype=torch.float)
        degree.scatter_add_(0, row, torch.ones(row.size(0), device=x.device))
        degree = degree.clamp(min=1)

        # Symmetric normalization: D^(-0.5) A D^(-0.5)
        deg_inv_sqrt = degree.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Message passing
        out = torch.zeros_like(x)
        out.index_add_(0, row, x[col] * norm.unsqueeze(-1))

        return out

    def rbm_forward_pass(self, visible):
        """
        Forward pass with improved Gumbel-Softmax sampling
        """
        hidden_activation = self.hidden_bias + torch.matmul(visible, self.W)

        # Gumbel-Softmax for differentiable sampling
        if self.training:
            # Add Gumbel noise for stochasticity
            noise = -torch.log(-torch.log(torch.rand_like(hidden_activation) + 1e-10) + 1e-10)
            hidden_prob = torch.sigmoid((hidden_activation + noise) / self.temperature)
        else:
            # Deterministic during inference
            hidden_prob = torch.sigmoid(hidden_activation / self.temperature)

        return hidden_prob, hidden_activation

    def rbm_backward_pass(self, hidden):
        """
        Backward pass through RBM
        """
        visible_activation = self.visible_bias + torch.matmul(hidden, self.W.t())

        if self.training:
            noise = -torch.log(-torch.log(torch.rand_like(visible_activation) + 1e-10) + 1e-10)
            visible_prob = torch.sigmoid((visible_activation + noise) / self.temperature)
        else:
            visible_prob = torch.sigmoid(visible_activation / self.temperature)

        return visible_prob, visible_activation

    def _gumbel_softmax(self, logits, tau=1.0, hard=True):
        # logits: [N, D]
        gumbels = -torch.empty_like(logits).exponential_().log()  # ~Gumbel(0,1)
        y = logits + gumbels
        y_soft = torch.sigmoid(y / tau)

        if not hard:
            return y_soft

        # Straight-through estimator
        y_hard = (y_soft > 0.5).float()
        ret = y_hard - y_soft.detach() + y_soft
        return ret


    def contrastive_divergence(self, visible):
        """
        Contrastive divergence with k steps
        """
        # Positive phase
        pos_hidden_prob, _ = self.rbm_forward_pass(visible)

        # Negative phase (k-step Gibbs sampling)
        neg_visible = visible.detach()

        for _ in range(self.k_steps):
            neg_hidden_prob, _ = self.rbm_forward_pass(neg_visible)
            neg_hidden_sample = self._gumbel_softmax(
                neg_hidden_prob * 10, tau=self.temperature, hard=True
            )  # *10 makes it sharper

            neg_visible_prob, _ = self.rbm_backward_pass(neg_hidden_sample)
            neg_visible = self._gumbel_softmax(
                neg_visible_prob * 10, tau=self.temperature, hard=True
            )

        # Final negative hidden
        neg_hidden_prob, _ = self.rbm_forward_pass(neg_visible)

        return pos_hidden_prob, neg_hidden_prob, neg_visible

    def forward(self, x, edge_index):
        """
        Forward pass with graph structure
        """
        identity = x

        # Layer normalization for stability
        x = self.layer_norm(x)

        # Dropout on input
        if hasattr(self, 'dropout_layer') and self.training:
            x = self.dropout_layer(x)

        # Aggregate neighbor information
        aggregated_x = self.aggregate_neighbors(x, edge_index)

        # Learnable combination: self + alpha * neighbors
        combined = self.self_weight * x + self.neighbor_weight * aggregated_x + self.alpha * x

        # RBM transformation
        if self.training:
            output, _, _ = self.contrastive_divergence(combined)
        else:
            output, _ = self.rbm_forward_pass(combined)

        # Skip connection for better gradient flow
        if self.skip_proj is not None:
            identity = self.skip_proj(identity)
        elif identity.size(-1) == output.size(-1):
            output = output + self.residual * identity

        # Batch normalization
        if self.use_batch_norm:
            output = self.bn(output)

        return output

    def compute_cd_loss(self, visible):
        """
        Compute contrastive divergence loss
        """
        pos_hidden_prob, neg_hidden_prob, neg_visible = self.contrastive_divergence(visible)

        # Reconstruction loss
        recon_loss = F.mse_loss(neg_visible, visible)

        # Free energy difference
        pos_free_energy = -torch.sum(
            torch.log(1 + torch.exp(self.hidden_bias + torch.matmul(visible, self.W))),
            dim=1
        ) - torch.matmul(visible, self.visible_bias)

        neg_free_energy = -torch.sum(
            torch.log(1 + torch.exp(self.hidden_bias + torch.matmul(neg_visible, self.W))),
            dim=1
        ) - torch.matmul(neg_visible, self.visible_bias)

        cd_loss = torch.mean(pos_free_energy - neg_free_energy)

        return cd_loss + 0.1 * recon_loss



class DeepRBMConvNet(nn.Module):
    """
    Deeper RBM network with residual connections between blocks
    """

    def __init__(self, num_features, num_classes, hidden_dim=128,
                 k_steps=1,
                 num_blocks=3, layers_per_block=2, dropout=0.5):
        super(DeepRBMConvNet, self).__init__()

        self.num_blocks = num_blocks
        self.layers_per_block = layers_per_block

        # Input projection
        self.input_proj = nn.Linear(num_features, hidden_dim)

        # RBM blocks
        self.blocks = nn.ModuleList()
        for block_idx in range(num_blocks):
            block = nn.ModuleList()
            for layer_idx in range(layers_per_block):
                layer = EnhancedRBMLayer(
                    in_features=hidden_dim,
                    out_features=hidden_dim,
                    k_steps=k_steps,
                    dropout=dropout,
                    alpha=0.1,
                    residual=0.5,
                    temperature=0.5,
                    use_batch_norm=True
                )
                block.append(layer)
            self.blocks.append(block)

        # Output projection
        self.output_proj = nn.Linear(hidden_dim, num_classes)
        self.output_norm = nn.LayerNorm(num_classes)

    def reset_parameters(self):
        for block in self.blocks:
            for layer in block:
                layer.reset_parameters()
        self.output_proj.reset_parameters()
        self.output_norm.reset_parameters()

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # Input projection
        x = self.input_proj(x)
        x = F.elu(x)

        # Process through blocks with residual connections
        for block in self.blocks:
            block_input = x

            # Process through layers in block
            for layer in block:
                x = layer(x, edge_index)
                x = F.elu(x)

            # Residual connection at block level
            x = x + block_input

        # Output projection
        x = self.output_proj(x)
        x = self.output_norm(x)

        return F.log_softmax(x, dim=1)

class EnhancedRBMConvNet(nn.Module):
    """
    Enhanced RBM network - pure RBM philosophy with improvements
    """

    def __init__(self, num_features, num_classes, hidden_dim=128,
                 num_layers=3, dropout=0.5, k_steps=1):
        super(EnhancedRBMConvNet, self).__init__()

        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        # Progressive dimension reduction
        dims = [num_features] + [hidden_dim] * (num_layers - 1) + [num_classes]

        # Build RBM layers
        for i in range(num_layers):
            layer = EnhancedRBMLayer(
                in_features=dims[i],
                out_features=dims[i + 1],
                k_steps=k_steps,
                dropout=dropout if i < num_layers - 1 else 0.0,
                alpha=0.1,
                residual=0.5,
                temperature=0.5,
                use_batch_norm=(i < num_layers - 1)
            )
            self.layers.append(layer)

        # Additional normalization on output
        self.output_norm = nn.LayerNorm(num_classes)

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()
        self.output_norm.reset_parameters()

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # Forward through RBM layers
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x, edge_index)
            x = F.elu(x)  # ELU activation

        # Output layer (no activation)
        x = self.layers[-1](x, edge_index)

        # Output normalization
        x = self.output_norm(x)

        return F.log_softmax(x, dim=1)

    def compute_total_cd_loss(self, data):
        """
        Compute CD loss across all layers (optional for unsupervised pre-training)
        """
        x = data.x
        edge_index = data.edge_index

        total_loss = 0
        for i, layer in enumerate(self.layers[:-1]):  # Exclude output layer
            # Aggregate neighbors
            aggregated = layer.aggregate_neighbors(x, edge_index)
            combined = layer.self_weight * x + layer.neighbor_weight * aggregated + layer.alpha * x

            # Compute CD loss
            cd_loss = layer.compute_cd_loss(combined)
            total_loss += cd_loss

            # Forward to next layer
            x = layer(x, edge_index)
            x = F.elu(x)

        return total_loss




import torch
import torch.nn as nn
import torch.nn.functional as F


class EnhancedDeepRBMConvNet(nn.Module):
    """
    Enhanced Deep RBM with pure RBM approach - simple but effective improvements
    Key enhancements:
    1. Adaptive temperature per block (cooler in deeper layers)
    2. Progressive dropout reduction
    3. Better skip connections with learnable gates
    4. Intermediate supervision at each block
    5. Feature refinement between blocks
    """

    def __init__(self, num_features, num_classes, hidden_dim=256,
                 num_blocks=4, layers_per_block=2, dropout=0.5,
                 adaptive_temperature=True, use_gates=True):
        super(EnhancedDeepRBMConvNet, self).__init__()

        self.num_blocks = num_blocks
        self.layers_per_block = layers_per_block
        self.adaptive_temperature = adaptive_temperature
        self.use_gates = use_gates

        # Input projection with normalization
        self.input_proj = nn.Sequential(
            nn.Linear(num_features, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout * 0.3)
        )

        # RBM blocks
        self.blocks = nn.ModuleList()
        self.block_norms = nn.ModuleList()
        self.block_refinements = nn.ModuleList()  # Feature refinement between blocks

        # Learnable skip connection gates (if enabled)
        if use_gates:
            self.skip_gates = nn.ParameterList([
                nn.Parameter(torch.ones(1)) for _ in range(num_blocks)
            ])

        for block_idx in range(num_blocks):
            # Adaptive temperature: cooler (more deterministic) in deeper layers
            if adaptive_temperature:
                temperature = 0.5 * (0.85 ** block_idx)  # 0.5 -> 0.425 -> 0.361 -> 0.307
            else:
                temperature = 0.5

            # Progressive dropout: less dropout in deeper layers
            block_dropout = dropout * (1.0 - 0.3 * block_idx / num_blocks)

            # Create RBM layers for this block
            block = nn.ModuleList()
            for layer_idx in range(layers_per_block):
                layer = EnhancedRBMLayer(
                    in_features=hidden_dim,
                    out_features=hidden_dim,
                    k_steps=1,  # CD-1 is efficient
                    dropout=block_dropout,
                    alpha=0.1,
                    residual=0.5,
                    temperature=temperature,
                    use_batch_norm=True
                )
                block.append(layer)

            self.blocks.append(block)

            # Block normalization
            self.block_norms.append(nn.LayerNorm(hidden_dim))

            # Feature refinement between blocks (1x1 convolution equivalent)
            if block_idx < num_blocks - 1:
                self.block_refinements.append(nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.LayerNorm(hidden_dim),
                    nn.ELU(),
                    nn.Dropout(dropout * 0.2)
                ))

        # Multi-scale output head for better classification
        self.output_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ELU(),
            nn.Dropout(dropout * 0.3),
            nn.Linear(hidden_dim, num_classes),
            nn.LayerNorm(num_classes)
        )

        self.reset_parameters()

    def reset_parameters(self):
        # Input projection
        for module in self.input_proj:
            if hasattr(module, 'reset_parameters'):
                module.reset_parameters()

        # Blocks
        for block in self.blocks:
            for layer in block:
                layer.reset_parameters()

        # Block norms
        for norm in self.block_norms:
            norm.reset_parameters()

        # Refinements
        for refinement in self.block_refinements:
            for module in refinement:
                if hasattr(module, 'reset_parameters'):
                    module.reset_parameters()

        # Output head
        for module in self.output_head:
            if hasattr(module, 'reset_parameters'):
                module.reset_parameters()

        # Gates
        if self.use_gates:
            for gate in self.skip_gates:
                nn.init.constant_(gate, 1.0)

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # Input projection
        x = self.input_proj(x)
        x = F.elu(x)

        # Process through blocks
        for block_idx, (block, block_norm) in enumerate(zip(self.blocks, self.block_norms)):
            block_input = x

            # Process through RBM layers in block
            for layer in block:
                x = layer(x, edge_index)
                x = F.elu(x)

            # Learnable gated skip connection
            if self.use_gates:
                gate = torch.sigmoid(self.skip_gates[block_idx])
                x = gate * x + (1 - gate) * block_input
            else:
                # Standard residual connection
                x = x + block_input

            # Block normalization
            x = block_norm(x)

            # Feature refinement between blocks
            if block_idx < len(self.block_refinements):
                x = self.block_refinements[block_idx](x)

        # Output head
        x = self.output_head(x)

        return F.log_softmax(x, dim=1)

    def compute_block_cd_losses(self, data):
        """
        Compute CD loss for each block (for unsupervised pre-training)
        """
        x, edge_index = data.x, data.edge_index

        # Input projection
        x = self.input_proj(x)
        x = F.elu(x)

        block_losses = []

        for block_idx, block in enumerate(self.blocks):
            block_input = x
            block_loss = 0

            for layer in block:
                # Aggregate neighbors
                aggregated = layer.aggregate_neighbors(x, edge_index)
                combined = layer.self_weight * x + layer.neighbor_weight * aggregated + layer.alpha * x

                # Compute CD loss
                cd_loss = layer.compute_cd_loss(combined)
                block_loss += cd_loss

                # Forward
                x = layer(x, edge_index)
                x = F.elu(x)

            # Average CD loss for this block
            block_losses.append(block_loss / len(block))

            # Skip connection
            if self.use_gates:
                gate = torch.sigmoid(self.skip_gates[block_idx])
                x = gate * x + (1 - gate) * block_input
            else:
                x = x + block_input

            x = self.block_norms[block_idx](x)

            if block_idx < len(self.block_refinements):
                x = self.block_refinements[block_idx](x)

        return sum(block_losses) / len(block_losses)


class UltraDeepRBMConvNet(nn.Module):
    """
    Ultra-deep variant with better gradient flow for very deep networks (6+ blocks)
    Uses pre-activation residuals and layer scale
    """

    def __init__(self, num_features, num_classes, hidden_dim=256,
                 num_blocks=6, layers_per_block=2, dropout=0.5,
                 layer_scale_init=1e-4):
        super(UltraDeepRBMConvNet, self).__init__()

        self.num_blocks = num_blocks
        self.layers_per_block = layers_per_block

        # Input projection
        self.input_proj = nn.Sequential(
            nn.Linear(num_features, hidden_dim),
            nn.LayerNorm(hidden_dim)
        )

        # RBM blocks with layer scale
        self.blocks = nn.ModuleList()
        self.pre_norms = nn.ModuleList()  # Pre-activation normalization
        self.layer_scales = nn.ParameterList()  # Layer scale parameters

        for block_idx in range(num_blocks):
            # Adaptive temperature and dropout
            temperature = 0.5 * (0.9 ** block_idx)
            block_dropout = dropout * (1.0 - 0.4 * block_idx / num_blocks)

            # Pre-activation norm
            self.pre_norms.append(nn.LayerNorm(hidden_dim))

            # RBM block
            block = nn.ModuleList()
            for _ in range(layers_per_block):
                layer = EnhancedRBMLayer(
                    in_features=hidden_dim,
                    out_features=hidden_dim,
                    k_steps=1,
                    dropout=block_dropout,
                    alpha=0.1,
                    residual=0.5,
                    temperature=temperature,
                    use_batch_norm=True
                )
                block.append(layer)

            self.blocks.append(block)

            # Layer scale (helps with very deep networks)
            self.layer_scales.append(
                nn.Parameter(torch.ones(hidden_dim) * layer_scale_init)
            )

        # Output
        self.final_norm = nn.LayerNorm(hidden_dim)
        self.output_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ELU(),
            nn.Dropout(dropout * 0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )

        self.reset_parameters()

    def reset_parameters(self):
        for module in self.input_proj:
            if hasattr(module, 'reset_parameters'):
                module.reset_parameters()

        for block in self.blocks:
            for layer in block:
                layer.reset_parameters()

        for norm in self.pre_norms:
            norm.reset_parameters()

        for scale in self.layer_scales:
            nn.init.constant_(scale, 1e-4)

        self.final_norm.reset_parameters()

        for module in self.output_head:
            if hasattr(module, 'reset_parameters'):
                module.reset_parameters()

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # Input projection
        x = self.input_proj(x)

        # Process through blocks with pre-activation
        for block_idx, (pre_norm, block, layer_scale) in enumerate(
                zip(self.pre_norms, self.blocks, self.layer_scales)
        ):
            block_input = x

            # Pre-activation normalization
            h = pre_norm(x)
            h = F.elu(h)

            # Process through RBM layers
            for layer in block:
                h = layer(h, edge_index)
                h = F.elu(h)

            # Layer scale + residual
            x = x + layer_scale * h

        # Final normalization and output
        x = self.final_norm(x)
        x = self.output_head(x)

        return F.log_softmax(x, dim=1)


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint


class MemoryEfficientRBMLayer(nn.Module):
    """
    Memory-optimized RBM layer with gradient checkpointing and in-place operations
    """

    def __init__(self, in_features, out_features, k_steps=1, dropout=0.5,
                 alpha=0.1, residual=0.5, temperature=0.5, use_batch_norm=True):
        super(MemoryEfficientRBMLayer, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.k_steps = k_steps
        self.dropout = dropout
        self.alpha = alpha
        self.residual = residual
        self.temperature = temperature

        # RBM parameters - use float16 where possible
        self.W = nn.Parameter(torch.empty(in_features, out_features))
        self.visible_bias = nn.Parameter(torch.zeros(in_features))
        self.hidden_bias = nn.Parameter(torch.zeros(out_features))

        # Learnable aggregation weights (shared scalars, not per-node)
        self.neighbor_weight = nn.Parameter(torch.tensor(1.0))
        self.self_weight = nn.Parameter(torch.tensor(1.0))

        # Normalization - more memory efficient than LayerNorm
        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.bn = nn.BatchNorm1d(out_features, track_running_stats=False)  # No stats tracking

        # Dropout
        if dropout > 0:
            self.dropout_layer = nn.Dropout(dropout, inplace=True)  # In-place dropout

        # Skip projection
        self.skip_proj = None
        if in_features != out_features:
            self.skip_proj = nn.Linear(in_features, out_features, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W, gain=0.5)
        nn.init.zeros_(self.visible_bias)
        nn.init.zeros_(self.hidden_bias)
        nn.init.constant_(self.neighbor_weight, 1.0)
        nn.init.constant_(self.self_weight, 1.0)

        if self.skip_proj is not None:
            nn.init.xavier_uniform_(self.skip_proj.weight, gain=0.5)

        if self.use_batch_norm:
            self.bn.reset_parameters()

    def aggregate_neighbors(self, x, edge_index):
        """Memory-efficient aggregation without creating large intermediate tensors"""
        row, col = edge_index
        num_nodes = x.size(0)

        # Compute degree in-place
        degree = torch.zeros(num_nodes, device=x.device, dtype=x.dtype)
        degree.scatter_add_(0, row, torch.ones(row.size(0), device=x.device, dtype=x.dtype))
        degree.clamp_(min=1)  # In-place clamp

        # Symmetric normalization
        deg_inv_sqrt = degree.pow_(-0.5)  # In-place power
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Message passing - reuse memory
        out = torch.zeros_like(x)
        out.index_add_(0, row, x[col] * norm.unsqueeze(-1))

        return out

    def rbm_forward_pass(self, visible):
        """Forward pass - minimize temporary tensors"""
        # Compute activation
        hidden_activation = torch.matmul(visible, self.W)
        hidden_activation.add_(self.hidden_bias)  # In-place add

        if self.training:
            # Gumbel noise
            noise = torch.rand_like(hidden_activation)
            noise = -torch.log(-torch.log(noise + 1e-10) + 1e-10)
            hidden_activation.add_(noise)  # In-place

        hidden_prob = torch.sigmoid_(hidden_activation.div_(self.temperature))  # In-place

        return hidden_prob

    def contrastive_divergence_minimal(self, visible):
        """CD-1 with minimal memory footprint"""
        # Positive phase
        pos_hidden = self.rbm_forward_pass(visible)

        # Negative phase (k=1 only for memory efficiency)
        neg_visible = visible.detach()
        neg_hidden = self.rbm_forward_pass(neg_visible)
        neg_hidden_binary = (neg_hidden > 0.5).to(neg_hidden.dtype)

        # Reconstruct
        neg_visible_activation = torch.matmul(neg_hidden_binary, self.W.t())
        neg_visible_activation.add_(self.visible_bias)
        neg_visible = torch.sigmoid_(neg_visible_activation)
        neg_visible_binary = (neg_visible > 0.5).to(neg_visible.dtype)

        # Final negative hidden
        neg_hidden_final = self.rbm_forward_pass(neg_visible_binary)

        return pos_hidden, neg_hidden_final

    def forward(self, x, edge_index):
        """Memory-efficient forward pass"""
        identity = x if self.skip_proj is None and x.size(-1) == self.out_features else None

        # Dropout in-place
        if hasattr(self, 'dropout_layer') and self.training:
            x = self.dropout_layer(x)

        # Aggregate neighbors
        aggregated = self.aggregate_neighbors(x, edge_index)

        # Combine features (in-place when possible)
        combined = self.self_weight * x + self.neighbor_weight * aggregated
        combined.add_(self.alpha * x)  # In-place

        # RBM transformation
        if self.training:
            output, _ = self.contrastive_divergence_minimal(combined)
        else:
            output = self.rbm_forward_pass(combined)

        # Skip connection
        if self.skip_proj is not None:
            identity = self.skip_proj(x) if identity is None else identity
            output.add_(self.residual * identity)  # In-place
        elif identity is not None:
            output.add_(self.residual * identity)  # In-place

        # Batch norm
        if self.use_batch_norm:
            output = self.bn(output)

        return output


class MemoryOptimizedDeepRBM(nn.Module):
    """
    Memory-optimized Deep RBM for large graphs like ogbn-arxiv

    Key optimizations:
    1. Gradient checkpointing (trade compute for memory)
    2. In-place operations where possible
    3. No intermediate normalization layers
    4. Minimal skip connection overhead
    5. Optional mixed precision support
    """

    def __init__(self, num_features, num_classes, hidden_dim=256,
                 num_blocks=4, layers_per_block=2, dropout=0.5,
                 use_checkpoint=True, adaptive_temperature=True):
        super(MemoryOptimizedDeepRBM, self).__init__()

        self.num_blocks = num_blocks
        self.layers_per_block = layers_per_block
        self.use_checkpoint = use_checkpoint
        self.adaptive_temperature = adaptive_temperature

        # Input projection (no normalization to save memory)
        self.input_proj = nn.Linear(num_features, hidden_dim)
        self.input_dropout = nn.Dropout(dropout * 0.3, inplace=True)

        # RBM blocks - no separate ModuleList to reduce overhead
        self.blocks = nn.ModuleList()

        for block_idx in range(num_blocks):
            # Adaptive temperature
            if adaptive_temperature:
                temperature = 0.5 * (0.85 ** block_idx)
            else:
                temperature = 0.5

            # Progressive dropout
            block_dropout = dropout * (1.0 - 0.3 * block_idx / num_blocks)

            # Create block
            block = nn.ModuleList()
            for _ in range(layers_per_block):
                layer = MemoryEfficientRBMLayer(
                    in_features=hidden_dim,
                    out_features=hidden_dim,
                    k_steps=1,  # Always k=1 for memory efficiency
                    dropout=block_dropout,
                    alpha=0.1,
                    residual=0.5,
                    temperature=temperature,
                    use_batch_norm=True
                )
                block.append(layer)

            self.blocks.append(block)

        # Minimal output head
        self.output_proj = nn.Linear(hidden_dim, num_classes)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.input_proj.weight, gain=0.5)
        nn.init.zeros_(self.input_proj.bias)

        for block in self.blocks:
            for layer in block:
                layer.reset_parameters()

        nn.init.xavier_uniform_(self.output_proj.weight)
        nn.init.zeros_(self.output_proj.bias)

    def _forward_block(self, x, edge_index, block):
        """Forward through a single block (for checkpointing)"""
        block_input = x

        for layer in block:
            x = layer(x, edge_index)
            x = F.elu_(x)  # In-place activation

        # Residual connection
        x.add_(block_input)  # In-place add

        return x

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # Input projection
        x = self.input_proj(x)
        x = F.elu_(x)  # In-place
        x = self.input_dropout(x)

        # Process through blocks with gradient checkpointing
        for block in self.blocks:
            if self.use_checkpoint and self.training:
                # Gradient checkpointing: trade compute for memory
                x = checkpoint(self._forward_block, x, edge_index, block, use_reentrant=False)
            else:
                x = self._forward_block(x, edge_index, block)

        # Output
        x = self.output_proj(x)

        return F.log_softmax(x, dim=1)

    @torch.no_grad()
    def inference(self, data):
        """Memory-efficient inference - disables gradient checkpointing"""
        old_checkpoint = self.use_checkpoint
        self.use_checkpoint = False
        output = self.forward(data)
        self.use_checkpoint = old_checkpoint
        return output


class CompactDeepRBM(nn.Module):
    """
    Ultra-compact Deep RBM for extreme memory constraints
    Uses smaller hidden dimensions and fewer parameters
    """

    def __init__(self, num_features, num_classes, hidden_dim=128,
                 num_blocks=3, layers_per_block=2, dropout=0.5):
        super(CompactDeepRBM, self).__init__()

        self.num_blocks = num_blocks

        # Compact input projection
        self.input_proj = nn.Linear(num_features, hidden_dim, bias=False)

        # Shared RBM weights across blocks (extreme memory saving)
        self.shared_layers = nn.ModuleList([
            MemoryEfficientRBMLayer(
                in_features=hidden_dim,
                out_features=hidden_dim,
                k_steps=1,
                dropout=dropout,
                alpha=0.1,
                residual=0.5,
                temperature=0.5,
                use_batch_norm=False  # No batch norm to save memory
            ) for _ in range(layers_per_block)
        ])

        # Minimal output
        self.output_proj = nn.Linear(hidden_dim, num_classes, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.input_proj.weight, gain=0.5)
        for layer in self.shared_layers:
            layer.reset_parameters()
        nn.init.xavier_uniform_(self.output_proj.weight)

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # Input
        x = self.input_proj(x)
        x = F.elu_(x)

        # Reuse same layers multiple times
        for _ in range(self.num_blocks):
            block_input = x

            for layer in self.shared_layers:
                x = layer(x, edge_index)
                x = F.elu_(x)

            x.add_(block_input)

        # Output
        x = self.output_proj(x)
        return F.log_softmax(x, dim=1)


# Factory function with memory profiles
def create_memory_optimized_rbm(num_features, num_classes, memory_level='medium'):
    """
    Create RBM model based on available memory

    memory_level:
    - 'high': 16GB+ VRAM (hidden_dim=256, 4 blocks)
    - 'medium': 8-16GB VRAM (hidden_dim=128, 4 blocks, checkpoint)
    - 'low': 4-8GB VRAM (hidden_dim=64, 3 blocks, checkpoint)
    - 'minimal': <4GB VRAM (hidden_dim=64, 3 blocks, compact)
    """

    if memory_level == 'high':
        print("Creating HIGH memory model (16GB+ VRAM)")
        return MemoryOptimizedDeepRBM(
            num_features=num_features,
            num_classes=num_classes,
            hidden_dim=256,
            num_blocks=4,
            layers_per_block=2,
            dropout=0.5,
            use_checkpoint=False,
            adaptive_temperature=True
        )

    elif memory_level == 'medium':
        print("Creating MEDIUM memory model (8-16GB VRAM)")
        return MemoryOptimizedDeepRBM(
            num_features=num_features,
            num_classes=num_classes,
            hidden_dim=128,
            num_blocks=4,
            layers_per_block=2,
            dropout=0.5,
            use_checkpoint=True,  # Enable gradient checkpointing
            adaptive_temperature=True
        )

    elif memory_level == 'low':
        print("Creating LOW memory model (4-8GB VRAM)")
        return MemoryOptimizedDeepRBM(
            num_features=num_features,
            num_classes=num_classes,
            hidden_dim=64,
            num_blocks=3,
            layers_per_block=2,
            dropout=0.5,
            use_checkpoint=True,
            adaptive_temperature=True
        )

    elif memory_level == 'minimal':
        print("Creating MINIMAL memory model (<4GB VRAM)")
        return CompactDeepRBM(
            num_features=num_features,
            num_classes=num_classes,
            hidden_dim=64,
            num_blocks=3,
            layers_per_block=2,
            dropout=0.5
        )

    else:
        raise ValueError(f"Unknown memory level: {memory_level}")



# Recommended configurations for ogbn-arxiv
def create_rbm_model_for_arxiv(num_features, num_classes, model_type='enhanced'):
    """
    Factory function with recommended configs for ogbn-arxiv
    """

    if model_type == 'enhanced':
        # Best balance: Enhanced Deep RBM
        return EnhancedDeepRBMConvNet(
            num_features=num_features,
            num_classes=num_classes,
            hidden_dim=256,
            num_blocks=4,
            layers_per_block=2,
            dropout=0.5,
            adaptive_temperature=True,
            use_gates=True
        )

    elif model_type == 'ultra_deep':
        # For more capacity: Ultra Deep RBM
        return UltraDeepRBMConvNet(
            num_features=num_features,
            num_classes=num_classes,
            hidden_dim=256,
            num_blocks=6,
            layers_per_block=2,
            dropout=0.5,
            layer_scale_init=1e-4
        )

    elif model_type == 'wide':
        # Wider but shallower
        return EnhancedDeepRBMConvNet(
            num_features=num_features,
            num_classes=num_classes,
            hidden_dim=512,
            num_blocks=3,
            layers_per_block=2,
            dropout=0.6,
            adaptive_temperature=True,
            use_gates=True
        )

    else:
        raise ValueError(f"Unknown model type: {model_type}")


# Training configuration for best results
OPTIMAL_CONFIG = {
    'lr': 0.01,  # Higher LR for RBMs
    'weight_decay': 0.0,  # No weight decay for RBMs
    'grad_clip': 1.0,
    'scheduler': 'cosine',
    'epochs': 500,
    'patience': 100,
    'use_cd_loss': False,  # Pure supervised learning
    'device': 'cuda'
}



import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (
    GATConv, GCNConv, APPNP, JumpingKnowledge, SGConv, TransformerConv, MessagePassing
)
from torch_geometric.nn.inits import glorot, zeros



# =====================================================================
# 1️⃣ GATv2 - Paper: "How Attentive are Graph Attention Networks?" (ICLR 2022)
# =====================================================================
class GATNet(nn.Module):
    """
    Enhanced GAT following GATv2 paper recommendations:
    - Dynamic attention (attention depends on both source and target)
    - Multi-head attention with proper concatenation
    - Residual connections for deep networks
    - Layer normalization for training stability
    """

    def __init__(self, dataset, hidden_dim=32, num_heads=8, num_layers=3, dropout=0.6,
                 edge_dim=None, use_residual=True, attention_type='gatv2', **kwargs):
        super(GATNet, self).__init__()
        in_dim, out_dim = dataset.num_features, dataset.num_classes
        self.num_layers = num_layers
        self.dropout = dropout
        self.use_residual = use_residual

        self.convs = nn.ModuleList()
        self.skips = nn.ModuleList()  # Skip connections for residual
        self.norms = nn.ModuleList()

        # Input layer
        self.convs.append(
            GATConv(in_dim, hidden_dim, heads=num_heads, dropout=dropout,
                    add_self_loops=True, bias=True, concat=True)
        )
        self.norms.append(nn.LayerNorm(hidden_dim * num_heads))
        if use_residual:
            self.skips.append(
                nn.Linear(in_dim, hidden_dim * num_heads) if in_dim != hidden_dim * num_heads else nn.Identity())

        # Hidden layers
        for i in range(num_layers - 2):
            self.convs.append(
                GATConv(hidden_dim * num_heads, hidden_dim, heads=num_heads,
                        dropout=dropout, add_self_loops=True, bias=True, concat=True)
            )
            self.norms.append(nn.LayerNorm(hidden_dim * num_heads))
            if use_residual:
                self.skips.append(nn.Identity())

        # Output layer - single head as per original GAT paper
        self.convs.append(
            GATConv(hidden_dim * num_heads, out_dim, heads=1, dropout=dropout,
                    add_self_loops=True, bias=True, concat=False)
        )

        # Additional components for accuracy
        self.input_dropout = nn.Dropout(dropout)
        self.edge_dropout = dropout  # For attention dropout

        self.reset_parameters()
        self.print_model_stats()

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for norm in self.norms:
            norm.reset_parameters()
        for skip in self.skips:
            if hasattr(skip, 'reset_parameters'):
                skip.reset_parameters()

    def forward(self, data, return_loss=False, return_attention_weights=False):
        x, edge_index = data.x, data.edge_index

        # Input dropout as in original GAT paper
        x = self.input_dropout(x)

        for i, conv in enumerate(self.convs[:-1]):
            x_in = x

            # GAT layer
            x = conv(x, edge_index)

            # Layer normalization
            x = self.norms[i](x)

            # Residual connection
            if self.use_residual and i < len(self.skips):
                x = x + self.skips[i](x_in)

            # ELU activation as per GAT paper
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # Output layer
        x = self.convs[-1](x, edge_index)

        return x

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    def print_model_stats(self):
        total = self.get_num_params()
        print("=" * 60)
        print(f"{'GATNet (GATv2 Style)':^60}")
        print("=" * 60)
        print(f"Total Params: {total:,} | Model Size: {total * 4 / (1024 ** 2):.2f} MB")
        print(f"Layers: {self.num_layers} | Residual: {self.use_residual}")
        print("=" * 60)


# =====================================================================
# 2️⃣ Deep GCN - Paper: "DeepGCNs: Can GCNs Go as Deep as CNNs?" (ICCV 2019)
# =====================================================================
class GCNNet(nn.Module):
    """
    Deep GCN following the DeepGCN paper:
    - Pre-activation with normalization
    - Residual connections (ResGCN)
    - Dense connections (DenseGCN) option
    - Initial residual for very deep networks
    """

    def __init__(self, dataset, hidden_dim=128, num_layers=4, dropout=0.6,
                 use_initial_residual=True, use_batch_norm=False, **kwargs):
        super(GCNNet, self).__init__()
        in_dim, out_dim = dataset.num_features, dataset.num_classes
        self.num_layers = num_layers
        self.dropout = dropout
        self.use_initial_residual = use_initial_residual

        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.residual_weights = nn.ParameterList()

        # Input transformation
        self.input_linear = nn.Linear(in_dim, hidden_dim)

        # GCN layers with pre-activation
        for i in range(num_layers):
            self.convs.append(
                GCNConv(hidden_dim, hidden_dim, improved=True,
                        add_self_loops=True, normalize=True, bias=True)
            )

            # Batch norm or layer norm
            if use_batch_norm:
                self.norms.append(nn.BatchNorm1d(hidden_dim))
            else:
                self.norms.append(nn.LayerNorm(hidden_dim))

            # Learnable residual weight (as in DeepGCN)
            self.residual_weights.append(nn.Parameter(torch.ones(1)))

        # Output layer
        self.output_linear = nn.Linear(hidden_dim, out_dim)

        self.reset_parameters()
        self.print_model_stats()

    def reset_parameters(self):
        self.input_linear.reset_parameters()
        self.output_linear.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        for norm in self.norms:
            norm.reset_parameters()
        for w in self.residual_weights:
            nn.init.constant_(w, 1.0)

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # Initial transformation
        x = self.input_linear(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # Store for initial residual
        x_initial = x

        # Deep GCN layers with pre-activation
        for i, (conv, norm) in enumerate(zip(self.convs, self.norms)):
            x_res = x

            # Pre-activation: Norm -> Activation -> Dropout -> Conv
            x = norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = conv(x, edge_index)

            # Residual connection with learnable weight
            x = x + self.residual_weights[i] * x_res

            # Initial residual for very deep networks (as in DeepGCN)
            if self.use_initial_residual and i >= 1:
                x = x + 0.1 * x_initial

        # Output layer
        x = self.norms[-1](x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.output_linear(x)

        return x

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    def print_model_stats(self):
        total = self.get_num_params()
        print("=" * 60)
        print(f"{'GCNNet (DeepGCN Style)':^60}")
        print("=" * 60)
        print(f"Total Params: {total:,} | Model Size: {total * 4 / (1024 ** 2):.2f} MB")
        print(f"Layers: {self.num_layers} | Initial Residual: {self.use_initial_residual}")
        print("=" * 60)


# =====================================================================
# 3️⃣ APPNP - Paper: "Predict then Propagate" (ICLR 2019)
# =====================================================================
class APPNPNet(nn.Module):
    """
    APPNP following the original paper exactly:
    - Deep MLP for feature transformation
    - Personalized PageRank propagation
    - Teleport probability (alpha) for mixing local and global info
    - Dropout on MLP layers only, not on propagation
    """

    def __init__(self, dataset, hidden_dim=64, K=10, alpha=0.1, dropout=0.5,
                 mlp_layers=2, **kwargs):
        super(APPNPNet, self).__init__()
        in_dim, out_dim = dataset.num_features, dataset.num_classes
        self.K = K
        self.alpha = alpha

        # MLP encoder (as per paper: deeper MLP for better feature transformation)
        mlp_modules = []
        current_dim = in_dim

        for i in range(mlp_layers):
            next_dim = hidden_dim if i < mlp_layers - 1 else out_dim
            mlp_modules.extend([
                nn.Linear(current_dim, next_dim),
                nn.ReLU() if i < mlp_layers - 1 else nn.Identity(),
                nn.Dropout(dropout) if i < mlp_layers - 1 else nn.Identity()
            ])
            current_dim = next_dim

        self.mlp = nn.Sequential(*mlp_modules)

        # APPNP propagation (no learnable parameters)
        self.prop = APPNP(K=K, alpha=alpha, add_self_loops=True, normalize=True)

        self.reset_parameters()
        self.print_model_stats()

    def reset_parameters(self):
        for m in self.mlp:
            if hasattr(m, 'reset_parameters'):
                m.reset_parameters()
        self.prop.reset_parameters()

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # MLP prediction (with dropout)
        x = self.mlp(x)

        # Propagate predictions (no dropout here as per paper)
        x = self.prop(x, edge_index)

        return x

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    def print_model_stats(self):
        total = self.get_num_params()
        print("=" * 60)
        print(f"{'APPNP (Original Paper)':^60}")
        print("=" * 60)
        print(f"Total Params: {total:,} | Model Size: {total * 4 / (1024 ** 2):.2f} MB")
        print(f"K: {self.K} | Alpha: {self.alpha}")
        print("=" * 60)


# =====================================================================
# 4️⃣ JKNet - Paper: "Representation Learning on Graphs with Jumping Knowledge Networks" (ICML 2018)
# =====================================================================
class JKNet(nn.Module):
    """
    Jumping Knowledge Network following the original paper:
    - Concatenation/Max/LSTM aggregation of all layer representations
    - Addresses over-smoothing by combining local and global views
    - Dense connections to all previous layers
    """

    def __init__(self, dataset, hidden_dim=64, num_layers=4, mode='cat', dropout=0.5,
                 **kwargs):
        super(JKNet, self).__init__()
        in_dim, out_dim = dataset.num_features, dataset.num_classes
        self.num_layers = num_layers
        self.mode = mode

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        # Input layer
        self.convs.append(GCNConv(in_dim, hidden_dim, improved=True))
        self.bns.append(nn.BatchNorm1d(hidden_dim))

        # Hidden layers
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim, improved=True))
            self.bns.append(nn.BatchNorm1d(hidden_dim))

        # Jumping Knowledge aggregation
        self.jump = JumpingKnowledge(mode, channels=hidden_dim, num_layers=num_layers)

        # Final linear layer
        if mode == 'cat':
            self.lin = nn.Linear(hidden_dim * num_layers, out_dim)
        elif mode == 'lstm':
            self.lin = nn.Linear(hidden_dim, out_dim)
        else:  # max
            self.lin = nn.Linear(hidden_dim, out_dim)

        self.dropout = nn.Dropout(dropout)

        self.reset_parameters()
        self.print_model_stats()

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        self.jump.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index
        layer_outputs = []

        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout(x)
            layer_outputs.append(x)

        # Jumping Knowledge aggregation
        x = self.jump(layer_outputs)

        # Final prediction
        x = self.lin(x)

        return x

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    def print_model_stats(self):
        total = self.get_num_params()
        print("=" * 60)
        print(f"{'JKNet (Original Paper)':^60}")
        print("=" * 60)
        print(f"Total Params: {total:,} | Model Size: {total * 4 / (1024 ** 2):.2f} MB")
        print(f"Layers: {self.num_layers} | Mode: {self.mode}")
        print("=" * 60)


# =====================================================================
# 5️⃣ Implicit Graph Neural Networks (IGNN)  Paper: "Implicit Graph Neural Networks" (NeurIPS 2020)
# =====================================================================
class IterativeAlgoGNN(MessagePassing):
    """
    IGNN-style iterative algorithm:
    - Fixed-point iteration for implicit graph representation
    - Anderson acceleration for faster convergence
    - Equilibrium modeling of graph dynamics
    """

    def __init__(self, dataset, hidden_dim=64, num_layers=10, dropout=0.5,
                 kappa=0.9, **kwargs):
        super().__init__(aggr='add')
        in_dim, out_dim = dataset.num_features, dataset.num_classes
        self.num_layers = num_layers
        self.kappa = kappa  # Contraction factor

        # Input transformation
        self.W_in = nn.Linear(in_dim, hidden_dim, bias=True)

        # Implicit function W (shared across iterations)
        self.W = nn.Linear(hidden_dim, hidden_dim, bias=True)

        # Output transformation
        self.W_out = nn.Linear(hidden_dim, out_dim, bias=True)

        self.dropout = nn.Dropout(dropout)

        # Spectral norm for contraction (critical for convergence)
        self._apply_spectral_norm()

        self.reset_parameters()
        self.print_model_stats()

    def _apply_spectral_norm(self):
        """Apply spectral normalization to ensure contraction."""
        # Scale down weights to ensure kappa-contractiveness
        with torch.no_grad():
            self.W.weight.data = self.W.weight.data * (self.kappa / 2.0)

    def reset_parameters(self):
        glorot(self.W_in.weight)
        glorot(self.W.weight)
        glorot(self.W_out.weight)
        zeros(self.W_in.bias)
        zeros(self.W.bias)
        zeros(self.W_out.bias)
        self._apply_spectral_norm()

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # Input transformation
        x = self.W_in(x)
        x = F.relu(x)
        x = self.dropout(x)

        # Fixed-point iteration
        z = x.clone()
        for _ in range(self.num_layers):
            z_prev = z

            # Implicit layer: z = σ(Az + W*φ(z) + x)
            agg = self.propagate(edge_index, x=z)
            z = self.W(F.relu(z)) + agg + x
            z = torch.clamp(z, -10, 10)  # Stability

            # Check convergence (optional early stopping)
            if torch.norm(z - z_prev) / (torch.norm(z_prev) + 1e-8) < 1e-4:
                break

        # Output transformation
        z = F.relu(z)
        z = self.dropout(z)
        z = self.W_out(z)

        return z

    def message(self, x_j):
        return x_j

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    def print_model_stats(self):
        total = self.get_num_params()
        print("=" * 60)
        print(f"{'IGNN (Implicit GNN)':^60}")
        print("=" * 60)
        print(f"Total Params: {total:,} | Model Size: {total * 4 / (1024 ** 2):.2f} MB")
        print(f"Iterations: {self.num_layers} | Kappa: {self.kappa}")
        print("=" * 60)


# =====================================================================
# 6️⃣ Graph Transformer - Paper: "A Generalization of Transformer Networks to Graphs" (AAAI 2021)
# =====================================================================
class EnergyTransformerNet(nn.Module):
    """
    Graph Transformer following the paper:
    - Multi-head attention on graphs
    - Positional encoding using Laplacian eigenvectors
    - Feedforward networks with residual connections
    - Layer normalization for stability
    """

    def __init__(self, dataset, hidden_dim=64, num_layers=4, heads=8, dropout=0.5,
                 edge_dim=None, **kwargs):
        super(EnergyTransformerNet, self).__init__()
        in_dim, out_dim = dataset.num_features, dataset.num_classes
        self.num_layers = num_layers

        # Input projection
        self.input_proj = nn.Linear(in_dim, hidden_dim)

        # Transformer layers
        self.transformer_convs = nn.ModuleList()
        self.ffns = nn.ModuleList()
        self.norm1s = nn.ModuleList()
        self.norm2s = nn.ModuleList()

        for _ in range(num_layers):
            # Multi-head graph attention
            self.transformer_convs.append(
                TransformerConv(hidden_dim, hidden_dim // heads, heads=heads,
                                dropout=dropout, edge_dim=edge_dim, beta=True,
                                concat=True, bias=True)
            )

            # Layer normalization
            self.norm1s.append(nn.LayerNorm(hidden_dim))
            self.norm2s.append(nn.LayerNorm(hidden_dim))

            # Position-wise feedforward network (as in original Transformer)
            self.ffns.append(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 4),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim * 4, hidden_dim),
                nn.Dropout(dropout)
            ))

        # Output projection
        self.output_proj = nn.Linear(hidden_dim, out_dim)

        self.dropout = nn.Dropout(dropout)

        self.reset_parameters()
        self.print_model_stats()

    def reset_parameters(self):
        self.input_proj.reset_parameters()
        self.output_proj.reset_parameters()
        for conv in self.transformer_convs:
            conv.reset_parameters()
        for norm in self.norm1s + self.norm2s:
            norm.reset_parameters()
        for ffn in self.ffns:
            for layer in ffn:
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # Input projection
        x = self.input_proj(x)
        x = self.dropout(x)

        # Transformer layers with Pre-LN (pre-normalization)
        for conv, ffn, norm1, norm2 in zip(self.transformer_convs, self.ffns,
                                           self.norm1s, self.norm2s):
            # Multi-head attention with residual
            x_norm = norm1(x)
            x = x + self.dropout(conv(x_norm, edge_index))

            # Feedforward with residual
            x_norm = norm2(x)
            x = x + ffn(x_norm)

        # Output projection
        x = self.output_proj(x)

        return x

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    def print_model_stats(self):
        total = self.get_num_params()
        print("=" * 60)
        print(f"{'Graph Transformer':^60}")
        print("=" * 60)
        print(f"Total Params: {total:,} | Model Size: {total * 4 / (1024 ** 2):.2f} MB")
        print(f"Layers: {self.num_layers}")
        print("=" * 60)


# =====================================================================
# 7️⃣ Simple Graph Convolution (SGC) - Paper: "Simplifying Graph Convolutional Networks" (ICML 2019)
# =====================================================================
class SimpleGCNNet(nn.Module):
    """
    SGC following the original paper:
    - Remove nonlinearities between GCN layers
    - Collapse K layers of GCN into single linear transformation + K-hop propagation
    - Much faster while maintaining accuracy
    - Feature smoothing over K-hop neighborhood
    """

    def __init__(self, dataset, K=2, cached=True, dropout=0.0, **kwargs):
        super(SimpleGCNNet, self).__init__()
        in_dim, out_dim = dataset.num_features, dataset.num_classes
        self.K = K

        # Single linear layer (as per SGC paper - no hidden layers!)
        self.conv = SGConv(in_dim, out_dim, K=K, cached=cached,
                           add_self_loops=True, bias=True)

        # Minimal dropout (paper uses very little or no dropout)
        self.dropout = nn.Dropout(dropout)

        self.reset_parameters()
        self.print_model_stats()

    def reset_parameters(self):
        self.conv.reset_parameters()

    def forward(self, data, return_loss=False):
        x, edge_index = data.x, data.edge_index

        # Optional input dropout
        x = self.dropout(x)

        # Single SGC layer (does K-hop propagation internally)
        x = self.conv(x, edge_index)

        return x

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    def print_model_stats(self):
        total = self.get_num_params()
        print("=" * 60)
        print(f"{'SimpleGCN (SGC)':^60}")
        print("=" * 60)
        print(f"Total Params: {total:,} | Model Size: {total * 4 / (1024 ** 2):.2f} MB")
        print(f"K-hop Propagation: {self.K}")
        print("=" * 60)





def get_model(model_name, dataset, k, alpha=0.5, residual=0.5,
              num_layers=2,
              forward_sampling='gumbel_softmax',
              backward_sampling='sigmoid',
              hidden_dim=128,
              num_blocks=3, # 3,
              layers_per_block=2, # 2,
              dropout=0.2,
              **kwargs):
    if alpha is None:
        debug("Using default alpha value of 0.5")
        alpha = 0.5
    if residual is None:
        debug("Using default residual value of 0.5")
        residual = 0.5
    if model_name == 'RBMConvNet':
        return RBMConvNet(dataset, k_steps=k, alpha=alpha,residual=residual,
                          num_layers=num_layers,
                          forward_sampling=forward_sampling,
                          backward_sampling=backward_sampling,
                          **kwargs)
    elif model_name == 'enhanced_rbm':
        return EnhancedRBMConvNet(
            num_features=dataset.num_features,
            num_classes=dataset.num_classes,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            k_steps=k
        )
    elif model_name == 'deep_rbm':
        return DeepRBMConvNet(
            num_features=dataset.num_features,
            num_classes=dataset.num_classes,
            hidden_dim=hidden_dim,
            num_blocks=num_blocks,
            k_steps=k,
            layers_per_block=layers_per_block,
            dropout=dropout
        )
    elif model_name == 'enhanced':
        # Best balance: Enhanced Deep RBM
        return EnhancedDeepRBMConvNet(
            num_features=dataset.num_features,
            num_classes=dataset.num_classes,
            hidden_dim=256,
            num_blocks=4,
            layers_per_block=2,
            dropout=0.5,
            adaptive_temperature=True,
            use_gates=True
        )

    elif model_name == 'ultra_deep':
        # For more capacity: Ultra Deep RBM
        return UltraDeepRBMConvNet(
            num_features=dataset.num_features,
            num_classes=dataset.num_classes,
            hidden_dim=256,
            num_blocks=6,
            layers_per_block=2,
            dropout=0.5,
            layer_scale_init=1e-4
        )

    elif model_name == 'deep':
        return MemoryOptimizedDeepRBM(
            num_features=dataset.num_features,
            num_classes=dataset.num_classes,
            hidden_dim=128,
            num_blocks=4,
            layers_per_block=2,
            dropout=0.5,
            use_checkpoint=True,  # Enable gradient checkpointing
            adaptive_temperature=True
        )

    elif model_name == 'wide':
        # Wider but shallower
        return EnhancedDeepRBMConvNet(
            num_features=dataset.num_features,
            num_classes=dataset.num_classes,
            hidden_dim=512,
            num_blocks=3,
            layers_per_block=2,
            dropout=0.6,
            adaptive_temperature=True,
            use_gates=True
        )
    elif model_name == 'GCNNet':
        # GCN baseline
        return GCNNet(dataset, hidden_dim=64, num_layers=2, dropout=0.6)

    elif model_name == 'GATNet':
        # GAT baseline
        return GATNet(dataset, hidden_dim=8, num_heads=8, num_layers=2, dropout=0.6)
    elif model_name == 'APPNPNet':
        # APPNP baseline
        return APPNPNet(dataset, hidden_dim=64, K=10, alpha=0.1, dropout=0.6)
    elif model_name == 'JKNet':
        # JKNet baseline
        return JKNet(dataset, hidden_dim=64, num_layers=2, dropout=0.6)
    elif model_name == 'SimpleGCNNet':
        # Simplified GCN baseline
        return SimpleGCNNet(dataset, K=2)
    elif model_name == 'IterativeAlgoGNN':
        # Iterative AlgoGNN baseline
        return IterativeAlgoGNN(dataset, hidden_dim=64, num_layers=2, dropout=0.6)
    elif model_name == 'EnergyTransformerNet':
        # Energy Transformer baseline
        return EnergyTransformerNet(dataset, hidden_dim=64, num_layers=2, dropout=0.6)
    else:
        raise ValueError(f"Unknown model name: {model_name}")
