""" Reworked learnable scale construction using the attention mechanism. """
import numpy as np
import torch
from torch.nn import Linear
from torch_scatter import scatter_mean
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree

from trainable_scattering.models.scatter import scatter_moments


class LazyLayer(torch.nn.Module):
    """ Currently a single elementwise multiplication with one laziness parameter per
    channel. this is run through a softmax so that this is a real laziness parameter
    """

    def __init__(self, n):
        super().__init__()
        self.weights = torch.nn.Parameter(torch.Tensor(2, n))

    def forward(self, x, propogated):
        inp = torch.stack((x, propogated), dim=1)
        s_weights = torch.nn.functional.softmax(self.weights, dim=0)
        return torch.sum(inp * s_weights, dim=-2)

    def reset_parameters(self):
        torch.nn.init.ones_(self.weights)


class Diffuse(MessagePassing):
    """ Implements low pass walk with optional weights
    """

    def __init__(
        self, in_channels, out_channels, trainable_laziness=False, fixed_weights=True
    ):
        super().__init__(aggr="add", node_dim=-3)  # "Add" aggregation.        assert in_channels == out_channels
        self.trainable_laziness = trainable_laziness
        self.fixed_weights = fixed_weights
        if trainable_laziness:
            self.lazy_layer = LazyLayer(in_channels)
        if not self.fixed_weights:
            self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        # turn off this step for simplicity
        if not self.fixed_weights:
            x = self.lin(x)

        # Step 3: Compute normalization
        row, col = edge_index
        deg = degree(row, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-1)

        # Normalize to lazy random walk matrix not symmetric as in GCN (kipf 2017)
        # norm = deg_inv_sqrt[col]# * deg_inv_sqrt[col]
        norm = deg_inv_sqrt[row]  # * deg_inv_sqrt[col]

        # Step 4-6: Start propagating messages.
        propogated = self.propagate(
            edge_index, size=(x.size(0), x.size(0)), x=x, norm=norm
        )
        if not self.trainable_laziness:
            return 0.5 * (x + propogated)
        return self.lazy_layer(x, propogated)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        x_j = x_j.transpose(0, 1)
        to_return = norm.view(-1, 1) * x_j
        return to_return.transpose(0, 1)

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
        # Step 6: Return new node embeddings.
        return aggr_out

def feng_filters():
    tmp = np.arange(16).reshape(4,4)
    results = [4]
    for i in range(2, 4):
        for j in range(0, i):
            results.append(4*i+j)
    return results

class Scatter(torch.nn.Module):
    def __init__(self, in_channels, trainable_laziness=False):
        super().__init__()
        self.in_channels = in_channels
        self.trainable_laziness = trainable_laziness
        self.diffusion_layer1 = Diffuse(in_channels, in_channels, trainable_laziness)
        self.diffusion_layer2 = Diffuse(
            4 * in_channels, 4 * in_channels, trainable_laziness
        )
        # create the machinery for attention.
        self.attention = torch.nn.MultiheadAttention(in_channels,1) # initialize an attention mechanism with EMBED DIMENSION = in_channels and NUM HEADS = 1
        self.attention2 = torch.nn.MultiheadAttention(4*in_channels,1) # For the second round of diffusion.
        self.query_trainer = torch.nn.Linear(in_channels,in_channels) # create the linear network to learn the queries for attention
        self.query_trainer2 = torch.nn.Linear(4*in_channels, 4*in_channels)
        self.count = 0
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        s0 = x[:,:,None]
        avgs = [s0]
        for i in range(16):
            avgs.append(self.diffusion_layer1(avgs[-1], edge_index))
        for j in range(len(avgs)):
            avgs[j] = avgs[j][None, :, :, :]  # add an extra dimension to each tensor to avoid data loss while concatenating TODO: is there a faster way to do this?
        # Combine the diffusion levels into a single tensor. This will become our key and value matrix
        diffusion_levels = torch.squeeze(torch.cat(avgs))
        # compute attention-driven weighting of the diffusion levels
        # to compute more weightings, simply change the '3' below.
        query = self.query_trainer(torch.ones(3,len(x),self.in_channels).to(x)) # pass an all ones tensor through the linear network to learn the query
        diffusion_blends, _ = self.attention(query,diffusion_levels,diffusion_levels)
        
        self.count += 1
        #if self.count % 1000 == 0:
        #    print("The attention weights are ",_)
        
        # construct each filter as the weighted sum of different diffusion levels.
        # add an extra dimension for concatenation.
        filter1 = x[:, :, None] - diffusion_blends[0][:, :, None]  # I - (some scale of diffusions)
        filter2 = diffusion_blends[0][:, :, None] - diffusion_blends[1][:, :, None]
        filter3 = diffusion_blends[1][:, :, None] - diffusion_blends[2][:, :, None]
        filter4 = diffusion_blends[2][:, :, None]  # to preserve the telescopic sum to 1
        # ... we could easily make more filters...
        s0 = x[:, :, None]
        s1 = torch.abs(torch.cat([filter1, filter2, filter3, filter4], dim=-1))
        # repeat the above with 4 times as many channels as previously
        avgs2 = [s1]
        for i in range(16):
            avgs2.append(self.diffusion_layer2(avgs2[-1], edge_index))
        for j in range(len(avgs2)):
            avgs2[j] = avgs2[j][None, :, :]
        diffusion_levels2 = torch.squeeze(torch.cat(avgs2))
        query2 = self.query_trainer2(torch.ones(3, len(s1), self.in_channels * 4).to(x))
        reshaped_dls = diffusion_levels2.view(diffusion_levels2.shape[0], diffusion_levels2.shape[1],-1)  # combine the final two dimensions for attention
        diffusion_blends, _ = self.attention2(query2, reshaped_dls, reshaped_dls)
        diffusion_blends = diffusion_blends.view(-1, diffusion_levels2[0].shape[0], diffusion_levels2[0].shape[1],diffusion_levels2[0].shape[2])  # reshape diffusion blends to fit the shape of s1
        filter1 = s1 - diffusion_blends[0]  # I - (some scale of diffusions)
        filter2 = diffusion_blends[0] - diffusion_blends[1]
        filter3 = diffusion_blends[1] - diffusion_blends[2]
        filter4 = diffusion_blends[2]  # to preserve the telescopic sum to 1
        s2 = torch.abs(torch.cat([filter1, filter2, filter3, filter4], dim=1))

        s2_reshaped = torch.reshape(s2, (-1, self.in_channels, 4))
        s2_swapped = torch.reshape(torch.transpose(s2_reshaped, 1, 2), (-1, 16, self.in_channels))
        s2 = s2_swapped[:, feng_filters()]

        x = torch.cat([s0, s1], dim=2)
        x = torch.transpose(x, 1, 2)
        x = torch.cat([x, s2], dim=1)
        # x = scatter_mean(x, batch, dim=0)
        if hasattr(data, 'batch'):
            x = scatter_moments(x, data.batch, 2)
        else:
            x = scatter_moments(x, torch.zeros(data.x.shape[0], dtype=torch.int32), 2)
        return x

    def out_shape(self):
        # x * 2 moments * in
        return 11 * 2 * self.in_channels

