import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np


def boolean_string(s):
    if s not in {'False', 'True', '0', '1'}:
        raise ValueError('Not a valid boolean string')
    return (s == 'True') or (s == '1')

def binarize(tens, thresh=0.5):
    if isinstance(tens, torch.Tensor):
        tens = tens.clone()
    else:
        tens = np.copy(tens)
    tens[tens < thresh] = 0.
    tens[tens >= thresh] = 1.
    return tens

def get_dev(argdev=None):
    if torch.cuda.is_available():
        if argdev == "cpu":
            print("Warning: using cpu while cuda is available")
            return torch.device("cpu")
        elif "cuda" in argdev:
            return torch.device(argdev)
        elif argdev is None:
            print("use cpu by default")
            return torch.device("cpu")
        else:
            raise ValueError(f"unrecognizable device {argdev}")
    else:
        return torch.device("cpu")

def get_full_dataloader(bs, is_shuffle=False):
    test_trans = transforms.Compose([transforms.ToTensor(),])
    test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=test_trans)
    test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=is_shuffle)
    return test_loader