import torch


def get_available_device(cfg=None):
    if cfg is not None:
        if cfg.device in ["cpu", "mps"]:
            return torch.device(cfg.device)
        elif "cuda:" in cfg.device:
            print(f"Using GPU {cfg.device} with manual override.")
            return torch.device(cfg.device)
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        free_gpus = [
            (i, torch.cuda.mem_get_info(i)) for i in range(num_gpus)
        ]
        # Sort by memory usage and pick the least used GPU
        best_gpu = max(free_gpus, key=lambda x: x[1][0])[0]
        print(f"Using GPU {best_gpu}")
        return torch.device(f"cuda:{best_gpu}")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")