import torch
import torch.nn as nn
import util as ut
import numpy as np
from torch_geometric.nn import GINConv, SAGEConv
from torch_geometric.data import Data, Batch

class Generator(nn.Module):
    """Generator network."""
    def __init__(self, conv_dims, z_dim, vertexes, edges, nodes_embedding, dropout):
        super(Generator, self).__init__()

        self.vertexes = vertexes # # of vertexes
        self.edges = edges # edge types
        self.nodes = nodes_embedding # node embedding size

        layers = []
        for c0, c1 in zip([z_dim]+conv_dims[:-1], conv_dims):
            layers.append(nn.Linear(c0, c1))
            layers.append(nn.Tanh())
            layers.append(nn.Dropout(p=dropout, inplace=True))
        self.layers = nn.Sequential(*layers)

        self.nodes_layer = nn.Linear(conv_dims[-1], vertexes * nodes_embedding)
        self.dropoout = nn.Dropout(p=dropout)

    def forward(self, x):
        output = self.layers(x)

        nodes_logits = self.nodes_layer(output)
        nodes_logits = self.dropoout(nodes_logits.view(-1,self.vertexes,self.nodes))

        return nodes_logits


class Discriminator(nn.Module):
    """Discriminator network."""
    def __init__(self,conv_dim, n_node_features, num_strings, stremb_dim, net_type, device, dropout, h_dim=64):
        super(Discriminator, self).__init__()
        graph_conv_dim, aux_dim, linear_dim = conv_dim
        self.num_strings, self.stremb_dim = num_strings, stremb_dim
        self.str_embedding_layer = torch.nn.Embedding(self.num_strings, self.stremb_dim)
        self.layers = nn.ModuleList()
        self.num_layers = 1
        self.device = device
        d1, d2 = n_node_features, h_dim
        # graph conv layers
        for _ in range(self.num_layers):
            if net_type == 'GIN':
                l = GINConv(nn.Sequential(nn.Linear(d1, d2), nn.ReLU(), nn.Linear(d2, d2)))
            elif net_type == 'GraphSAGE':
                l = SAGEConv(d1, d2)
            else:
                raise Exception('No such GNN type: ', net_type)
            self.layers.append(l)
            d1, d2 = h_dim, h_dim

        #last layer
        self.output_layer = nn.Linear(linear_dim[-1], 1)
        #self.output_layer2 = nn.Linear(linear_dim[-1], ) # for each node

    def forward(self, activatation=None, needs_embedding=False, **kwargs):
        if needs_embedding:
            batch_data = kwargs['batch_data']
            bt, edge_list = self.convert_to_batch(batch_data)
            tmp = bt.x
            feat = ut.batched_target_att_to_vec(tmp.to(self.device), self.str_embedding_layer)
            bt.x = feat
        else:
            edge_list, data = kwargs['e_list'], kwargs['data']
            data_list = [Data(x=torch.squeeze(i).float(), edge_index=j.t().contiguous()) for i, j in zip(data, edge_list)]
            bt = Batch.from_data_list(data_list).to(self.device)
        x, edge_index, batch = bt.x, bt.edge_index, bt.batch
        for layer in self.layers:
            x = layer(x, edge_index)
        output = self.output_layer(x)
        output = activatation(output) if activatation is not None else output

        return output, edge_list

    def convert_to_batch(self, batch_data_list):
        n_list = []
        e_list = []
        for i,d_list in enumerate(batch_data_list):
            data = d_list[0]
            edge_list = d_list[1]
            edges = torch.from_numpy(np.array(edge_list)[:,0:2])
            e_list.append(edges)
            n_list.append(torch.from_numpy(data))
        node = torch.stack(n_list,dim=0)
        data_list = [Data(x=torch.squeeze(i).float(), edge_index=j.t().contiguous()) for i, j in zip(node, e_list)]
        bt = Batch.from_data_list(data_list).to(self.device)
        return bt, e_list
