import re
import numpy as np
from tqdm import trange
from multiprocessing import Pool
from rdkit import Chem, RDLogger
RDLogger.DisableLog('rdApp.*')

import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj, to_dense_batch, dense_to_sparse
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max, scatter_std
torch.multiprocessing.set_sharing_strategy('file_system')



## get graph feature from nodes and edges
def extract_graph_feature(x, adj, node_mask=None):
    if node_mask is None:
        node_mask = (adj.sum(dim=-1)>0).bool()
    batch_index = node_mask.nonzero()[:,0]
    graph_rep = scatter_add(adj[node_mask], batch_index, dim=0, dim_size=x.size(0)) # degree distribution
    graph_rep = torch.sort(graph_rep, -1)[0]
    return graph_rep




# -------- convert sparse input to dense input--------
def convert_sparse_to_dense(batch_index, node_feature, edge_index, edge_attr, augment_mask=None, return_node_mask=True):
    max_count = torch.unique(batch_index, return_counts=True)[1].max()
    dense_adj = to_dense_adj(edge_index, batch=batch_index, max_num_nodes=max_count).to(torch.float32)
    if augment_mask is not None:
        dense_adj = dense_adj[augment_mask]
    dense_x = torch.ones((dense_adj.size(0), dense_adj.size(1), 1), dtype=torch.float32).to(dense_adj.device)
    node_mask = dense_adj.sum(dim=-1)>0
    if return_node_mask:
        return dense_x, None, dense_adj, None, node_mask
    return dense_x, None, dense_adj, None
    # return: batch_dense_x, batch_dense_x_disable, batch_dense_enable_adj, batch_dense_disable_adj, batch_node_mask 


# -------- convert dense input to original sparse input --------
def standardize_adj(adj):
    device = adj.device
    adj = (adj + adj.transpose(-1,-2)) / 2
    mask = torch.eye(adj.size(-1), adj.size(-1)).bool().unsqueeze_(0).to(device)
    adj.masked_fill_(mask, 0)
    return adj
def convert_dense_to_rawpyg(dense_x, dense_adj, augmented_labels, n_jobs=20, reverse_label=False):
    # dense_x: B, N, F; dense_adj: B, N, N
    if isinstance(augmented_labels, torch.Tensor):
        augmented_labels = augmented_labels.cpu()
    dense_x = dense_x.cpu()
    dense_adj = dense_adj.cpu()
    dense_adj = standardize_adj(dense_adj)
    dense_adj = torch.where(dense_adj < 0.5, 0, 1)
    # mols = []
    pyg_graph_list = []
    augment_trace = torch.arange(dense_x.shape[0])
    batch_split_x = torch.tensor_split(dense_x, n_jobs)
    batch_split_adj = torch.tensor_split(dense_adj, n_jobs)
    batch_split_labels = torch.tensor_split(augmented_labels, n_jobs)
    batch_split_traces = torch.tensor_split(augment_trace, n_jobs)
    with Pool(n_jobs) as pool:  # Pool created
        results = pool.map(get_pyg_data_from_dense_batch, [(batch_split_x[i], batch_split_adj[i], batch_split_labels[i], batch_split_traces[i], reverse_label) for i in range(len(batch_split_adj))])
    for single_results in results:
        pyg_graph_list.extend(single_results)
    return pyg_graph_list

def get_pyg_data_from_dense_batch(params):
    batched_x, batched_adj, augmented_labels, batch_traces, reverse_label = params
    pyg_graph_list = []
    for b_index, (x_single, adj_single) in enumerate(zip(batched_x, batched_adj)):
        node_mask_single = adj_single.sum(dim=-1) > 0
        adj_single = adj_single[node_mask_single, :]
        adj_single = adj_single[:, node_mask_single]
        edge_index = dense_to_sparse(adj_single)[0]
        if edge_index.numel() > 0:
            edge_attr = torch.ones((edge_index.size(1),1)).to(torch.long)
            num_nodes = torch.max(edge_index) + 1
            g = Data()
            g.__num_nodes__ = num_nodes # ogb < 1.3.4
            g.num_nodes  = num_nodes # ogb > 1.3.4
            g.edge_index = edge_index
            g.edge_attr = edge_attr
            g.x = torch.ones((num_nodes,1)).to(torch.long)
            g.species_id = torch.tensor([-1], dtype=torch.long)
            if reverse_label:
                g.y = (1-augmented_labels[b_index]).view(1, -1)
            else:
                g.y = augmented_labels[b_index].view(1, -1)
            pyg_graph_list.append(g)
    return pyg_graph_list


# -------- convert dense adj to sparse with address: inaccurate with external edges --------
def convert_dense_adj_to_sparse_with_attr(adj, node_mask): # B x N x N
    # adj = adj * node_mask.unsqueeze(-1)
    # adj = adj * node_mask.unsqueeze(-2)
    adj = adj[node_mask]
    edge_index = (adj > 0.5).nonzero().t()
    row, col = edge_index[0], edge_index[1]
    edge_attr = torch.ones((len(row),1)).to(torch.long).to(adj.device)
    return torch.stack([row, col], dim=0), edge_attr
