import math

import torch
import torch.nn as nn

from .rapidash_utils import scatter_add, scatter_softmax


class SeparableFiberBundleConv(nn.Module):
    """ """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_dim,
        bias=True,
        groups=1,
        attention=False,
        edge_attr_dim=0,
    ):
        super().__init__()

        # Check arguments
        if groups == 1:
            self.depthwise = False
        elif groups == in_channels and groups == out_channels:
            self.depthwise = True
            self.in_channels = in_channels
            self.out_channels = out_channels
        else:
            assert ValueError(
                "Invalid option for groups, should be groups=1 or groups=in_channels=out_channels (depth-wise separable)"
            )

        # Construct kernels
        self.pre_linear = nn.Linear(in_channels + edge_attr_dim, in_channels, bias=False)
        self.kernel = nn.Linear(kernel_dim, in_channels, bias=False)
        self.fiber_kernel = nn.Linear(
            kernel_dim, int(in_channels * out_channels / groups), bias=False
        )
        self.attention = attention
        if self.attention:
            key_dim = 128
            self.key_transform = nn.Linear(in_channels, key_dim)
            self.query_transform = nn.Linear(in_channels, key_dim)
            nn.init.xavier_uniform_(self.key_transform.weight)
            nn.init.xavier_uniform_(self.query_transform.weight)
            self.key_transform.bias.data.fill_(0)
            self.query_transform.bias.data.fill_(0)

        # Construct bias
        if bias:
            self.bias = nn.Parameter(torch.empty(out_channels))
            self.bias.data.zero_()
        else:
            self.register_parameter("bias", None)

        # Automatic re-initialization
        self.register_buffer("callibrated", torch.tensor(False))

    def forward(self, x, kernel_basis, fiber_kernel_basis, edge_index, edge_attr=None):
        """ """

        if edge_attr is not None:
            x_s = self.pre_linear(torch.cat([x[edge_index[0]], edge_attr], dim=-1))
        else:
            x_s = x[edge_index[0]]
        # 1. Do the spatial convolution
        message = x_s * self.kernel(kernel_basis)  # [num_edges, num_ori, in_channels]
        if self.attention:
            keys = self.key_transform(x)
            queries = self.query_transform(x)
            d_k = keys.size(-1)
            att_logits = (keys[edge_index[0]] * queries[edge_index[1]]).sum(
                dim=-1, keepdim=True
            ) / math.sqrt(d_k)
            att_weights = scatter_softmax(att_logits, edge_index[1], x.size(0))
            message = message * att_weights
        x_1 = scatter_add(
            src=message, index=edge_index[1], dim_size=edge_index[1].max().item() + 1
        )

        # 2. Fiber (spherical) convolution
        fiber_kernel = self.fiber_kernel(fiber_kernel_basis)
        if self.depthwise:
            x_2 = (
                torch.einsum("boc,poc->bpc", x_1, fiber_kernel) / fiber_kernel.shape[-2]
            )
        else:
            x_2 = (
                torch.einsum(
                    "boc,podc->bpd",
                    x_1,
                    fiber_kernel.unflatten(-1, (self.out_channels, self.in_channels)),
                )
                / fiber_kernel.shape[-2]
            )

        # Re-callibrate the initializaiton
        if self.training and not (self.callibrated):
            self.callibrate(x.std(), x_1.std(), x_2.std())

        # Add bias
        if self.bias is not None:
            return x_2 + self.bias
        else:
            return x_2

    def callibrate(self, std_in, std_1, std_2):
        print("Callibrating...")
        with torch.no_grad():
            self.kernel.weight.data = self.kernel.weight.data * std_in / std_1
            self.fiber_kernel.weight.data = (
                self.fiber_kernel.weight.data * std_1 / std_2
            )
            self.callibrated = ~self.callibrated

class SeparableFiberBundleConvNext(nn.Module):
    """ """

    def __init__(
        self,
        in_channels,
        kernel_dim,
        out_channels=None,
        act=nn.GELU(),
        layer_scale=1e-6,
        widening_factor=4,
        attention=False,
    ):
        super().__init__()

        out_channels = in_channels if out_channels is None else out_channels

        self.conv = SeparableFiberBundleConv(
            in_channels,
            in_channels,
            kernel_dim,
            groups=in_channels,
            attention=attention,
        )

        self.act_fn = act

        self.linear_1 = nn.Linear(in_channels, widening_factor * in_channels)
        self.linear_2 = nn.Linear(widening_factor * in_channels, out_channels)

        if layer_scale is not None:
            self.layer_scale = nn.Parameter(torch.ones(out_channels) * layer_scale)
        else:
            self.register_buffer("layer_scale", None)

        self.norm = nn.LayerNorm(in_channels)

    def forward(self, x, kernel_basis, fiber_kernel_basis, edge_index, edge_attr=None):
        """ """
        input = x
        x = self.conv(x, kernel_basis, fiber_kernel_basis, edge_index, edge_attr)
        x = self.norm(x)
        x = self.linear_1(x)
        x = self.act_fn(x)
        x = self.linear_2(x)

        if self.layer_scale is not None:
            x = self.layer_scale * x

        if x.shape == input.shape:
            x = x + input

        return x
    
class Conv(nn.Module):
    """ """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_dim,
        bias=True,
        groups=1,
        attention=False,
        edge_attr_dim=0,
    ):
        super().__init__()

        # Check arguments
        if groups == 1:
            self.depthwise = False
        elif groups == in_channels and groups == out_channels:
            self.depthwise = True
            self.in_channels = in_channels
            self.out_channels = out_channels
        else:
            assert ValueError(
                "Invalid option for groups, should be groups=1 or groups=in_channels=out_channels (depth-wise separable)"
            )

        # Construct kernels
        self.pre_linear = nn.Linear(in_channels + edge_attr_dim, in_channels, bias=False)
        self.kernel = nn.Linear(kernel_dim, int(out_channels * in_channels / groups), bias=False)
        self.attention = attention
        if self.attention:
            key_dim = 128
            self.key_transform = nn.Linear(in_channels, key_dim)
            self.query_transform = nn.Linear(in_channels, key_dim)
            nn.init.xavier_uniform_(self.key_transform.weight)
            nn.init.xavier_uniform_(self.query_transform.weight)
            self.key_transform.bias.data.fill_(0)
            self.query_transform.bias.data.fill_(0)

        # Construct bias
        if bias:
            self.bias = nn.Parameter(torch.empty(out_channels))
            self.bias.data.zero_()
        else:
            self.register_parameter("bias", None)

        # Automatic re-initialization
        self.register_buffer("callibrated", torch.tensor(False))

    def forward(self, x, kernel_basis, fiber_kernel_basis, edge_index, edge_attr=None):
        """ """

        if edge_attr is not None:
            x_s = self.pre_linear(torch.cat([x[edge_index[0]], edge_attr], dim=-1))
        else:
            x_s = x[edge_index[0]]
        # 1. Do the spatial convolution
        kernel = self.kernel(kernel_basis)  # [num_edges, in_channels] or # [num_edges, out_channels * in_channels]
        if self.depthwise:
            message = x_s * kernel  
        else:
            message = torch.einsum('boi,bi->bo', kernel.unflatten(-1, (self.out_channels, self.in_channels)), x_s)
        if self.attention:
            keys = self.key_transform(x)
            queries = self.query_transform(x)
            d_k = keys.size(-1)
            att_logits = (keys[edge_index[0]] * queries[edge_index[1]]).sum(
                dim=-1, keepdim=True
            ) / math.sqrt(d_k)
            att_weights = scatter_softmax(att_logits, edge_index[1], x.size(0))
            message = message * att_weights
        x_1 = scatter_add(
            src=message, index=edge_index[1], dim_size=edge_index[1].max().item() + 1
        )

        # Re-callibrate the initializaiton
        if self.training and not (self.callibrated):
            self.callibrate(x.std(), x_1.std())

        # Add bias
        if self.bias is not None:
            return x_1 + self.bias
        else:
            return x_1

    def callibrate(self, std_in, std_1):
        print("Callibrating...")
        with torch.no_grad():
            self.kernel.weight.data = self.kernel.weight.data * std_in / std_1
            self.callibrated = ~self.callibrated

class ConvNext(nn.Module):
    """ """

    def __init__(
        self,
        in_channels,
        kernel_dim,
        out_channels=None,
        act=nn.GELU(),
        layer_scale=1e-6,
        widening_factor=4,
        attention=False,
    ):
        super().__init__()

        out_channels = in_channels if out_channels is None else out_channels

        self.conv = Conv(
            in_channels,
            in_channels,
            kernel_dim,
            groups=in_channels,
            attention=attention,
        )

        self.act_fn = act

        self.linear_1 = nn.Linear(in_channels, widening_factor * in_channels)
        self.linear_2 = nn.Linear(widening_factor * in_channels, out_channels)

        if layer_scale is not None:
            self.layer_scale = nn.Parameter(torch.ones(out_channels) * layer_scale)
        else:
            self.register_buffer("layer_scale", None)

        self.norm = nn.LayerNorm(in_channels)

    def forward(self, x, kernel_basis, fiber_kernel_basis, edge_index, edge_attr=None):
        """ """
        input = x
        x = self.conv(x, kernel_basis, fiber_kernel_basis, edge_index, edge_attr)
        x = self.norm(x)
        x = self.linear_1(x)
        x = self.act_fn(x)
        x = self.linear_2(x)

        if self.layer_scale is not None:
            x = self.layer_scale * x

        if x.shape == input.shape:
            x = x + input

        return x






class PolynomialFeatures(nn.Module):
    def __init__(self, degree):
        super(PolynomialFeatures, self).__init__()

        self.degree = degree

    def forward(self, x):

        polynomial_list = [x]
        for it in range(1, self.degree + 1):
            polynomial_list.append(
                torch.einsum("...i,...j->...ij", polynomial_list[-1], x).flatten(-2, -1)
            )
        return torch.cat(polynomial_list, -1)