import itertools
import math

import numpy as np
import torch
from torch_geometric.nn import GINEConv, GCNConv, GATv2Conv, GATConv, RGCNConv, RGATConv, FastRGCNConv, SAGEConv
import torch.nn.functional as F
from torch.nn import ModuleList


def setup_conv_layers(gnn_arch, num_layers, fl_input_size, nl_input_size, ef_dim, lf_dim, hyperparams, gds_metadata):
    convs = ModuleList()
    if gnn_arch == 'gine':
        ## todo: still need to adjust to allow use of another activation (e.g. selu)
        nn = torch.nn.Sequential(torch.nn.Linear(fl_input_size, lf_dim), torch.nn.ReLU())
        convs.append(GINEConv(nn, edge_dim=ef_dim, train_eps=True))

        for i in range(0, num_layers - 1):
            nn = torch.nn.Sequential(torch.nn.Linear(nl_input_size, lf_dim), torch.nn.ReLU())
            convs.append(GINEConv(nn, edge_dim=ef_dim, train_eps=True))

    elif gnn_arch == 'gcn':
        convs.append(GCNConv(fl_input_size, lf_dim))
        for i in range(0, num_layers - 1):
            convs.append(GCNConv(nl_input_size, lf_dim))

    elif gnn_arch == 'sage':
        convs.append(SAGEConv(fl_input_size, lf_dim))
        for i in range(0, num_layers - 1):
            convs.append(SAGEConv(nl_input_size, lf_dim))

    elif gnn_arch == 'gat':
        # TODO: consider switching to GATv2; needs pytorch-geometric>=2.0.3.
        num_heads = 1

        convs.append(GATConv(fl_input_size, lf_dim, heads=num_heads, edge_dim=ef_dim))
        for i in range(0, num_layers - 1):
            convs.append(GATConv(nl_input_size, lf_dim, heads=1, edge_dim=ef_dim))

    elif gnn_arch == 'rgcn':
        num_relations = find_num_relations(hyperparams, gds_metadata)
        conv_class = FastRGCNConv
        # conv_class = RGCNConv

        convs.append(conv_class(fl_input_size, lf_dim, num_relations, root_weight=False))
        for i in range(0, num_layers - 1):
            convs.append(conv_class(nl_input_size, lf_dim, num_relations, root_weight=False))

    elif gnn_arch == 'rgat':
        num_relations = find_num_relations(hyperparams, gds_metadata)

        convs.append(RGATConv(fl_input_size, lf_dim, num_relations, edge_dim=ef_dim))
        for i in range(0, num_layers - 1):
            convs.append(RGATConv(nl_input_size, lf_dim, num_relations, edge_dim=ef_dim))


    else:
        raise ValueError(f"arch {gnn_arch} not known.")

    return convs

def find_num_relations(hyperparams, gds_metadata):
    if hyperparams['edge_coloring'] == 'uniform':
        return 1
    elif hyperparams['edge_coloring'] == 'unique_edge':
        return gds_metadata['max_num_edges']


def prefix_sum(tensor):
    out = tensor.new_empty((tensor.size(0) + 1,) + tensor.size()[1:])
    out[0] = 0
    torch.cumsum(tensor, 0, out=out[1:])
    return out


def remove_zero_padding(x, edge_index, num_graphs, batch_num_nodes, batch_num_edges, max_num_nodes):
    true_node_idxes = [list(range((i * max_num_nodes), (i * max_num_nodes) + batch_num_nodes[i].item()))
                       for i in range(num_graphs)]
    true_idxes_flat = torch.LongTensor(list(itertools.chain(*true_node_idxes)))
    x = x[true_idxes_flat]
    true_edge_index = unpad_edge_index(edge_index, batch_num_nodes, batch_num_edges, max_num_nodes)

    return x, true_edge_index


def unpad_edge_index(edge_index, batch_num_nodes, batch_num_edges, max_num_nodes):
    out = edge_index.clone()

    batch_nodes_missing = max_num_nodes - batch_num_nodes
    pyg_increment = prefix_sum(batch_nodes_missing)

    ne_prefix_sum = prefix_sum(batch_num_edges)
    for i in range(ne_prefix_sum.size(dim=0) - 1):
        idx_start = ne_prefix_sum[i]
        idx_end = ne_prefix_sum[i + 1]

        out[:, idx_start:idx_end] -= pyg_increment[i]

    return out