import torch
import numpy as np
from typing import Union, Callable
from torch._six import string_classes


def pin_memory(data):

    if isinstance(data, torch.Tensor):
        return data.pin_memory()
    elif isinstance(data, string_classes):
        return data
    elif isinstance(data, dict):
        return {k: pin_memory(sample) for k, sample in data.items()}
    elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple
        return type(data)(*(pin_memory(sample) for sample in data))
    elif isinstance(data, list):
        r = [pin_memory(sample) for sample in data]
        if all([item.shape == r[0].shape for item in r]):
            if isinstance(r[0], torch.Tensor):
                return torch.stack(r)
            elif isinstance(r[0], np.ndarray):
                return np.array(r)
        return r
    elif hasattr(data, "pin_memory"):
        return data.pin_memory()
    else:
        return data
