import json
import os

import numpy as np
import networkx as nx
import torch
from torch._six import container_abcs, string_classes, int_classes
from torch_geometric.data import Data as Graph
from torch_geometric.data.batch import Batch


def place_on_gpu(data, device=0):
    """
    Recursively places all 'torch.Tensor's in data on gpu and detaches.
    If elements are lists or tuples, recurses on the elements. Otherwise it
    ignores it.
    source: inspired by place_on_gpu from Snorkel Metal
    https://github.com/HazyResearch/metal/blob/master/metal/utils.py
    """
    data_type = type(data)
    if data_type in (list, tuple):
        data = [place_on_gpu(data[i], device) for i in range(len(data))]
        data = data_type(data)
        return data
    elif data_type is dict:
        data = {key: place_on_gpu(val, device) for key, val in data.items()}
        return data
    elif isinstance(data, torch.Tensor):
        return data.to(device)
    elif isinstance(data, Batch) or isinstance(data, Graph):
        return data.to(device)
    else:
        return data

def place_on_cpu(data):
    """
    Recursively places all 'torch.Tensor's in data on cpu and detaches from computation
    graph. If elements are lists or tuples, recurses on the elements. Otherwise it
    ignores it.
    source: inspired by place_on_gpu from Snorkel Metal
    https://github.com/HazyResearch/metal/blob/master/metal/utils.py
    """
    data_type = type(data)
    if data_type in (list, tuple):
        data = [place_on_cpu(data[i]) for i in range(len(data))]
        data = data_type(data)
        return data
    elif data_type is dict:
        data = {key: place_on_cpu(val) for key,val in data.items()}
        return data
    elif isinstance(data, torch.Tensor):
        return data.cpu().detach()
    else:
        return data

def batched_index_select(inputs: torch.tensor, dim: int, index: torch.tensor):
    """Performs an index select across a batch of tensors where each batch can select
    a different set of indices. 
    
    Args:
        input (torch.tensor): a tensor [batch_size, n, ...]
        dim (int): the dimension along which to 
        index (torch.tensor): [batch_size, num_select]
    
    Returns:
        torch.tensor: [batch_size, ..., num_select, ...] 
    """
    views = [inputs.shape[0]] + [1 if i != dim else -1 for i in range(1, len(inputs.shape))]
    expanse = list(inputs.shape)
    expanse[0] = -1
    expanse[dim] = -1
    index = index.view(views).expand(expanse)
    return torch.gather(inputs, dim, index)
