import logging
import re

import torch
from torch import nn
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

from utils.data_utils import batched_index_select

class SAGEConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SAGEConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = nn.Linear(in_channels, out_channels)
        self.lin_update = nn.Linear(out_channels + in_channels, out_channels)
    def forward(self, x, edge_index, edge_weights=None):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        edge_index, _ = pyg_utils.remove_self_loops(edge_index)
        # Transform node feature matrix.
        #self_x = self.lin_self(x)
        if edge_weights is None:
          edge_weights = torch.ones((edge_index.size()[1], 1))
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_weights=edge_weights)
    def message(self, x_i, x_j, edge_index, size, edge_weights):
        # Compute messages
        # x_j has shape [E, out_channels]
        return torch.mul(edge_weights, self.lin(x_j))
    def update(self, aggr_out, x):
        # aggr_out has shape [N, out_channels]
        out = torch.cat((aggr_out, x), dim=1)
        out = self.lin_update(out)
        return out

class SiameseGNN(nn.Module):
    
    def __init__(self, search_layers: list, query_layers: list, devices=[0]):
        
        super().__init__()
        
        # build search network
        self.search_network = self.build_network(search_layers)
        
        # build query network
        self.query_network = self.build_network(query_layers)
        
        # build classification network
        in_features = query_layers[-1]["out_dim"] + search_layers[-1]["out_dim"]
        self.final_hidden =  nn.Linear(in_features=in_features,
                                       out_features=256)
        self.final_class =  nn.Linear(in_features=256,
                                      out_features=1)

    
    def build_network(self, layers):
        """
        """
        network = nn.ModuleList()
        for idx, layer in enumerate(layers):
            msg_pass = SAGEConv(in_channels=layer['in_dim'], out_channels=layer['out_dim'])
            network.append(msg_pass)

        return network
    
    def forward(self, batch):
        search_x, search_edge_index, search_centers = (
                batch["search_graph"].feat,
                batch["search_graph"].edge_index, 
                batch["search_graph"].center_index)
        query_x, query_edge_index, query_centers = (
                batch["query_graph"].feat,
                batch["query_graph"].edge_index,
                batch["query_graph"].center_index)
        # forward pass on search
        for idx, layer in enumerate(self.search_network):
            search_x = layer(search_x, search_edge_index)
            
            if idx != len(self.search_network) - 1:
                search_x = nn.functional.relu(search_x)
            # TODO: add dropout
        
        # forward pass on query
        for idx, layer in enumerate(self.query_network): 
            query_x = layer(query_x, query_edge_index)
            
            if idx != len(self.query_network) - 1:
                query_x = nn.functional.relu(query_x)

        #search_x = search_x.reshape(batch["search_graph"].num_graphs, -1, search_x.shape[-1])
        # select center node 
        # search_indices: num_graphs x num_nodes_to_select_per_graph (currently only 1 center)
        # search_indices = batch["mapping"][:, :, 0]
        # get the embeddings for the centers for all search graphs in this minibatch
        #search_targets = batched_index_select(search_x, 1, search_indices)
        search_targets = search_x[search_centers, :]
        # (num_graphs x num_selected_per_graph) x emb_dim
        #search_targets = search_targets.reshape(-1, search_targets.shape[-1])
        
        #query_x = query_x.reshape(batch["query_graph"].num_graphs, -1, query_x.shape[-1])
        # select center node
        # query_indices = batch["mapping"][:, :, 1]
        # query_indices = query_centers.unsqueeze(0)
        # query_targets = batched_index_select(query_x, 1, query_centers)
        query_targets = query_x[query_centers, :]
        query_targets = query_targets.reshape(-1, query_targets.shape[-1])
        
        out = self.final_hidden(torch.cat((search_targets,  query_targets), dim=-1))
        out = nn.functional.relu(out)
        out = self.final_class(out)

        return {
            "out": out,
            "search_emb":  search_x, 
            "query_emb": query_x
        }

    def inference(self, search_embs, query_embs):
        ''' assume 1 query, 1 search graph
        Args:
            search_embs: num_nodes x emb_dim
        '''
        num_search_nodes = search_embs.size(0)
        num_query_nodes = query_embs.size(0)
        # 1 x num_nodes x emb_dim
        search_embs = torch.unsqeeze(search_embs, dim=0)
        # num_nodes x 1 x emb_dim
        query_embs = torch.unsqeeze(query_embs, dim=1)
        search_embs = search_embs.repeat(num_query_nodes, 1, 1)
        query_embs = query_embs.repeat(1, num_query_nodes, 1)
        combined = torch.cat([search_embs, query_embs], dim=-1)

        # MLP approach
        out = self.final_hidden(combined)
        out = nn.functional.relu(out)
        out = self.final_class(out)
        return out
    
    def loss(self, outputs, targets):
        """
        """
        targets = targets.unsqueeze(1).expand_as(outputs).float() # .unsqueeze(1)
        loss = torch.mean(torch.binary_cross_entropy_with_logits(outputs, targets))
        return loss

    def predict(self, batch):
        out = self.forward(batch)
        return torch.sigmoid(out["out"])

    '''
    def init_metrics(self):
        def output_to_class(output, threshold=0.5, activation=torch.sigmoid):
            outputs, targets = output
            probs = activation(outputs)
            preds = (probs > threshold).int()
            targets = targets.unsqueeze(1).unsqueeze(1).expand(preds.shape)
            preds, targets = preds.reshape(-1), targets.reshape(-1)  # flatten so all nodes
            return (preds.contiguous(), targets.contiguous()) 

        def output_to_probs(output, activation=torch.sigmoid):
            outputs, targets = output
            probs = activation(outputs)
            targets = targets.unsqueeze(1).unsqueeze(1).expand(probs.shape) 
            probs, targets = probs.reshape(-1), targets.reshape(-1)  # flatten so all nodes
            return (probs, targets)

        metrics = {
            "accuracy": ignite.metrics.Accuracy(output_transform=output_to_class),
            "loss": ignite.metrics.Loss(self.loss),
            "precision": ignite.metrics.Precision(output_transform=output_to_class),
            "recall": ignite.metrics.Recall(output_transform=output_to_class),
            "auroc": AUROC(output_transform=output_to_probs)
        }
        return metrics
    '''
   
