import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import degree
from torch_scatter import scatter_add


def pairwise_squared_euclidean(x: Tensor) -> Tensor:
    x_norm = (x ** 2).sum(1).view(-1, 1)
    return x_norm + x_norm.T - 2.0 * x @ x.T


class RUNGConv(nn.Module):
    def __init__(
        self,
        lam_hat: float = 0.9,
        w_func: callable = None,
        quasi_newton: bool = True,
        eta: float = 0.01,
        prop_step: int = 10
    ):
        super().__init__()
        assert 0 <= lam_hat <= 1, 'lam_hat should be in [0, 1]'
        if quasi_newton:
            assert eta is None or eta == 0.01, 'eta is ignored when using quasi_newton'
        else:
            assert eta is not None and eta > 0, 'eta must be > 0 when not using quasi_newton'

        self.lam_hat = lam_hat
        self.lam = 1 / lam_hat - 1
        self.w_func = w_func
        self.quasi_newton = quasi_newton
        self.eta = eta
        self.prop_step = prop_step

    def reset_parameters(self):
        self.lam_hat = self.lam_hat
        self.lam = 1 / self.lam_hat - 1
        self.w_func = self.w_func
        self.quasi_newton = self.quasi_newton
        self.eta = self.eta
        self.prop_step = self.prop_step

    def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:
        F0 = x
        F = x
        row, col = edge_index
        num_nodes = x.size(0)

        deg = degree(row, num_nodes=num_nodes, dtype=x.dtype)
        D = deg + 1e-12
        D_inv_sqrt = D.pow(-0.5).view(-1, 1)

        for _ in range(self.prop_step):
            F_norm = F * D_inv_sqrt
            Z = pairwise_squared_euclidean(F_norm).detach()
            W = self.w_func(Z.sqrt())
            W.fill_diagonal_(0)
            W[torch.isnan(W)] = 1

            if self.quasi_newton:
                # sparse matrix multiplication: compute W_ij * F_j for all (i,j)
                W_row_col = W[row, col]  # shape [num_edges]
                Q_hat = scatter_add(W_row_col, row, dim=0, dim_size=num_nodes) / D + self.lam
                Q_hat = Q_hat.view(-1, 1)

                # aggregate weighted features
                agg = scatter_add(W_row_col.view(-1, 1) * F[col], row, dim=0, dim_size=num_nodes)
                agg = agg / D.view(-1, 1)

                F = agg / Q_hat + self.lam * F0 / Q_hat

            else:
                # Non-quasi-newton gradient-based update
                W_row_col = W[row, col]
                F_diff = F[row] - F[col]
                grad_smoothing = scatter_add(W_row_col.view(-1, 1) * F_diff, row, dim=0, dim_size=num_nodes)
                grad_reg = 2 * self.lam * (F - F0)
                F = F - self.eta * (2 * grad_smoothing + grad_reg)

        return F
