""" The attention scatter everyone loves, but now with a sorting step to ensure that the higher diffusion level  is always subtracted from the lower one."""
import numpy as np
import torch
from torch.nn import Linear
from torch_scatter import scatter_mean
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree


# TODO (anonymous) this is pretty inefficient using a for loop, is there a faster way to do this?
# TODO (anonymous) only scatter higher diffusions using masking
def scatter_moments(graph, batch_indices, moments_returned=4):
    """ Compute specified statistical coefficients for each feature of each graph passed. The graphs expected are disjoint subgraphs within a single graph, whose feature tensor is passed as argument "graph."
        "batch_indices" connects each feature tensor to its home graph.
        "Moments_returned" specifies the number of statistical measurements to compute. If 1, only the mean is returned. If 2, the mean and variance. If 3, the mean, variance, and skew. If 4, the mean, variance, skew, and kurtosis.
        The output is a dictionary. You can obtain the mean by calling output["mean"] or output["skew"], etc."""
    # Step 1: Aggregate the features of each mini-batch graph into its own tensor
    graph_features = [torch.zeros(0) for i in range(torch.max(batch_indices) + 1)]
    for i, node_features in enumerate(
            graph
    ):  # Sort the graph features by graph, according to batch_indices. For each graph, create a tensor whose first row is the first element of each feature, etc.
        #        print("node features are",node_features)
        if (
                len(graph_features[batch_indices[i]]) == 0
        ):  # If this is the first feature added to this graph, fill it in with the features.
            graph_features[batch_indices[i]] = node_features.view(
                -1, 1, 1
            )  # .view(-1,1,1) changes [1,2,3] to [[1],[2],[3]],so that we can add each column to the respective row.
        else:
            graph_features[batch_indices[i]] = torch.cat(
                (graph_features[batch_indices[i]], node_features.view(-1, 1, 1)), dim=1
            )  # concatenates along columns

    statistical_moments = {"mean": torch.zeros(0)}
    if moments_returned >= 2:
        statistical_moments["variance"] = torch.zeros(0)
    if moments_returned >= 3:
        statistical_moments["skew"] = torch.zeros(0)
    if moments_returned >= 4:
        statistical_moments["kurtosis"] = torch.zeros(0)

    for data in graph_features:
        data = data.squeeze()

        def m(i):  # ith moment, computed with derivation data
            return torch.sum(deviation_data ** i, axis=1) / torch.sum(
                torch.ones(data.shape), axis=1
            )

        mean = torch.sum(data, axis=1) / torch.sum(torch.ones(data.shape), axis=1)
        if moments_returned >= 1:
            statistical_moments["mean"] = torch.cat(
                (statistical_moments["mean"], mean[None, ...]), dim=0
            )

        # produce matrix whose every row is data row - mean of data row
        tuple_collect = []
        for a in mean:
            mean_row = torch.ones(data.shape[1]) * a
            tuple_collect.append(
                mean_row[None, ...]
            )  # added dimension to concatenate with differentiation of rows
        # each row contains the deviation of the elements from the mean of the row
        deviation_data = data - torch.cat(tuple_collect, axis=0)

        # variance: difference of u and u mean, squared element wise, summed and divided by n-1
        variance = m(2)
        if moments_returned >= 2:
            statistical_moments["variance"] = torch.cat(
                (statistical_moments["variance"], variance[None, ...]), dim=0
            )

        # skew: 3rd moment divided by cubed standard deviation (sd = sqrt variance), with correction for division by zero (inf -> 0)
        skew = m(3) / (variance ** (3 / 2))
        skew[
            skew > 1000000000000000
            ] = 0  # multivalued tensor division by zero produces inf
        skew[
            skew != skew
            ] = 0  # single valued division by 0 produces nan. In both cases we replace with 0.
        if moments_returned >= 3:
            statistical_moments["skew"] = torch.cat(
                (statistical_moments["skew"], skew[None, ...]), dim=0
            )

        # kurtosis: fourth moment, divided by variance squared. Using Fischer's definition to subtract 3 (default in scipy)
        kurtosis = m(4) / (variance ** 2) - 3
        kurtosis[kurtosis > 1000000000000000] = -3
        kurtosis[kurtosis != kurtosis] = -3
        if moments_returned >= 4:
            statistical_moments["kurtosis"] = torch.cat(
                (statistical_moments["kurtosis"], kurtosis[None, ...]), dim=0
            )
    # Concatenate into one tensor (anonymous)
    #    statistical_moments = torch.cat([v for k,v in statistical_moments.items()], axis=1)
    statistical_moments = torch.cat([statistical_moments['mean'], statistical_moments['variance']], axis=1)
    return statistical_moments


class LazyLayer(torch.nn.Module):
    """ Currently a single elementwise multiplication with one laziness parameter per
    channel. this is run through a softmax so that this is a real laziness parameter
    """

    def __init__(self, n):
        super().__init__()
        self.weights = torch.nn.Parameter(torch.Tensor(2, n))

    def forward(self, x, propogated):
        inp = torch.stack((x, propogated), dim=1)
        s_weights = torch.nn.functional.softmax(self.weights, dim=0)
        return torch.sum(inp * s_weights, dim=-2)

    def reset_parameters(self):
        torch.nn.init.ones_(self.weights)


class Diffuse(MessagePassing):
    """ Implements low pass walk with optional weights
    """

    def __init__(
            self, in_channels, out_channels, trainable_laziness=False, fixed_weights=True
    ):
        super().__init__(aggr="add")  # "Add" aggregation.
        assert in_channels == out_channels
        self.trainable_laziness = trainable_laziness
        self.fixed_weights = fixed_weights
        if trainable_laziness:
            self.lazy_layer = LazyLayer(in_channels)
        if not self.fixed_weights:
            self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        # edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        # turn off this step for simplicity
        if not self.fixed_weights:
            x = self.lin(x)

        # Step 3: Compute normalization
        row, col = edge_index
        deg = degree(row, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-1)

        # Normalize to lazy random walk matrix not symmetric as in GCN (kipf 2017)
        # norm = deg_inv_sqrt[col]# * deg_inv_sqrt[col]
        norm = deg_inv_sqrt[row]  # * deg_inv_sqrt[col]

        # Step 4-6: Start propagating messages.
        propogated = self.propagate(
            edge_index, size=(x.size(0), x.size(0)), x=x, norm=norm
        )
        if not self.trainable_laziness:
            return 0.5 * (x + propogated)
        return self.lazy_layer(x, propogated)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        x_j = x_j.transpose(0, 1)
        to_return = norm.view(-1, 1) * x_j
        return to_return.transpose(0, 1)

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
        # Step 6: Return new node embeddings.
        return aggr_out


def feng_filters():
    tmp = np.arange(16).reshape(4, 4)
    results = [4]
    for i in range(2, 4):
        for j in range(0, i):
            results.append(4 * i + j)
    return results


class Scatter(torch.nn.Module):
    def __init__(self, in_channels, trainable_laziness=False):
        super().__init__()
        self.in_channels = in_channels
        self.trainable_laziness = trainable_laziness
        self.diffusion_layer1 = Diffuse(in_channels, in_channels, trainable_laziness)
        self.diffusion_layer2 = Diffuse(
            4 * in_channels, 4 * in_channels, trainable_laziness
        )
        # create the machinery for attention.
        self.attention = torch.nn.MultiheadAttention(in_channels,
                                                     1)  # initialize an attention mechanism with EMBED DIMENSION = in_channels and NUM HEADS = 1
        self.attention2 = torch.nn.MultiheadAttention(4 * in_channels, 1)  # For the second round of diffusion.
        self.query_trainer = torch.nn.Linear(in_channels,
                                             in_channels)  # create the linear network to learn the queries for attention
        self.query_trainer2 = torch.nn.Linear(4 * in_channels, 4 * in_channels)
        self.count = 0

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        s0 = x[:, :, None]
        avgs = [s0]
        for i in range(16):
            avgs.append(self.diffusion_layer1(avgs[-1], edge_index))
        for j in range(len(avgs)):
            avgs[j] = avgs[j][None, :, :,
                      :]  # add an extra dimension to each tensor to avoid data loss while concatenating TODO: is there a faster way to do this?
        # Combine the diffusion levels into a single tensor. This will become our key and value matrix
        diffusion_levels = torch.squeeze(torch.cat(avgs))
        # compute attention-driven weighting of the diffusion levels
        # to compute more weightings, simply change the '3' below.
        query = self.query_trainer(torch.ones(3, len(x),
                                              self.in_channels))  # pass an all ones tensor through the linear network to learn the query
        diffusion_blends, _ = self.attention(query, diffusion_levels, diffusion_levels)

        self.count += 1
        if self.count % 1000 == 0:
            print("The attention weights are ", _)

        # construct each filter as the weighted sum of different diffusion levels.
        # add an extra dimension for concatenation.
        filter1 = x[:, :, None] - diffusion_blends[0][:, :, None]  # I - (some scale of diffusions)
        filter2 = diffusion_blends[0][:, :, None] - diffusion_blends[1][:, :, None]
        filter3 = diffusion_blends[1][:, :, None] - diffusion_blends[2][:, :, None]
        filter4 = diffusion_blends[2][:, :, None]  # to preserve the telescopic sum to 1
        # ... we could easily make more filters...
        s0 = x[:, :, None]
        s1 = torch.abs(torch.cat([filter1, filter2, filter3, filter4], dim=-1))
        # repeat the above with 4 times as many channels as previously
        avgs2 = [s1]
        for i in range(16):
            avgs2.append(self.diffusion_layer2(avgs2[-1], edge_index))
        for j in range(len(avgs2)):
            avgs2[j] = avgs2[j][None, :, :]
        diffusion_levels2 = torch.squeeze(torch.cat(avgs2))
        query2 = self.query_trainer2(torch.ones(3, len(s1), self.in_channels * 4))
        reshaped_dls = diffusion_levels2.view(diffusion_levels2.shape[0], diffusion_levels2.shape[1],
                                              -1)  # combine the final two dimensions for attention
        diffusion_blends, _ = self.attention2(query2, reshaped_dls, reshaped_dls)
        diffusion_blends = diffusion_blends.view(-1, diffusion_levels2[0].shape[0], diffusion_levels2[0].shape[1],
                                                 diffusion_levels2[0].shape[
                                                     2])  # reshape diffusion blends to fit the shape of s1
        filter1 = s1 - diffusion_blends[0]  # I - (some scale of diffusions)
        filter2 = diffusion_blends[0] - diffusion_blends[1]
        filter3 = diffusion_blends[1] - diffusion_blends[2]
        filter4 = diffusion_blends[2]  # to preserve the telescopic sum to 1
        s2 = torch.abs(torch.cat([filter1, filter2, filter3, filter4], dim=1))

        s2_reshaped = torch.reshape(s2, (-1, self.in_channels, 4))
        s2_swapped = torch.reshape(torch.transpose(s2_reshaped, 1, 2), (-1, 16, self.in_channels))
        s2 = s2_swapped[:, feng_filters()]

        x = torch.cat([s0, s1], dim=2)
        x = torch.transpose(x, 1, 2)
        x = torch.cat([x, s2], dim=1)
        # x = scatter_mean(x, batch, dim=0)
        if hasattr(data, 'batch'):
            x = scatter_moments(x, data.batch, 4)
        else:
            x = scatter_moments(x, torch.zeros(data.x.shape[0], dtype=torch.int32), 4)
        #        print("self in channels",self.in_channels)
        #        print('x returned shape', x.shape)
        return x

    def out_shape(self):
        # x * 2 moments * in
        return 11 * 2 * self.in_channels

