import torch.nn as nn
from torch import cat
from torch_geometric.nn import DenseSAGEConv, dense_diff_pool
from math import ceil
import torch


class DiffPoolLayer(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, device,
                 normalize=False, lin=True):
        super().__init__()
        self.conv_layers = []
        self.bn_layers = []
        self.num_layers = num_layers
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.device = device
        for i in range(num_layers):
            if i == num_layers - 1:  # last layer
                self.conv_layers.append(DenseSAGEConv(hidden_channels, out_channels, normalize).to(self.device))
                self.bn_layers.append(nn.BatchNorm1d(out_channels).to(self.device))
            else:
                if i == 0:  # first layer
                    self.conv_layers.append(DenseSAGEConv(in_channels, hidden_channels, normalize).to(self.device))
                else:  # middle layers
                    self.conv_layers.append(DenseSAGEConv(hidden_channels, hidden_channels, normalize).to(self.device))
                self.bn_layers.append(nn.BatchNorm1d(hidden_channels).to(self.device))

        if lin is True:
            self.lin = nn.Linear(2 * hidden_channels + out_channels, out_channels).to(self.device)
        else:
            self.lin = None

    def bn(self, i, x):

        batch_size, num_nodes, num_channels = x.size()

        x = x.view(-1, num_channels)
        x = self.bn_layers[i](x)
        x = x.view(batch_size, num_nodes, num_channels)
        return x

    def forward(self, x, adj, mask=None):

        outputs = [x]
        for i in range(self.num_layers):
            outputs.append(self.bn(i, (self.conv_layers[i](outputs[-1], adj, mask)).relu()))

        x = cat(outputs[1:], dim=-1)

        if self.lin is not None:
            x = self.lin(x).relu()

        return x


class DiffPool(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 output_size,
                 num_layers_gnn,
                 max_nodes, additional_losses, device):
        super().__init__()
        if device.lower() == "gpu" or "cuda" in device:
            device = torch.device("cuda:0")
        else:
            device = torch.device("cpu")
        num_nodes = ceil(0.25 * max_nodes)
        self.gnn1_pool = DiffPoolLayer(input_size, hidden_size, num_nodes, num_layers_gnn, device)
        self.gnn1_embed = DiffPoolLayer(input_size, hidden_size, hidden_size, num_layers_gnn, device, lin=False)

        num_nodes = ceil(0.25 * num_nodes)
        self.gnn2_pool = DiffPoolLayer(3 * hidden_size, hidden_size, num_nodes, num_layers_gnn, device)
        self.gnn2_embed = DiffPoolLayer(3 * hidden_size, hidden_size, hidden_size, num_layers_gnn, device, lin=False)

        self.gnn3_embed = DiffPoolLayer(3 * hidden_size, hidden_size, hidden_size, num_layers_gnn, device, lin=False)

        self.lin1 = torch.nn.Linear(3 * hidden_size, hidden_size).to(device)
        self.lin2 = torch.nn.Linear(hidden_size, output_size).to(device)
        self.additional_losses = additional_losses

    def forward(self, batch):
        x = batch.x
        adj = batch.adj
        mask = batch.mask
        s = self.gnn1_pool(x, adj, mask)
        x = self.gnn1_embed(x, adj, mask)

        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask)

        s = self.gnn2_pool(x, adj)
        x = self.gnn2_embed(x, adj)

        x, adj, l2, e2 = dense_diff_pool(x, adj, s)

        x = self.gnn3_embed(x, adj)

        x = x.mean(dim=1)
        x = self.lin1(x).relu()
        x = self.lin2(x)
        if x.ndim == 3:
            assert (x.shape[-1] == 1)
            x = x.squeeze(-1)
        dictionary = {"logits": x}

        if self.additional_losses:
            dictionary["loss_l"] = l1 + l2
            dictionary["loss_e"] = e1 + e2
        else:
            dictionary["loss_l"] = 0
            dictionary["loss_e"] = 0
        return dictionary

    @staticmethod
    def kwargs(cfg, preparator):
        my_dict = {}
        my_dict['input_size'] = preparator.x_dim()
        my_dict['hidden_size'] = cfg.model.dim_inner
        if cfg.model.dim_latent > 0:
            my_dict['output_size'] = cfg.model.dim_latent
        else:
            my_dict['output_size'] = preparator.label_dim()

        my_dict['num_layers_gnn'] = cfg.model.num_layers
        my_dict['max_nodes'] = preparator.max_nodes
        my_dict['additional_losses'] = cfg.model.additional_losses
        my_dict['device'] = cfg.device
        return my_dict
