import torch
import torch.nn as nn

from .utils import scatter_add


class Conv(nn.Module):
    """Basic convolution module for non-spherical features."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_dim: int,
        bias: bool = True,
        depthwise: bool = True,
        avg_pool: bool = False,
    ):
        super().__init__()

        # Check arguments
        self.depthwise = depthwise
        if depthwise:
            assert in_channels == out_channels, "in_channels and out_channels must be equal for depthwise convolution"
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.avg_pool = avg_pool

        # Construct kernels
        self.kernel = nn.Linear(kernel_dim, in_channels if depthwise else out_channels*in_channels, bias=False)

        # Construct bias
        if bias:
            self.bias = nn.Parameter(torch.empty(out_channels))
            self.bias.data.zero_()
        else:
            self.register_parameter("bias", None)

        # Automatic re-initialization
        self.register_buffer("initialized", torch.tensor(False))

    def forward(self, x, kernel_basis, edge_index):
        if self.training and not self.initialized:
            self._initialize(edge_index)
        
        x_send = x[edge_index[0]]
        kernel = self.kernel(kernel_basis)
        if self.depthwise:
            message = kernel * x_send
        else:
            kernel = kernel.unflatten(-1, (self.out_channels, self.in_channels))
            message = torch.einsum('boi,bi->bo', kernel, x_send)
        num_nodes = edge_index[1].max().item() + 1
        x = scatter_add(src=message, index=edge_index[1], dim_size=num_nodes)
        
        if self.avg_pool:
            ones = torch.ones_like(edge_index[1], dtype=torch.float32)
            num_neigh_per_node = scatter_add(src=ones, index=edge_index[1], dim_size=num_nodes)
            x = x / num_neigh_per_node.view(-1,  *([1]* (len(x.shape) - 1)))
        
        if self.bias is not None:
            x = x + self.bias
        
        return x
    
    def _initialize(self, edge_index, kernel_var = 1.0):
        with torch.no_grad():
            if self.avg_pool:
                avg_num_neigh = 1.0
            else:
                num_nodes = edge_index[1].max().item() + 1
                ones = torch.ones_like(edge_index[1], dtype=torch.float32)
                num_neigh_per_node = scatter_add(src=ones, index=edge_index[1], dim_size=num_nodes)
                avg_num_neigh = num_neigh_per_node.mean()
            N_sum = avg_num_neigh * (1 if self.depthwise else self.in_channels)
            std_W_target = (1.0 / (N_sum * kernel_var))**0.5
            torch.nn.init.normal_(self.kernel.weight, mean=0.0, std=std_W_target)
            self.initialized = ~self.initialized
    
class PointwiseSphereConv(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_dim: int,
            bias: bool = True,
            ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_dim = kernel_dim
        self.kernel = nn.Linear(kernel_dim, out_channels * in_channels, bias=False)
        if bias:
            self.bias = nn.Parameter(torch.empty(out_channels))
            self.bias.data.zero_()
        else:
            self.register_parameter("bias", None)
        self.register_buffer("initialized", torch.tensor(False))
        
    def forward(self, x, kernel_basis):
        if self.training and not self.initialized:
            self._initialize(kernel_basis)

        kernel = self.kernel(kernel_basis).unflatten(-1, (self.out_channels, self.in_channels))
        out = torch.einsum("...oc,podc->...pd", x, kernel)

        return out
    
    def _initialize(self, sphere_kernel_basis):
        with torch.no_grad():
            num_ori = sphere_kernel_basis.shape[0]
            kernel_var = sphere_kernel_basis.var()
            N_sum = num_ori * self.in_channels
            std_W_target = (1 / (N_sum * self.kernel_dim * kernel_var))**0.5
            torch.nn.init.normal_(self.kernel.weight, mean=0.0, std=std_W_target)
            self.initialized = ~self.initialized

class ConvNext(nn.Module):
    """ConvNext block with separable convolution for both spatial and spherical features."""

    def __init__(
        self,
        in_channels: int,
        kernel_dim: int,
        out_channels: int = None,
        act: nn.Module = nn.GELU(),
        layer_scale: float = 1e-6,
        widening_factor: int = 4,
        avg_pool: bool = False,
        lifted: bool = False,
    ):
        super().__init__()

        out_channels = in_channels if out_channels is None else out_channels

        self.conv = Conv(in_channels, in_channels, kernel_dim, bias=True, depthwise=True, avg_pool=avg_pool)

        self.act_fn = act

        if lifted:
            self.linear_1_module = PointwiseSphereConv(in_channels, widening_factor * in_channels, kernel_dim, bias=True)
            self.linear_2_module = PointwiseSphereConv(widening_factor * in_channels, in_channels, kernel_dim, bias=True)
            self.linear_1 = lambda x_input, fkb: self.linear_1_module(x_input, fkb)
            self.linear_2 = lambda x_input, fkb: self.linear_2_module(x_input, fkb)
        else:
            self.linear_1_module = nn.Linear(in_channels, widening_factor * in_channels, bias=True)
            self.linear_2_module = nn.Linear(widening_factor * in_channels, out_channels, bias=True)
            self.linear_1 = lambda x_input, fkb: self.linear_1_module(x_input)
            self.linear_2 = lambda x_input, fkb: self.linear_2_module(x_input)

        if layer_scale is not None:
            self.layer_scale = nn.Parameter(torch.ones(out_channels) * layer_scale)
        else:
            self.register_buffer("layer_scale", None)

        self.norm = nn.LayerNorm(in_channels)

    def forward(self, x, kernel_basis, edge_index, sphere_kernel_basis=None):
        """Forward pass of the ConvNext block."""
        input = x

        x = self.conv(x, kernel_basis, edge_index)
        x = self.norm(x)
        x = self.linear_1(x, sphere_kernel_basis)
        x = self.act_fn(x)
        x = self.linear_2(x, sphere_kernel_basis)

        if x.shape == input.shape:
            if self.layer_scale is not None:
                x = self.layer_scale * x
            x = x + input

        return x

class PolynomialFeatures(nn.Module):
    """Generates polynomial features up to specified degree."""
    
    def __init__(self, degree: int):
        super().__init__()
        self.degree = degree

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Generate polynomial features."""
        polynomial_list = [x]
        for it in range(1, self.degree + 1):
            polynomial_list.append(
                torch.einsum("...i,...j->...ij", polynomial_list[-1], x).flatten(-2, -1)
            )
        return torch.cat(polynomial_list, -1)

class PolynomialCutoff(torch.nn.Module):
    """
    Distance windowing function from DimeNet.
    Smoothly decays to zero at r_max.
    
    Reference:
    Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020.
    """
    def __init__(self, r_max, p=6):
        super().__init__()
        if r_max is not None:
            self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
            self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()))
        else:
            self.r_max = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.r_max is not None:
            envelope = (
                1.0
                - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
                + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
                - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
            )
            return envelope * (x < self.r_max)
        else:
            return torch.ones_like(x)