import itertools
import math
import os

import numpy as np
import torch
from torch_geometric.nn import GINEConv, GCNConv, GATv2Conv, GATConv
import torch.nn.functional as F
from torch.nn import ModuleList

from relnet.agent.gnn.gnn_common import prefix_sum, remove_zero_padding, unpad_edge_index, setup_conv_layers


class DemandRegressorGNN(torch.nn.Module):
    def __init__(self, nf_dim, ef_dim, demand_rep, gnn_arch, hyperparams, gds_metadata):
        super().__init__()
        self.nf_dim = nf_dim
        self.ef_dim = ef_dim

        self.hyperparams = hyperparams
        self.demand_rep = demand_rep
        self.gnn_arch = gnn_arch

        self.gds_metadata = gds_metadata
        self.max_num_nodes = gds_metadata['max_num_nodes']
        self.max_num_edges = gds_metadata['max_num_edges']

        self.lf_dim = self.hyperparams['lf_dim']
        self.num_layers = self.hyperparams['num_layers']
        self.in_offset = 0

        if self.hyperparams['activation_fn'] == "relu":
            self.activation_fn = F.relu
        elif self.hyperparams['activation_fn'] == "selu":
            self.activation_fn = F.selu
        else:
            raise ValueError(f"don't know about activation function {self.hyperparams['activation_fn']}")

        fl_input_size = self.nf_dim

        self.convs = setup_conv_layers(self.gnn_arch, self.num_layers,
                                       fl_input_size,
                                       self.lf_dim,
                                       self.ef_dim,
                                       self.lf_dim,
                                       self.hyperparams,
                                       self.gds_metadata
                                       )

        self.out_layer = torch.nn.Linear(self.lf_dim, 1)

        self.print_progress = False

    def forward(self, data):
        x, edge_index, edge_attr, edge_type = data.x, data.edge_index, data.edge_attr, data.edge_type
        x = self.propagate_messages(x, edge_index, edge_attr, edge_type)
        return self.emb_to_preds(x, data)

    def propagate_messages(self, x, edge_index, edge_attr, edge_type):
        for i, conv in enumerate(self.convs):
            x = self.apply_conv(conv, x, edge_index, edge_type, edge_attr)

            if i < len(self.convs) - 1:
                x = self.activation_fn(x)
        return x

    def apply_conv(self, conv, x, mp_edge_index, mp_edge_type, mp_edge_attr):
        if self.gnn_arch in ['gcn', 'sage']:
            x = conv(x=x, edge_index=mp_edge_index)
        elif self.gnn_arch == 'gat':
            x = conv(x=x, edge_index=mp_edge_index, edge_attr=mp_edge_attr)
        elif self.gnn_arch == 'rgcn':
            x = conv(x=x, edge_index=mp_edge_index, edge_type=mp_edge_type)
        elif self.gnn_arch == 'rgat':
            x = conv(x=x, edge_index=mp_edge_index, edge_attr=mp_edge_attr, edge_type=mp_edge_type)
        else:
            raise ValueError(f"arch {self.gnn_arch} not compatible with demand relational graph.")
        return x

    def emb_to_preds(self, x, data):
        batch_num_nodes = data.state_num_nodes
        batch_preds = []
        split_embeds = torch.split(x, batch_num_nodes.tolist())

        for g_embed in split_embeds:
            if self.hyperparams['subgraph_agg'] == "sum":
                graph_rep = torch.sum(g_embed, dim=0)
            elif self.hyperparams['subgraph_agg'] == "max":
                graph_rep, _ = torch.max(g_embed, dim=0)
            else:
                raise ValueError(f"unknown subgraph agg method {self.hyperparams['subgraph_agg']}")
            pred = self.out_layer(graph_rep)
            batch_preds.append(pred)
        return batch_preds
