import torch
from torch.nn import Linear, BatchNorm1d
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool
from trainable_scattering.models.scatter import scatter_moments
from trainable_scattering.models.rbf import RBF, gaussian


class SoftmaxWavelets(torch.nn.Linear):
    def __init__(self, in_features, out_features):
        super().__init__(in_features - 2, out_features - 1, bias=False)

    def forward(self, input):
        # Attention on the middle scales
        s0, sint, smax = torch.split(input, [1, 15, 1], dim = -1)
        soft_weights = F.softmax(self.weight, dim=1) # Softmax on the input dimension
        indices = torch.argsort(torch.argmax(soft_weights, axis=1))
        reordered_weights = soft_weights[indices]
        int_scales = F.linear(sint, reordered_weights, None)
        wavelets = torch.cat([s0, int_scales], dim=-1) - torch.cat([int_scales, smax], dim=-1)
        return wavelets

    def get_scales(self):
        soft_weights = F.softmax(self.weight, dim=1) # Softmax on the input dimension
        indices = torch.argsort(torch.argmax(soft_weights, axis=1))
        reordered_weights = soft_weights[indices]
        return torch.max(soft_weights, axis=1), torch.argmax(soft_weights, axis=1).detach(), indices


class SortNet(torch.nn.Module):
    def __init__(self, in_channels, intermediate_channels, num_layers, num_wavelets, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.num_wavelets = num_wavelets
        self.innerlayers = []
        self.lin1 = Linear((5 + 4 * self.num_wavelets) * 2, intermediate_channels)
        for i in range(intermediate_channels):
            self.innerlayers.append(Linear(intermediate_channels,intermediate_channels))
        self.lin2 = Linear(intermediate_channels, out_channels)
        self.act = torch.nn.LeakyReLU()
        self.wavelets = SoftmaxWavelets(17, self.num_wavelets)


    def forward(self, data):
        x = data.x
        s01, s2 = x[:, :5, :], x[:, 5:, :]
        s2 = torch.reshape(s2, (-1, 4, 17, 2))  # [Nodes x Filters x Diffusions x Features]
        s2 = torch.transpose(s2, -2, -1)        # [Nodes x Filters x Features x Diffusions]
        # Selection over diffusions
        wavelets = self.wavelets(s2)            # [Nodes x Filters x Features x Wavelets]
        wavelets = torch.transpose(wavelets, -2, -1)
        wavelets = torch.reshape(wavelets, (-1, 4 * self.num_wavelets, 2))
        wavelets = torch.abs(wavelets)
        x = torch.cat([s01, wavelets], axis=1)
        #x = scatter_moments(x, data.batch)
        #x = self.act(x)
        x = global_mean_pool(x, data.batch)
        x = torch.reshape(x, (-1, (5 + 4 * self.num_wavelets) * 2))
        x = self.lin1(x)
        x = self.act(x)
        for currentlayer in range(self.num_layers):
            x = self.innerlayers[currentlayer](x)
            x = self.act(x)
        x = self.lin2(x)
        return x


class SortRBFNet(torch.nn.Module):
    def __init__(self, in_channels, intermediate_channels, num_wavelets, num_centres, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_wavelets = num_wavelets
        self.innerlayers = []
        self.lin1 = Linear((5 + 4 * self.num_wavelets) * 2, intermediate_channels)
        self.bn = BatchNorm1d(intermediate_channels, affine=False)
        self.rbf = RBF(intermediate_channels, num_centres, gaussian)
        self.lin2 = Linear(num_centres, out_channels)
        self.wavelets = SoftmaxWavelets(17, self.num_wavelets)

    def forward(self, data):
        x = data.x
        s01, s2 = x[:, :5, :], x[:, 5:, :]
        s2 = torch.reshape(s2, (-1, 4, 17, 2))  # [Nodes x Filters x Diffusions x Features]
        s2 = torch.transpose(s2, -2, -1)        # [Nodes x Filters x Features x Diffusions]
        # Selection over diffusions
        wavelets = self.wavelets(s2)            # [Nodes x Filters x Features x Wavelets]
        wavelets = torch.transpose(wavelets, -2, -1)
        wavelets = torch.reshape(wavelets, (-1, 4 * self.num_wavelets, 2))
        wavelets = torch.abs(wavelets)
        x = torch.cat([s01, wavelets], axis=1)
        x = global_mean_pool(x, data.batch)
        x = torch.reshape(x, (-1, (5 + 4 * self.num_wavelets) * 2))
        x = self.lin1(x)
        x = self.bn(x)
        x = self.rbf(x)
        x = self.lin2(x)
        return x
