import torch
import rich
from overparameterized_ensembles.utils.utils import device

torch.set_default_dtype(torch.float64)
torch.set_grad_enabled(False)

# Check the default data type
rich.print("Default data type:", torch.get_default_dtype())

# The device is already set in utils.py
print(f"Using device: {device}")
