from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree
from greatx.nn.layers import activations
# Assuming greatx.utils.wrapper is available, if not, remove or replace it
from greatx.utils import wrapper


class _Propagate(MessagePassing):
    """Internal helper class for propagation with preconditioning."""
    def __init__(self):
        super().__init__(aggr='add')

    def forward(self, y, x_residual, edge_index, edge_weight, alp, lam): # Renamed x_init to x_residual
        row, col = edge_index
        deg = degree(col, y.size(0), dtype=y.dtype)

        # Handle zero degrees to avoid NaNs from pow(-0.5) or pow(-1)
        # masked_fill_ is used, but ensure the base of power is not zero or inf before that.
        deg_pow_base = lam * deg + (1 - lam)
        deg_inv_sqrt = torch.pow(deg_pow_base, -0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.0) # Mask inf to 0

        norm_y = deg_inv_sqrt.view(-1, 1) * y

        # Propagate message
        propagated_y = self.propagate(edge_index, x=norm_y, edge_weight=edge_weight)
        propagated_y = deg_inv_sqrt.view(-1, 1) * propagated_y

        deg_inv = torch.pow(deg_pow_base, -1)
        deg_inv.masked_fill_(deg_inv == float('inf'), 0.0) # Mask inf to 0
        norm_x_residual = deg_inv.view(-1, 1) * x_residual # Use x_residual here

        # The core update equation - all terms must have the same dimension
        y_new = (1 - alp) * y + alp * lam * propagated_y + alp * norm_x_residual
        return y_new

    def message(self, x_j, edge_weight):
        # x_j holds features of source nodes for each edge
        # edge_weight holds the attention weights for each edge
        return x_j if edge_weight is None else x_j * edge_weight.view(-1, 1)


class _PropagateNoPrecond(MessagePassing):
    """Internal helper class for propagation without preconditioning."""
    def __init__(self):
        super().__init__(aggr='add')

    def forward(self, y, x_residual, edge_index, edge_weight, alp, lam): # Renamed x_init to x_residual
        row, col = edge_index
        deg = degree(col, y.size(0), dtype=y.dtype)
        
        # Handle zero degrees
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.0) # Mask inf to 0
        
        norm_y = y * deg_inv_sqrt.view(-1, 1)
        
        # Propagate message
        propagated_y = self.propagate(
            edge_index, x=norm_y, edge_weight=edge_weight
        ) * deg_inv_sqrt.view(-1, 1)

        # The core update equation - all terms must have the same dimension
        y_new = (1 - alp * lam - alp) * y + alp * lam * propagated_y + alp * x_residual # Use x_residual here
        return y_new

    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else x_j * edge_weight.view(-1, 1)


class _Attention(nn.Module):
    """Internal helper class for attention mechanism."""
    def __init__(self, tau, T, p, attn_dropout=0.0):
        super().__init__()
        self.tau = tau
        self.T = T
        self.p = p
        self.attn_dropout = attn_dropout

    def forward(self, y, edge_index, etas=None):
        if etas is not None:
            # Ensure etas broadcast correctly. etas is (d,) so view(1, -1) expands for batch multiplication
            # y is (num_nodes, d)
            y = y * etas.view(1, -1) 

        row, col = edge_index
        
        # Calculate squared Euclidean distance between node features
        # (y[row] - y[col]) is (num_edges, d)
        # torch.pow(..., 2).sum(dim=-1) results in (num_edges,)
        dist = torch.pow(y[row] - y[col], 2).sum(dim=-1) 
        
        # Ensure 'w' is positive to avoid issues with pow and log (though log not used here, good practice)
        w = F.relu(dist) + 1e-7 
        
        # Apply power transformation
        w = torch.pow(w, 1 - 0.5 * self.p) 
        
        # Apply tau lower bound
        w.masked_fill_(w < self.tau, self.tau)
        
        # Apply T upper bound, making it effectively infinite if exceeded
        if self.T > 0:
            w.masked_fill_(w > self.T, float("inf"))

        # Calculate inverse weight, add small epsilon to prevent division by zero if w becomes inf
        edge_weight = 1.0 / (w + 1e-9) 

        if self.attn_dropout > 0:
            edge_weight = F.dropout(edge_weight, self.attn_dropout, 
                                    training=self.training)
        return edge_weight


class _UnfoldingAndAttention(nn.Module):
    """Internal helper class to combine unfolding and attention."""
    def __init__(self, d, alp, lam, prop_step, attn_aft, tau, T, p, use_eta, 
                 init_att, attn_dropout, precond):
        super().__init__()
        # self.alp calculation from DGL version, if alp is 0, auto set.
        self.alp = alp if alp > 0 else 1 / (lam + 1) 
        self.lam = lam
        self.prop_step = prop_step
        self.attn_aft = attn_aft
        self.use_eta = use_eta
        self.init_att = init_att

        prop_method = _Propagate if precond else _PropagateNoPrecond 
        self.prop_layers = nn.ModuleList( 
            [prop_method() for _ in range(prop_step)]) 

        self.init_attn_layer = _Attention(tau, T, p, attn_dropout) if self.init_att else None 
        self.attn_layer = _Attention(tau, T, p, attn_dropout) if self.attn_aft >= 0 else None 
        # etas parameter matches the `d` (feature dimension) that is passed to it
        self.etas = nn.Parameter(torch.ones(d)) if self.use_eta else None 

    def reset_parameters(self):
        if self.etas is not None:
            nn.init.ones_(self.etas) # Initialize etas to ones, as in DGL version

    # y is the feature being propagated, x_residual is the initial feature for residual connection
    def forward(self, y, x_residual, edge_index, edge_weight): 
        # If initial attention is used, update edge_weight based on current features (y)
        if self.init_attn_layer: 
            edge_weight = self.init_attn_layer(y, edge_index, self.etas) 

        # Propagate through layers
        for k, layer in enumerate(self.prop_layers): 
            # Pass both the propagating feature (y) and the residual feature (x_residual)
            y = layer(y, x_residual, edge_index, edge_weight, self.alp, self.lam) 
            
            # Apply attention at a specific intermediate layer if configured
            if k == self.attn_aft - 1 and self.attn_layer: 
                edge_weight = self.attn_layer(y, edge_index, self.etas) 
        return y 


class _MLP(nn.Module):
    """
    A multi-layer perceptron module with optional BatchNorm, activation, and dropout.
    Designed to mimic the DGL version's behavior for init_activate and layer-wise application.
    """
    def __init__(self, in_channels, out_channels, hidden_channels, num_layers,
                 act, dropout, bn, bias, init_activate):
        super().__init__()

        self.init_activate = init_activate
        self.num_layers = num_layers
        self.dropout_rate = dropout
        self.act_fn = activations.get(act)
        self.bn = bn

        self.linear_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList() # For BatchNorm layers

        if num_layers == 0:
            # If no layers, it's effectively an identity mapping in forward.
            pass
        elif num_layers == 1:
            self.linear_layers.append(nn.Linear(in_channels, out_channels, bias=bias))
        else: # num_layers > 1
            # First hidden layer
            self.linear_layers.append(nn.Linear(in_channels, hidden_channels, bias=bias))

            # Intermediate hidden layers
            for _ in range(num_layers - 2):
                self.linear_layers.append(nn.Linear(hidden_channels, hidden_channels, bias=bias))

            # Output layer
            self.linear_layers.append(nn.Linear(hidden_channels, out_channels, bias=bias))

        # Initialize BatchNorm layers based on DGL's logic
        # DGL: norm_cnt = num_layers - 1 + int(init_activate)
        # All DGL BatchNorms were on `hidden_d`. This implies feature reshaping
        # or specific design. Here, we adapt to input/hidden/output channels.
        if bn:
            if init_activate and num_layers > 0:
                # If initial activation is true, the first BN is on the input features
                # before the first linear layer.
                self.norm_layers.append(nn.BatchNorm1d(in_channels))

            # Add BatchNorm for intermediate layers
            # Total linear layers is `num_layers`. Norms are applied after each linear
            # except the last one.
            # So `num_layers - 1` norms will be for outputs of intermediate linear layers.
            # If num_layers=1, `num_layers-1=0`, so no intermediate norms.
            if num_layers > 1:
                for _ in range(num_layers - 1): # After each of `num_layers-1` linear layers (excluding the last one)
                    self.norm_layers.append(nn.BatchNorm1d(hidden_channels))
            elif num_layers == 1 and not init_activate:
                # No BatchNorm layers needed for a single linear layer without initial activation
                pass


        self.reset_parameters()

    def reset_parameters(self):
        """Reset the parameters of the MLP layers and BatchNorm layers."""
        for layer in self.linear_layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight.data)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias.data, 0)
        for norm_layer in self.norm_layers:
            if isinstance(norm_layer, (nn.BatchNorm1d, nn.LayerNorm)):
                norm_layer.reset_parameters()

    def _apply_activation_dropout_and_norm(self, x, norm_idx):
        """Helper function to apply BatchNorm, activation, and dropout."""
        if self.bn and norm_idx < len(self.norm_layers):
            x = self.norm_layers[norm_idx](x)
        x = self.act_fn(x)
        if self.dropout_rate > 0:
            x = F.dropout(x, self.dropout_rate, training=self.training)
        return x

    def forward(self, x):
        if self.num_layers == 0:
            return x # Identity mapping if no MLP layers

        current_norm_idx = 0

        # Apply initial activation/dropout/norm if specified
        if self.init_activate:
            x = self._apply_activation_dropout_and_norm(x, current_norm_idx)
            current_norm_idx += 1

        # Iterate through linear layers
        for i, layer in enumerate(self.linear_layers):
            x = layer(x)
            # Apply activation/dropout/norm after the linear layer, but not for the very last layer's output
            if i < len(self.linear_layers) - 1:
                x = self._apply_activation_dropout_and_norm(x, current_norm_idx)
                current_norm_idx += 1
            # Note: If num_layers == 1 and init_activate is False, the loop does one linear layer,
            # and the condition `i < len(self.linear_layers) - 1` is false, so no activation/norm/dropout is applied after.
            # This correctly matches the DGL version's `if i != len(self.layers) - 1:`.
        return x


class TWIRLS(nn.Module):
    r"""Refactored TWIRLS model."""
    @wrapper
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        hidden_channels: int = 64,
        num_pre_layers: int = 1,
        num_post_layers: int = 0,
        prop_step: int = 32,
        act: str = 'relu',
        dropout: float = 0.5,
        input_dropout: float = 0.0,
        attn_dropout: float = 0.0,
        bias: bool = True,
        bn: bool = False,
        precond: bool = True,
        attention: bool = True,
        alp: float = 1.0,
        lam: float = 1.0,
        tau: float = 0.2,
        T: float = 2.0,
        p: float = 0.1,
        use_eta: bool = False,
        attn_bef: bool = False,
    ):
        super().__init__()

        self.input_dropout = input_dropout
        self.attention = attention
        self.num_pre_layers = num_pre_layers # Store this for forward logic

        # --- Determine dimensions for intermediate layers ---
        # `unfolding_io_channels` will be the feature dimension for:
        # 1. The input to the `_UnfoldingAndAttention` module (both `y` and `x_residual`).
        # 2. The `d` parameter for `_UnfoldingAndAttention` (used for `etas`).
        # 3. The input to the `mlp_aft` module.
        unfolding_io_channels = in_channels
        if num_pre_layers > 0:
            # If `num_pre_layers > 0`, `mlp_bef` will transform `in_channels` to `hidden_channels`.
            # So, the unfolding block will operate on `hidden_channels`.
            unfolding_io_channels = hidden_channels

        self.cacheable = (not self.attention) and num_pre_layers == 0 and \
            self.input_dropout <= 0
        self.cached_unfolding = None

        # Define modules
        # mlp_bef: Transforms input features if num_pre_layers > 0.
        # It maps from `in_channels` to `hidden_channels`.
        # If `num_pre_layers = 0`, this MLP will be an identity mapping in effect.
        self.mlp_bef = _MLP(in_channels, hidden_channels, hidden_channels,
                            num_pre_layers, act, dropout, bn, bias, init_activate=False)

        # _UnfoldingAndAttention: Operates on features of `unfolding_io_channels` dimension.
        self.unfolding = _UnfoldingAndAttention(
            d=unfolding_io_channels, # `d` for etas must match the propagating feature dimension
            alp=alp,
            lam=lam,
            prop_step=prop_step,
            attn_aft=prop_step // 2 if attention else -1,
            tau=tau,
            T=T,
            p=p,
            use_eta=use_eta,
            init_att=attn_bef,
            attn_dropout=attn_dropout,
            precond=precond)

        # mlp_aft: Maps the output of the unfolding block (`unfolding_io_channels`)
        # to the final `out_channels`.
        self.mlp_aft = _MLP(unfolding_io_channels, out_channels, hidden_channels,
                            num_post_layers, act, dropout, bn, bias,
                            # init_activate for mlp_aft depends on if mlp_bef exists and if mlp_aft itself has layers
                            init_activate=(num_pre_layers > 0) and (num_post_layers > 0))


    def reset_parameters(self):
        self.mlp_bef.reset_parameters()
        self.unfolding.reset_parameters()
        self.mlp_aft.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        # Keep a reference to the original input features.
        # This is important for the residual connection if `num_pre_layers == 0`.
        x_initial_features = x 

        if edge_weight is None:
            # Initialize edge_weight to ones if not provided.
            edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)

        # Apply input dropout before any MLP or propagation (as in DGL version)
        if self.input_dropout > 0:
            x = F.dropout(x, self.input_dropout, training=self.training)

        if self.num_pre_layers > 0:
            y_propagate = self.mlp_bef(x) 
            x_residual_for_prop = y_propagate 
        else:
            y_propagate = x 
            x_residual_for_prop = x

        if self.cacheable and not self.training:
            if self.cached_unfolding is None:
                self.cached_unfolding = self.unfolding(y_propagate, x_residual_for_prop, edge_index, edge_weight)
            x_unfolded_output = self.cached_unfolding
        else:
            x_unfolded_output = self.unfolding(y_propagate, x_residual_for_prop, edge_index, edge_weight)

        x_final_output = self.mlp_aft(x_unfolded_output)
        return x_final_output