import torch

def extract_input(dataset):
    return torch.stack([x for x, y in dataset], dim = 0)

def extract_output(dataset):
    if isinstance(dataset[0][1], torch.Tensor):
        return torch.stack([y for x, y in dataset], dim = 0)
    elif isinstance(dataset[0][1], int):
        return torch.tensor([y for x, y in dataset])