import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from models.classifier import Classifier
import models.torch_utils


class DynamicEdgeConv(gnn.MessagePassing):
    def __init__(self, model, aggr="max", k=4, **kwargs):
        super().__init__(aggr=aggr)
        self.model = model
        self.k = k

    def message(self, x_i, x_j):
        E = x_i.shape[0]
        shape = (E, *self.shape[1:])
        x_i = x_i.reshape(shape)
        x_j = x_j.reshape(shape)

        model_out = self.model(torch.cat([x_i, x_j - x_i], dim=1))
        self.shape = model_out.shape

        return model_out.reshape(E, -1)

    def forward(self, x, batch=None):
        batch_size = x.shape[0]

        edge_index = gnn.knn_graph(x.reshape(batch_size, -1), self.k, batch, loop=False, flow=self.flow)
        self.shape = x.shape

        propagated = self.propagate(edge_index, x=x.reshape(x.shape[0], -1))
        shape = (batch_size, *self.shape[1:])
        return propagated.reshape(shape)


class FixedLeNetGNN(Classifier):
    def __init__(self, N_class=10, resolution=(1, 28, 28), **kwargs):
        """
        Initialize classifier.

        :param N_class: number of classes to classify
        :type N_class: int
        :param resolution: resolution (assumed to be square)
        :type resolution: int
        """

        assert resolution[0] == 1
        assert resolution[1] == 28
        assert resolution[2] == 28

        super().__init__(N_class, resolution, **kwargs)

        blocks = [
            nn.Sequential(
                nn.Conv2d(resolution[0] * 2, 32, 5, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
            ),
            nn.Sequential(
                nn.Conv2d(32 * 2, 64, 5, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(2,2),
            ),
            nn.Sequential(
                models.torch_utils.Flatten(),
                nn.Linear(7 * 7 * 64 * 2, 1024),
                nn.ReLU(),
                nn.Linear(1024, self.N_class),
            )
        ]

        for i, b in enumerate(blocks):
            self.append_layer(str(i), DynamicEdgeConv(b, **kwargs))

