import torch
import os


class SetDevice(object):
    def __init__(self, device):
        self.device = device

    def __enter__(self):
        if os.environ.get("DISABLE_SET_DEVICE", 0) == 0:
            self.prev_device = torch.get_default_device()
            torch.set_default_device(self.device)

    def __exit__(self, *args):
        if os.environ.get("DISABLE_SET_DEVICE", 0) == 0:
            torch.set_default_device(self.prev_device)
